#!/usr/bin/env python3 """ Forest Metrics Inference Demo - Interactive 3-Column Layout Column 1: RGBI Patch | Column 2: Predictions & GT | Column 3: Comparison Plots """ import numpy as np import torch import torch.nn.functional as F import rasterio import gradio as gr from transformers import SegformerModel from torch import nn from PIL import Image, ImageDraw import matplotlib.pyplot as plt import matplotlib.colors as mcolors import os import gc from scipy.ndimage import zoom # ============================================================ # CONFIG # ============================================================ PATCH_SIZE = 224 DROP_BORDER = 16 STRIDE = 112 OUT_CHANNELS = 3 # Patch inference settings PATCH_SIZE_METERS = 190 # 190m × 190m patch GSD = 0.2 # Ground Sample Distance in meters/pixel PATCH_SIZE_PIXELS = int(PATCH_SIZE_METERS / GSD) # 950 pixels # File paths FIXED_INPUT_TIF = "data/dop20rgbi_33348_5612_2_sn.tif" FIXED_GT_CHM = "data/lsc_33348_5612_2_sn_chm.tif" FIXED_GT_PAI = "data/lsc_33348_5612_2_sn_pai.tif" FIXED_GT_FHD = "data/lsc_33348_5612_2_sn_fhd.tif" MODEL_PATH = "data/student_MiTB2.pt" VIDEO_PATH = "data/overview.mp4" # ============================================================ # DEVICE # ============================================================ if torch.backends.mps.is_available(): DEVICE = "mps" elif torch.cuda.is_available(): DEVICE = "cuda" else: DEVICE = "cpu" print(f"[INFO] Using device: {DEVICE}") # ============================================================ # UTILITY FUNCTIONS # ============================================================ def create_colormap_image(array, cmap_name="viridis", vmin=None, vmax=None): """Convert array to RGB image with colormap.""" if vmin is None: vmin = np.nanmin(array) if vmax is None: vmax = np.nanmax(array) norm = mcolors.Normalize(vmin=vmin, vmax=vmax) cmap = plt.get_cmap(cmap_name) rgba = cmap(norm(array)) rgb = (rgba[:, :, :3] * 255).astype(np.uint8) return Image.fromarray(rgb) def create_comparison_plot(pred, gt, title): """Create scatter plot with MAE, RMSE, R², and Bias in top-right.""" pred_flat = pred.flatten() gt_flat = gt.flatten() valid_mask = ~(np.isnan(pred_flat) | np.isnan(gt_flat)) pred_valid = pred_flat[valid_mask] gt_valid = gt_flat[valid_mask] if len(pred_valid) == 0: fig, ax = plt.subplots(figsize=(6, 6)) ax.text(0.5, 0.5, "No valid data", ha="center", va="center") ax.set_title(title) return fig fig, ax = plt.subplots(figsize=(6, 6)) ax.scatter(gt_valid, pred_valid, alpha=0.3, s=1, c="#047857") min_val = min(np.min(gt_valid), np.min(pred_valid)) max_val = max(np.max(gt_valid), np.max(pred_valid)) ax.plot([min_val, max_val], [min_val, max_val], "r--", linewidth=2, label="1:1 line") # Calculate metrics including Bias mae = np.mean(np.abs(pred_valid - gt_valid)) rmse = np.sqrt(np.mean((pred_valid - gt_valid) ** 2)) r2 = np.corrcoef(pred_valid, gt_valid)[0, 1] ** 2 bias = np.mean(pred_valid - gt_valid) # Position metrics in top-right corner metrics_text = f"MAE: {mae:.3f}\nRMSE: {rmse:.3f}\nR²: {r2:.3f}\nBias: {bias:.3f}" ax.text( 0.95, 0.95, metrics_text, transform=ax.transAxes, verticalalignment="top", horizontalalignment="right", bbox=dict(boxstyle="round", facecolor="white", alpha=0.9), fontsize=10 ) ax.set_xlabel("Ground Truth", fontsize=12) ax.set_ylabel("Prediction", fontsize=12) ax.set_title(title, fontsize=14, fontweight="bold") ax.legend() ax.grid(True, alpha=0.3) plt.tight_layout() return fig # ============================================================ # MODEL DEFINITION # ============================================================ class SegFormerStudent(nn.Module): def __init__(self, out_ch=OUT_CHANNELS): super().__init__() self.backbone = SegformerModel.from_pretrained( "nvidia/mit-b2", use_safetensors=True, trust_remote_code=True ) self.backbone.config.output_hidden_states = True old_conv = self.backbone.encoder.patch_embeddings[0].proj new_conv = nn.Conv2d( 4, old_conv.out_channels, kernel_size=old_conv.kernel_size, stride=old_conv.stride, padding=old_conv.padding, bias=old_conv.bias is not None, ) with torch.no_grad(): new_conv.weight[:, :3, :, :] = old_conv.weight.clone() new_conv.weight[:, 3:, :, :] = 0.0 if old_conv.bias is not None: new_conv.bias.copy_(old_conv.bias) self.backbone.encoder.patch_embeddings[0].proj = new_conv self.dims = self.backbone.config.hidden_sizes self.proj0 = nn.Conv2d(self.dims[0], 128, 1) self.proj1 = nn.Conv2d(self.dims[1], 128, 1) self.proj2 = nn.Conv2d(self.dims[2], 128, 1) self.proj3 = nn.Conv2d(self.dims[3], 128, 1) self.adapter = nn.Conv2d(128, 256, 1) self.pred_head = nn.Sequential( nn.Conv2d(128, 64, 3, padding=1), nn.GELU(), nn.Conv2d(64, out_ch, 1) ) def forward(self, x): outputs = self.backbone(pixel_values=x, output_hidden_states=True) hs = outputs.hidden_states[-4:] if len(hs[0].shape) == 3: B, N, C = hs[0].shape H = W = int(N**0.5) f0, f1, f2, f3 = [h.permute(0, 2, 1).reshape(B, -1, H, W) for h in hs] else: f0, f1, f2, f3 = hs fused = ( self.proj0(f0) + F.interpolate(self.proj1(f1), size=f0.shape[-2:], mode="bilinear") + F.interpolate(self.proj2(f2), size=f0.shape[-2:], mode="bilinear") + F.interpolate(self.proj3(f3), size=f0.shape[-2:], mode="bilinear") ) return self.pred_head(fused) # ============================================================ # LOAD MODEL AND DATA AT STARTUP # ============================================================ print("[INFO] Loading model...") try: model = SegFormerStudent().to(DEVICE) state = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=False) model.load_state_dict(state) model.eval() print("✓ Model loaded successfully!") except Exception as e: print(f"✗ Error loading model: {e}") model = None # Load input tile print("[INFO] Loading input tile...") try: with rasterio.open(FIXED_INPUT_TIF) as src: input_data_full = src.read().astype("float32") input_profile = src.profile rgb_data = input_data_full[:3].copy() rgb_data = np.transpose(rgb_data, (1, 2, 0)) rgb_data = (rgb_data / rgb_data.max() * 255).astype(np.uint8) input_preview = Image.fromarray(rgb_data) print(f"✓ Input tile loaded: {input_data_full.shape}") INPUT_LOADED = True except Exception as e: print(f"✗ Error loading input: {e}") input_preview = None INPUT_LOADED = False # Load ground truth print("[INFO] Loading ground truth data...") try: with rasterio.open(FIXED_GT_CHM) as src: gt_chm_full = src.read(1).astype("float32") gt_profile = src.profile nodata_chm = src.nodata if nodata_chm is not None: gt_chm_full[gt_chm_full == nodata_chm] = np.nan with rasterio.open(FIXED_GT_PAI) as src: gt_pai_full = src.read(1).astype("float32") nodata_pai = src.nodata if nodata_pai is not None: gt_pai_full[gt_pai_full == nodata_pai] = np.nan with rasterio.open(FIXED_GT_FHD) as src: gt_fhd_full = src.read(1).astype("float32") nodata_fhd = src.nodata if nodata_fhd is not None: gt_fhd_full[gt_fhd_full == nodata_fhd] = np.nan gt_fhd_full[gt_fhd_full < -9000] = np.nan print(f"✓ Ground truth loaded") GT_LOADED = True except Exception as e: print(f"✗ Error loading ground truth: {e}") GT_LOADED = False # ============================================================ # PATCH EXTRACTION AND INFERENCE # ============================================================ def extract_patch(array, center_y, center_x, patch_size): """Extract a patch centered at (center_y, center_x).""" H, W = array.shape[-2:] half_patch = patch_size // 2 y_start = max(0, center_y - half_patch) y_end = min(H, center_y + half_patch) x_start = max(0, center_x - half_patch) x_end = min(W, center_x + half_patch) if array.ndim == 3: patch = array[:, y_start:y_end, x_start:x_end] else: patch = array[y_start:y_end, x_start:x_end] return patch, (y_start, y_end, x_start, x_end) def run_patch_inference(evt: gr.SelectData, progress=gr.Progress()): """Run inference on a patch centered at clicked location.""" if not INPUT_LOADED or not GT_LOADED or model is None: return (None,) * 12 + ("❌ Model or data not loaded",) if DEVICE in ["cuda", "mps"]: if DEVICE == "cuda": torch.cuda.empty_cache() else: torch.mps.empty_cache() progress(0, desc="Extracting patch...") try: click_x, click_y = evt.index display_w, display_h = input_preview.size original_h, original_w = input_data_full.shape[1], input_data_full.shape[2] scale_x = original_w / display_w scale_y = original_h / display_h center_x = int(click_x * scale_x) center_y = int(click_y * scale_y) progress(0.1, desc="Extracting input patch...") input_patch, (y_start, y_end, x_start, x_end) = extract_patch( input_data_full, center_y, center_x, PATCH_SIZE_PIXELS ) input_patch = input_patch / max(input_patch.max(), 1e-6) C, H, W = input_patch.shape if C == 3: input_patch = np.vstack([input_patch, np.zeros((1, H, W), dtype=np.float32)]) # Create RGB preview rgb_patch = input_patch[:3].copy() rgb_patch = np.transpose(rgb_patch, (1, 2, 0)) rgb_patch = (rgb_patch * 255).astype(np.uint8) rgb_patch_img = Image.fromarray(rgb_patch) progress(0.2, desc="Preparing display...") # Add bounding box full_img_with_box = input_preview.copy() draw = ImageDraw.Draw(full_img_with_box) box_x_start = int(x_start / scale_x) box_x_end = int(x_end / scale_x) box_y_start = int(y_start / scale_y) box_y_end = int(y_end / scale_y) draw.rectangle([box_x_start, box_y_start, box_x_end, box_y_end], outline="red", width=15) progress(0.3, desc="Running model inference...") patch_tensor = torch.from_numpy(input_patch).unsqueeze(0).to(DEVICE) with torch.inference_mode(): pred = model(patch_tensor)[0] del patch_tensor if DEVICE in ["cuda", "mps"]: if DEVICE == "cuda": torch.cuda.empty_cache() else: torch.mps.empty_cache() progress(0.5, desc="Processing predictions...") pred = F.interpolate( pred.unsqueeze(0), size=(H, W), mode="bilinear", align_corners=False )[0].cpu().numpy() pred_chm = pred[0] pred_pai = pred[1] pred_fhd = pred[2] progress(0.6, desc="Extracting ground truth...") gt_h, gt_w = gt_chm_full.shape gt_scale_y = gt_h / original_h gt_scale_x = gt_w / original_w gt_center_y = int(center_y * gt_scale_y) gt_center_x = int(center_x * gt_scale_x) gt_patch_size = int(PATCH_SIZE_PIXELS * gt_scale_y) gt_chm_patch, _ = extract_patch(gt_chm_full, gt_center_y, gt_center_x, gt_patch_size) gt_pai_patch, _ = extract_patch(gt_pai_full, gt_center_y, gt_center_x, gt_patch_size) gt_fhd_patch, _ = extract_patch(gt_fhd_full, gt_center_y, gt_center_x, gt_patch_size) progress(0.7, desc="Resampling predictions...") if pred_chm.shape != gt_chm_patch.shape: zoom_factors = ( gt_chm_patch.shape[0] / pred_chm.shape[0], gt_chm_patch.shape[1] / pred_chm.shape[1], ) pred_chm = zoom(pred_chm, zoom_factors, order=1) pred_pai = zoom(pred_pai, zoom_factors, order=1) pred_fhd = zoom(pred_fhd, zoom_factors, order=1) progress(0.8, desc="Creating visualizations...") chm_pred_img = create_colormap_image(pred_chm, "viridis", np.nanmin(gt_chm_full), np.nanmax(gt_chm_full)) pai_pred_img = create_colormap_image(pred_pai, "Greens", np.nanmin(gt_pai_full), np.nanmax(gt_pai_full)) fhd_pred_img = create_colormap_image(pred_fhd, "magma", np.nanmin(gt_fhd_full), np.nanmax(gt_fhd_full)) chm_gt_img = create_colormap_image(gt_chm_patch, "viridis", np.nanmin(gt_chm_full), np.nanmax(gt_chm_full)) pai_gt_img = create_colormap_image(gt_pai_patch, "Greens", np.nanmin(gt_pai_full), np.nanmax(gt_pai_full)) fhd_gt_img = create_colormap_image(gt_fhd_patch, "magma", np.nanmin(gt_fhd_full), np.nanmax(gt_fhd_full)) progress(0.9, desc="Generating comparison plots...") chm_plot = create_comparison_plot(pred_chm, gt_chm_patch, "CHM Comparison") pai_plot = create_comparison_plot(pred_pai, gt_pai_patch, "PAI Comparison") fhd_plot = create_comparison_plot(pred_fhd, gt_fhd_patch, "FHD Comparison") chm_mae = np.nanmean(np.abs(pred_chm - gt_chm_patch)) chm_rmse = np.sqrt(np.nanmean((pred_chm - gt_chm_patch) ** 2)) pai_mae = np.nanmean(np.abs(pred_pai - gt_pai_patch)) pai_rmse = np.sqrt(np.nanmean((pred_pai - gt_pai_patch) ** 2)) fhd_mae = np.nanmean(np.abs(pred_fhd - gt_fhd_patch)) fhd_rmse = np.sqrt(np.nanmean((pred_fhd - gt_fhd_patch) ** 2)) status = f""" ✓ **Inference completed on 190m×190m patch** **Location:** ({center_x}, {center_y}) in input coordinates **Patch Metrics:** - CHM: MAE = {chm_mae:.4f} m, RMSE = {chm_rmse:.4f} m - PAI: MAE = {pai_mae:.4f}, RMSE = {pai_rmse:.4f} - FHD: MAE = {fhd_mae:.4f}, RMSE = {fhd_rmse:.4f} """ progress(1.0, desc="Complete!") plt.close("all") gc.collect() return ( full_img_with_box, rgb_patch_img, chm_pred_img, chm_gt_img, chm_plot, pai_pred_img, pai_gt_img, pai_plot, fhd_pred_img, fhd_gt_img, fhd_plot, status ) except Exception as e: import traceback plt.close("all") gc.collect() error_msg = f"❌ Error during inference:\n```\n{traceback.format_exc()}\n```" return (None,) * 11 + (error_msg,) # ============================================================ # GRADIO INTERFACE - 3 COLUMN LAYOUT # ============================================================ def create_demo(): with gr.Blocks(title="FSKD Inference Demo") as demo: gr.HTML("""

FSKD: Monocular Forest Structure Inference via LiDAR-to-RGBI Knowledge Distillation

Taimur Khan1,2, Hannes Feilhauer2, and Muhammad Jazib Zafar3

1Helmholtz Centre for Environmental Research -- UFZ, Halle (Saale), Germany
taimur.khan@ufz.de
2Leipzig University, Leipzig, Germany
3Georg-August University of Göttingen, Göttingen, Germany

Interactive Demo for Forest Structure Inference (CHM, PAI, FHD) from RGBI Imagery

📄 Paper 💻 Code (upon publication) 🤗 ForestPatch Dataset
Abstract: Very High Resolution (VHR) forest structure data at individual-tree scale is essential for carbon, biodiversity, and ecosystem monitoring. Still, airborne LiDAR remains costly and infrequent despite being the reference for forest structure metrics like Canopy Height Model (CHM), Plant Area Index (PAI), and Foliage Height Diversity (FHD). We propose FSKD: a LiDAR-to-RGB-Infrared (RGBI) knowledge distillation (KD) framework in which a multi-modal teacher fuses RGBI imagery with LiDAR-derived planar metrics and vertical profiles via cross-attention, and an RGBI-only SegFormer student learns to reproduce these outputs. Trained on 384 km2 of forests in Saxony, Germany (20 cm ground sampling distance (GSD)) and evaluated on eight geographically distinct test tiles, the student achieves state-of-the-art (SOTA) zero-shot CHM performance (MedAE 4.17 m, R2=0.51, IoU 0.87), outperforming HRCHM/DAC baselines by 29--46% in MAE (5.81 m vs. 8.14--10.84 m) with stronger correlation coefficients (0.713 vs. 0.166--0.652). Ablations show that multi-modal fusion improves performance by 10--26% over RGBI-only training, and that asymmetric distillation with appropriate model capacity is critical. The method jointly predicts CHM, PAI, and FHD, a multi-metric capability not provided by current monocular CHM estimators, although PAI/FHD transfer remains region-dependent and benefits from local calibration. The framework also remains effective under temporal mismatch (winter LiDAR, summer RGBI), removing strict co-acquisition constraints and enabling scalable 20 cm operational monitoring for workflows such as Digital Twin Germany and national Digital Orthophoto programs.
""") if os.path.exists(VIDEO_PATH): gr.Video(value=VIDEO_PATH, label="📹 AI Generated Simplified Paper Overview Video", autoplay=False) gr.Markdown("---") # Top row: Input tile and status with gr.Row(): with gr.Column(scale=2): gr.Markdown("### 👇🏼 Click anywhere on the input DOP tile to run inference") input_tile_display = gr.Image( value=input_preview if INPUT_LOADED else None, label="Input RGBI Tile (RGB preview)", type="pil", interactive=False, ) with gr.Column(scale=1): status_display = gr.Markdown(value="Awaiting user input...") gr.Markdown(""" ### 📖 About **Model**: FSKD (MiT-B2) with knowledge distillation **Patch**: 190m×190m ≈ 950×950 pixels at 0.2m GSD **Test Tile**: 3348_5612_2 """) gr.Markdown("---") gr.Markdown("## 📊 Results: 3-Column Layout (RGBI Patch | Pred/GT | Plot)") # CHM Row gr.Markdown("### 🌳 Canopy Height Model (CHM)") with gr.Row(): rgb_patch_chm = gr.Image(label="RGBI Patch", type="pil", interactive=False, height=500) with gr.Column(): chm_pred_display = gr.Image(label="Prediction", type="pil", interactive=False, height=240) chm_gt_display = gr.Image(label="Ground Truth", type="pil", interactive=False, height=240) chm_comparison = gr.Plot(label="Pred vs GT") gr.Markdown("---") # PAI Row gr.Markdown("### 🍃 Plant Area Index (PAI)") with gr.Row(): rgb_patch_pai = gr.Image(label="RGBI Patch", type="pil", interactive=False, height=500) with gr.Column(): pai_pred_display = gr.Image(label="Prediction", type="pil", interactive=False, height=240) pai_gt_display = gr.Image(label="Ground Truth", type="pil", interactive=False, height=240) pai_comparison = gr.Plot(label="Pred vs GT") gr.Markdown("---") # FHD Row gr.Markdown("### 📏 Foliage Height Diversity (FHD)") with gr.Row(): rgb_patch_fhd = gr.Image(label="RGBI Patch", type="pil", interactive=False, height=500) with gr.Column(): fhd_pred_display = gr.Image(label="Prediction", type="pil", interactive=False, height=240) fhd_gt_display = gr.Image(label="Ground Truth", type="pil", interactive=False, height=240) fhd_comparison = gr.Plot(label="Pred vs GT") # Connect click event - returns 12 outputs input_tile_display.select( fn=run_patch_inference, inputs=[], outputs=[ input_tile_display, rgb_patch_chm, chm_pred_display, chm_gt_display, chm_comparison, pai_pred_display, pai_gt_display, pai_comparison, fhd_pred_display, fhd_gt_display, fhd_comparison, status_display, ], ) # Duplicate RGB patch to all rows rgb_patch_chm.change( fn=lambda x: (x, x), inputs=[rgb_patch_chm], outputs=[rgb_patch_pai, rgb_patch_fhd] ) return demo # ============================================================ # LAUNCH # ============================================================ if __name__ == "__main__": print("\n" + "=" * 70) print("Checking required files...") print("=" * 70) for path, name in [ (MODEL_PATH, "Model checkpoint"), (FIXED_INPUT_TIF, "Input RGBI tile"), (FIXED_GT_CHM, "Ground truth CHM"), (FIXED_GT_PAI, "Ground truth PAI"), (FIXED_GT_FHD, "Ground truth FHD"), (VIDEO_PATH, "Overview video"), ]: if os.path.exists(path): size_mb = os.path.getsize(path) / (1024**2) print(f"✓ {name}: {path} ({size_mb:.1f} MB)") else: print(f"✗ {name}: {path} (NOT FOUND)") print("=" * 70 + "\n") custom_theme = gr.themes.Soft(primary_hue="emerald", secondary_hue="teal") custom_css = ".gradio-container { max-width: 1900px !important; }" demo = create_demo() demo.launch( server_name="0.0.0.0", server_port=7860, theme=custom_theme, css=custom_css)