#!/usr/bin/env python3 """ Forest Metrics Inference Demo - Interactive 3-Column Layout Column 1: RGBI Patch | Column 2: Predictions & GT | Column 3: Comparison Plots """ import numpy as np import torch import torch.nn.functional as F import rasterio import gradio as gr from transformers import SegformerModel from torch import nn from PIL import Image, ImageDraw import matplotlib.pyplot as plt import matplotlib.colors as mcolors import os import gc from scipy.ndimage import zoom # ============================================================ # CONFIG # ============================================================ PATCH_SIZE = 224 DROP_BORDER = 16 STRIDE = 112 OUT_CHANNELS = 3 # Patch inference settings PATCH_SIZE_METERS = 190 # 190m × 190m patch GSD = 0.2 # Ground Sample Distance in meters/pixel PATCH_SIZE_PIXELS = int(PATCH_SIZE_METERS / GSD) # 950 pixels # File paths FIXED_INPUT_TIF = "data/dop20rgbi_33348_5612_2_sn.tif" FIXED_GT_CHM = "data/lsc_33348_5612_2_sn_chm.tif" FIXED_GT_PAI = "data/lsc_33348_5612_2_sn_pai.tif" FIXED_GT_FHD = "data/lsc_33348_5612_2_sn_fhd.tif" MODEL_PATH = "data/student_MiTB2.pt" VIDEO_PATH = "data/overview.mp4" # ============================================================ # DEVICE # ============================================================ if torch.backends.mps.is_available(): DEVICE = "mps" elif torch.cuda.is_available(): DEVICE = "cuda" else: DEVICE = "cpu" print(f"[INFO] Using device: {DEVICE}") # ============================================================ # UTILITY FUNCTIONS # ============================================================ def create_colormap_image(array, cmap_name="viridis", vmin=None, vmax=None): """Convert array to RGB image with colormap.""" if vmin is None: vmin = np.nanmin(array) if vmax is None: vmax = np.nanmax(array) norm = mcolors.Normalize(vmin=vmin, vmax=vmax) cmap = plt.get_cmap(cmap_name) rgba = cmap(norm(array)) rgb = (rgba[:, :, :3] * 255).astype(np.uint8) return Image.fromarray(rgb) def create_comparison_plot(pred, gt, title): """Create scatter plot with MAE, RMSE, R², and Bias in top-right.""" pred_flat = pred.flatten() gt_flat = gt.flatten() valid_mask = ~(np.isnan(pred_flat) | np.isnan(gt_flat)) pred_valid = pred_flat[valid_mask] gt_valid = gt_flat[valid_mask] if len(pred_valid) == 0: fig, ax = plt.subplots(figsize=(6, 6)) ax.text(0.5, 0.5, "No valid data", ha="center", va="center") ax.set_title(title) return fig fig, ax = plt.subplots(figsize=(6, 6)) ax.scatter(gt_valid, pred_valid, alpha=0.3, s=1, c="#047857") min_val = min(np.min(gt_valid), np.min(pred_valid)) max_val = max(np.max(gt_valid), np.max(pred_valid)) ax.plot([min_val, max_val], [min_val, max_val], "r--", linewidth=2, label="1:1 line") # Calculate metrics including Bias mae = np.mean(np.abs(pred_valid - gt_valid)) rmse = np.sqrt(np.mean((pred_valid - gt_valid) ** 2)) r2 = np.corrcoef(pred_valid, gt_valid)[0, 1] ** 2 bias = np.mean(pred_valid - gt_valid) # Position metrics in top-right corner metrics_text = f"MAE: {mae:.3f}\nRMSE: {rmse:.3f}\nR²: {r2:.3f}\nBias: {bias:.3f}" ax.text( 0.95, 0.95, metrics_text, transform=ax.transAxes, verticalalignment="top", horizontalalignment="right", bbox=dict(boxstyle="round", facecolor="white", alpha=0.9), fontsize=10 ) ax.set_xlabel("Ground Truth", fontsize=12) ax.set_ylabel("Prediction", fontsize=12) ax.set_title(title, fontsize=14, fontweight="bold") ax.legend() ax.grid(True, alpha=0.3) plt.tight_layout() return fig # ============================================================ # MODEL DEFINITION # ============================================================ class SegFormerStudent(nn.Module): def __init__(self, out_ch=OUT_CHANNELS): super().__init__() self.backbone = SegformerModel.from_pretrained( "nvidia/mit-b2", use_safetensors=True, trust_remote_code=True ) self.backbone.config.output_hidden_states = True old_conv = self.backbone.encoder.patch_embeddings[0].proj new_conv = nn.Conv2d( 4, old_conv.out_channels, kernel_size=old_conv.kernel_size, stride=old_conv.stride, padding=old_conv.padding, bias=old_conv.bias is not None, ) with torch.no_grad(): new_conv.weight[:, :3, :, :] = old_conv.weight.clone() new_conv.weight[:, 3:, :, :] = 0.0 if old_conv.bias is not None: new_conv.bias.copy_(old_conv.bias) self.backbone.encoder.patch_embeddings[0].proj = new_conv self.dims = self.backbone.config.hidden_sizes self.proj0 = nn.Conv2d(self.dims[0], 128, 1) self.proj1 = nn.Conv2d(self.dims[1], 128, 1) self.proj2 = nn.Conv2d(self.dims[2], 128, 1) self.proj3 = nn.Conv2d(self.dims[3], 128, 1) self.adapter = nn.Conv2d(128, 256, 1) self.pred_head = nn.Sequential( nn.Conv2d(128, 64, 3, padding=1), nn.GELU(), nn.Conv2d(64, out_ch, 1) ) def forward(self, x): outputs = self.backbone(pixel_values=x, output_hidden_states=True) hs = outputs.hidden_states[-4:] if len(hs[0].shape) == 3: B, N, C = hs[0].shape H = W = int(N**0.5) f0, f1, f2, f3 = [h.permute(0, 2, 1).reshape(B, -1, H, W) for h in hs] else: f0, f1, f2, f3 = hs fused = ( self.proj0(f0) + F.interpolate(self.proj1(f1), size=f0.shape[-2:], mode="bilinear") + F.interpolate(self.proj2(f2), size=f0.shape[-2:], mode="bilinear") + F.interpolate(self.proj3(f3), size=f0.shape[-2:], mode="bilinear") ) return self.pred_head(fused) # ============================================================ # LOAD MODEL AND DATA AT STARTUP # ============================================================ print("[INFO] Loading model...") try: model = SegFormerStudent().to(DEVICE) state = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=False) model.load_state_dict(state) model.eval() print("✓ Model loaded successfully!") except Exception as e: print(f"✗ Error loading model: {e}") model = None # Load input tile print("[INFO] Loading input tile...") try: with rasterio.open(FIXED_INPUT_TIF) as src: input_data_full = src.read().astype("float32") input_profile = src.profile rgb_data = input_data_full[:3].copy() rgb_data = np.transpose(rgb_data, (1, 2, 0)) rgb_data = (rgb_data / rgb_data.max() * 255).astype(np.uint8) input_preview = Image.fromarray(rgb_data) print(f"✓ Input tile loaded: {input_data_full.shape}") INPUT_LOADED = True except Exception as e: print(f"✗ Error loading input: {e}") input_preview = None INPUT_LOADED = False # Load ground truth print("[INFO] Loading ground truth data...") try: with rasterio.open(FIXED_GT_CHM) as src: gt_chm_full = src.read(1).astype("float32") gt_profile = src.profile nodata_chm = src.nodata if nodata_chm is not None: gt_chm_full[gt_chm_full == nodata_chm] = np.nan with rasterio.open(FIXED_GT_PAI) as src: gt_pai_full = src.read(1).astype("float32") nodata_pai = src.nodata if nodata_pai is not None: gt_pai_full[gt_pai_full == nodata_pai] = np.nan with rasterio.open(FIXED_GT_FHD) as src: gt_fhd_full = src.read(1).astype("float32") nodata_fhd = src.nodata if nodata_fhd is not None: gt_fhd_full[gt_fhd_full == nodata_fhd] = np.nan gt_fhd_full[gt_fhd_full < -9000] = np.nan print(f"✓ Ground truth loaded") GT_LOADED = True except Exception as e: print(f"✗ Error loading ground truth: {e}") GT_LOADED = False # ============================================================ # PATCH EXTRACTION AND INFERENCE # ============================================================ def extract_patch(array, center_y, center_x, patch_size): """Extract a patch centered at (center_y, center_x).""" H, W = array.shape[-2:] half_patch = patch_size // 2 y_start = max(0, center_y - half_patch) y_end = min(H, center_y + half_patch) x_start = max(0, center_x - half_patch) x_end = min(W, center_x + half_patch) if array.ndim == 3: patch = array[:, y_start:y_end, x_start:x_end] else: patch = array[y_start:y_end, x_start:x_end] return patch, (y_start, y_end, x_start, x_end) def run_patch_inference(evt: gr.SelectData, progress=gr.Progress()): """Run inference on a patch centered at clicked location.""" if not INPUT_LOADED or not GT_LOADED or model is None: return (None,) * 12 + ("❌ Model or data not loaded",) if DEVICE in ["cuda", "mps"]: if DEVICE == "cuda": torch.cuda.empty_cache() else: torch.mps.empty_cache() progress(0, desc="Extracting patch...") try: click_x, click_y = evt.index display_w, display_h = input_preview.size original_h, original_w = input_data_full.shape[1], input_data_full.shape[2] scale_x = original_w / display_w scale_y = original_h / display_h center_x = int(click_x * scale_x) center_y = int(click_y * scale_y) progress(0.1, desc="Extracting input patch...") input_patch, (y_start, y_end, x_start, x_end) = extract_patch( input_data_full, center_y, center_x, PATCH_SIZE_PIXELS ) input_patch = input_patch / max(input_patch.max(), 1e-6) C, H, W = input_patch.shape if C == 3: input_patch = np.vstack([input_patch, np.zeros((1, H, W), dtype=np.float32)]) # Create RGB preview rgb_patch = input_patch[:3].copy() rgb_patch = np.transpose(rgb_patch, (1, 2, 0)) rgb_patch = (rgb_patch * 255).astype(np.uint8) rgb_patch_img = Image.fromarray(rgb_patch) progress(0.2, desc="Preparing display...") # Add bounding box full_img_with_box = input_preview.copy() draw = ImageDraw.Draw(full_img_with_box) box_x_start = int(x_start / scale_x) box_x_end = int(x_end / scale_x) box_y_start = int(y_start / scale_y) box_y_end = int(y_end / scale_y) draw.rectangle([box_x_start, box_y_start, box_x_end, box_y_end], outline="red", width=15) progress(0.3, desc="Running model inference...") patch_tensor = torch.from_numpy(input_patch).unsqueeze(0).to(DEVICE) with torch.inference_mode(): pred = model(patch_tensor)[0] del patch_tensor if DEVICE in ["cuda", "mps"]: if DEVICE == "cuda": torch.cuda.empty_cache() else: torch.mps.empty_cache() progress(0.5, desc="Processing predictions...") pred = F.interpolate( pred.unsqueeze(0), size=(H, W), mode="bilinear", align_corners=False )[0].cpu().numpy() pred_chm = pred[0] pred_pai = pred[1] pred_fhd = pred[2] progress(0.6, desc="Extracting ground truth...") gt_h, gt_w = gt_chm_full.shape gt_scale_y = gt_h / original_h gt_scale_x = gt_w / original_w gt_center_y = int(center_y * gt_scale_y) gt_center_x = int(center_x * gt_scale_x) gt_patch_size = int(PATCH_SIZE_PIXELS * gt_scale_y) gt_chm_patch, _ = extract_patch(gt_chm_full, gt_center_y, gt_center_x, gt_patch_size) gt_pai_patch, _ = extract_patch(gt_pai_full, gt_center_y, gt_center_x, gt_patch_size) gt_fhd_patch, _ = extract_patch(gt_fhd_full, gt_center_y, gt_center_x, gt_patch_size) progress(0.7, desc="Resampling predictions...") if pred_chm.shape != gt_chm_patch.shape: zoom_factors = ( gt_chm_patch.shape[0] / pred_chm.shape[0], gt_chm_patch.shape[1] / pred_chm.shape[1], ) pred_chm = zoom(pred_chm, zoom_factors, order=1) pred_pai = zoom(pred_pai, zoom_factors, order=1) pred_fhd = zoom(pred_fhd, zoom_factors, order=1) progress(0.8, desc="Creating visualizations...") chm_pred_img = create_colormap_image(pred_chm, "viridis", np.nanmin(gt_chm_full), np.nanmax(gt_chm_full)) pai_pred_img = create_colormap_image(pred_pai, "Greens", np.nanmin(gt_pai_full), np.nanmax(gt_pai_full)) fhd_pred_img = create_colormap_image(pred_fhd, "magma", np.nanmin(gt_fhd_full), np.nanmax(gt_fhd_full)) chm_gt_img = create_colormap_image(gt_chm_patch, "viridis", np.nanmin(gt_chm_full), np.nanmax(gt_chm_full)) pai_gt_img = create_colormap_image(gt_pai_patch, "Greens", np.nanmin(gt_pai_full), np.nanmax(gt_pai_full)) fhd_gt_img = create_colormap_image(gt_fhd_patch, "magma", np.nanmin(gt_fhd_full), np.nanmax(gt_fhd_full)) progress(0.9, desc="Generating comparison plots...") chm_plot = create_comparison_plot(pred_chm, gt_chm_patch, "CHM Comparison") pai_plot = create_comparison_plot(pred_pai, gt_pai_patch, "PAI Comparison") fhd_plot = create_comparison_plot(pred_fhd, gt_fhd_patch, "FHD Comparison") chm_mae = np.nanmean(np.abs(pred_chm - gt_chm_patch)) chm_rmse = np.sqrt(np.nanmean((pred_chm - gt_chm_patch) ** 2)) pai_mae = np.nanmean(np.abs(pred_pai - gt_pai_patch)) pai_rmse = np.sqrt(np.nanmean((pred_pai - gt_pai_patch) ** 2)) fhd_mae = np.nanmean(np.abs(pred_fhd - gt_fhd_patch)) fhd_rmse = np.sqrt(np.nanmean((pred_fhd - gt_fhd_patch) ** 2)) status = f""" ✓ **Inference completed on 190m×190m patch** **Location:** ({center_x}, {center_y}) in input coordinates **Patch Metrics:** - CHM: MAE = {chm_mae:.4f} m, RMSE = {chm_rmse:.4f} m - PAI: MAE = {pai_mae:.4f}, RMSE = {pai_rmse:.4f} - FHD: MAE = {fhd_mae:.4f}, RMSE = {fhd_rmse:.4f} """ progress(1.0, desc="Complete!") plt.close("all") gc.collect() return ( full_img_with_box, rgb_patch_img, chm_pred_img, chm_gt_img, chm_plot, pai_pred_img, pai_gt_img, pai_plot, fhd_pred_img, fhd_gt_img, fhd_plot, status ) except Exception as e: import traceback plt.close("all") gc.collect() error_msg = f"❌ Error during inference:\n```\n{traceback.format_exc()}\n```" return (None,) * 11 + (error_msg,) # ============================================================ # GRADIO INTERFACE - 3 COLUMN LAYOUT # ============================================================ def create_demo(): with gr.Blocks(title="FSKD Inference Demo") as demo: gr.HTML("""
1Helmholtz Centre for Environmental Research -- UFZ, Halle (Saale), Germany
taimur.khan@ufz.de
2Leipzig University, Leipzig, Germany
3Georg-August University of Göttingen, Göttingen, Germany
Interactive Demo for Forest Structure Inference (CHM, PAI, FHD) from RGBI Imagery