Taimur Khan commited on
Commit
c702ff2
·
1 Parent(s): 38110d6
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.tif filter=lfs diff=lfs merge=lfs -text
37
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .DS_Store
app.py ADDED
@@ -0,0 +1,773 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Forest Metrics Inference Demo - Interactive Patch-based Version
4
+ Click anywhere on the input tile to run inference on a 45m×45m patch
5
+ """
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import rasterio
10
+ import gradio as gr
11
+ from transformers import SegformerModel
12
+ from torch import nn
13
+ from PIL import Image, ImageDraw
14
+ import matplotlib.pyplot as plt
15
+ import matplotlib.colors as mcolors
16
+ import os
17
+ import tempfile
18
+ import zipfile
19
+ import gc
20
+ from scipy.ndimage import zoom
21
+
22
+ # ============================================================
23
+ # CONFIG
24
+ # ============================================================
25
+ PATCH_SIZE = 224
26
+ DROP_BORDER = 16
27
+ STRIDE = 112
28
+ OUT_CHANNELS = 3
29
+
30
+ # Patch inference settings
31
+ PATCH_SIZE_METERS = 380 # 380m × 380m patch
32
+ GSD = 0.2 # Ground Sample Distance in meters/pixel
33
+ PATCH_SIZE_PIXELS = int(PATCH_SIZE_METERS / GSD) # 1900 pixels
34
+
35
+ # File paths
36
+ FIXED_INPUT_TIF = "data/dop20rgbi_33348_5612_2_sn.tif"
37
+ FIXED_GT_CHM = "data/lsc_33348_5612_2_sn_chm.tif"
38
+ FIXED_GT_PAI = "data/lsc_33348_5612_2_sn_pai.tif"
39
+ FIXED_GT_FHD = "data/lsc_33348_5612_2_sn_fhd.tif"
40
+ MODEL_PATH = "data/student_MiTB2.pt"
41
+ VIDEO_PATH = "data/overview.mp4"
42
+
43
+ # ============================================================
44
+ # DEVICE
45
+ # ============================================================
46
+ if torch.backends.mps.is_available():
47
+ DEVICE = "mps"
48
+ elif torch.cuda.is_available():
49
+ DEVICE = "cuda"
50
+ else:
51
+ DEVICE = "cpu"
52
+
53
+ print(f"[INFO] Using device: {DEVICE}")
54
+
55
+
56
+ # ============================================================
57
+ # UTILITY FUNCTIONS
58
+ # ============================================================
59
+ def make_weight_mask(out_size):
60
+ """Pyramidal weight mask for smooth blending."""
61
+ y = np.linspace(-1, 1, out_size)
62
+ x = np.linspace(-1, 1, out_size)
63
+ yy, xx = np.meshgrid(y, x)
64
+ w = (1 - np.abs(xx)) * (1 - np.abs(yy))
65
+ return w
66
+
67
+
68
+ def create_colormap_image(array, cmap_name="viridis", vmin=None, vmax=None):
69
+ """Convert array to RGB image with colormap."""
70
+ if vmin is None:
71
+ vmin = np.nanmin(array)
72
+ if vmax is None:
73
+ vmax = np.nanmax(array)
74
+
75
+ norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
76
+ cmap = plt.get_cmap(cmap_name)
77
+ rgba = cmap(norm(array))
78
+ rgb = (rgba[:, :, :3] * 255).astype(np.uint8)
79
+
80
+ return Image.fromarray(rgb)
81
+
82
+
83
+ def create_comparison_plot(pred, gt, title, cmap="viridis"):
84
+ """Create a scatter plot comparing predictions vs ground truth."""
85
+ # Flatten arrays and remove NaN values
86
+ pred_flat = pred.flatten()
87
+ gt_flat = gt.flatten()
88
+
89
+ # Create mask for valid values
90
+ valid_mask = ~(np.isnan(pred_flat) | np.isnan(gt_flat))
91
+ pred_valid = pred_flat[valid_mask]
92
+ gt_valid = gt_flat[valid_mask]
93
+
94
+ if len(pred_valid) == 0:
95
+ # Return empty plot if no valid data
96
+ fig, ax = plt.subplots(figsize=(6, 6))
97
+ ax.text(0.5, 0.5, "No valid data", ha="center", va="center")
98
+ ax.set_title(title)
99
+ plt.close(fig) # Close to prevent memory leak
100
+ return fig
101
+
102
+ # Create figure
103
+ fig, ax = plt.subplots(figsize=(6, 6))
104
+
105
+ # Scatter plot
106
+ ax.scatter(gt_valid, pred_valid, alpha=0.3, s=1, c="#047857")
107
+
108
+ # 1:1 line
109
+ min_val = min(np.min(gt_valid), np.min(pred_valid))
110
+ max_val = max(np.max(gt_valid), np.max(pred_valid))
111
+ ax.plot(
112
+ [min_val, max_val], [min_val, max_val], "r--", linewidth=2, label="1:1 line"
113
+ )
114
+
115
+ # Calculate metrics
116
+ mae = np.mean(np.abs(pred_valid - gt_valid))
117
+ rmse = np.sqrt(np.mean((pred_valid - gt_valid) ** 2))
118
+ r2 = np.corrcoef(pred_valid, gt_valid)[0, 1] ** 2
119
+
120
+ # Add metrics text
121
+ metrics_text = f"MAE: {mae:.3f}\nRMSE: {rmse:.3f}\nR²: {r2:.3f}"
122
+ ax.text(
123
+ 0.05,
124
+ 0.95,
125
+ metrics_text,
126
+ transform=ax.transAxes,
127
+ verticalalignment="top",
128
+ bbox=dict(boxstyle="round", facecolor="white", alpha=0.8),
129
+ )
130
+
131
+ ax.set_xlabel("Ground Truth", fontsize=12)
132
+ ax.set_ylabel("Prediction", fontsize=12)
133
+ ax.set_title(title, fontsize=14, fontweight="bold")
134
+ ax.legend()
135
+ ax.grid(True, alpha=0.3)
136
+
137
+ plt.tight_layout()
138
+ return fig
139
+
140
+
141
+ # ============================================================
142
+ # MODEL DEFINITION
143
+ # ============================================================
144
+ class SegFormerStudent(nn.Module):
145
+ def __init__(self, out_ch=OUT_CHANNELS):
146
+ super().__init__()
147
+
148
+ self.backbone = SegformerModel.from_pretrained(
149
+ "nvidia/mit-b2", use_safetensors=True, trust_remote_code=True
150
+ )
151
+ self.backbone.config.output_hidden_states = True
152
+
153
+ old_conv = self.backbone.encoder.patch_embeddings[0].proj
154
+ new_conv = nn.Conv2d(
155
+ 4,
156
+ old_conv.out_channels,
157
+ kernel_size=old_conv.kernel_size,
158
+ stride=old_conv.stride,
159
+ padding=old_conv.padding,
160
+ bias=old_conv.bias is not None,
161
+ )
162
+ with torch.no_grad():
163
+ new_conv.weight[:, :3, :, :] = old_conv.weight.clone()
164
+ new_conv.weight[:, 3:, :, :] = 0.0
165
+ if old_conv.bias is not None:
166
+ new_conv.bias.copy_(old_conv.bias)
167
+ self.backbone.encoder.patch_embeddings[0].proj = new_conv
168
+
169
+ self.dims = self.backbone.config.hidden_sizes
170
+ self.proj0 = nn.Conv2d(self.dims[0], 128, 1)
171
+ self.proj1 = nn.Conv2d(self.dims[1], 128, 1)
172
+ self.proj2 = nn.Conv2d(self.dims[2], 128, 1)
173
+ self.proj3 = nn.Conv2d(self.dims[3], 128, 1)
174
+ self.adapter = nn.Conv2d(128, 256, 1)
175
+
176
+ self.pred_head = nn.Sequential(
177
+ nn.Conv2d(128, 64, 3, padding=1), nn.GELU(), nn.Conv2d(64, out_ch, 1)
178
+ )
179
+
180
+ def forward(self, x):
181
+ outputs = self.backbone(pixel_values=x, output_hidden_states=True)
182
+ hs = outputs.hidden_states[-4:]
183
+
184
+ if len(hs[0].shape) == 3:
185
+ B, N, C = hs[0].shape
186
+ H = W = int(N**0.5)
187
+ f0, f1, f2, f3 = [h.permute(0, 2, 1).reshape(B, -1, H, W) for h in hs]
188
+ else:
189
+ f0, f1, f2, f3 = hs
190
+
191
+ fused = (
192
+ self.proj0(f0)
193
+ + F.interpolate(self.proj1(f1), size=f0.shape[-2:], mode="bilinear")
194
+ + F.interpolate(self.proj2(f2), size=f0.shape[-2:], mode="bilinear")
195
+ + F.interpolate(self.proj3(f3), size=f0.shape[-2:], mode="bilinear")
196
+ )
197
+
198
+ return self.pred_head(fused)
199
+
200
+
201
+ # ============================================================
202
+ # LOAD MODEL AND DATA AT STARTUP
203
+ # ============================================================
204
+ print("[INFO] Loading model...")
205
+ try:
206
+ model = SegFormerStudent().to(DEVICE)
207
+ state = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=False)
208
+ model.load_state_dict(state)
209
+ model.eval()
210
+ print("✓ Model loaded successfully!")
211
+ except Exception as e:
212
+ print(f"✗ Error loading model: {e}")
213
+ model = None
214
+
215
+ # Load input tile
216
+ print("[INFO] Loading input tile...")
217
+ try:
218
+ with rasterio.open(FIXED_INPUT_TIF) as src:
219
+ input_data_full = src.read().astype("float32") # C, H, W
220
+ input_profile = src.profile
221
+ input_bounds = src.bounds
222
+ input_transform = src.transform
223
+
224
+ # Create RGB preview for display
225
+ rgb_data = input_data_full[:3].copy()
226
+ rgb_data = np.transpose(rgb_data, (1, 2, 0)) # H, W, C
227
+ rgb_data = (rgb_data / rgb_data.max() * 255).astype(np.uint8)
228
+ input_preview = Image.fromarray(rgb_data)
229
+
230
+ print(f"✓ Input tile loaded: {input_data_full.shape}")
231
+ INPUT_LOADED = True
232
+ except Exception as e:
233
+ print(f"✗ Error loading input: {e}")
234
+ input_preview = None
235
+ INPUT_LOADED = False
236
+
237
+ # Load ground truth
238
+ print("[INFO] Loading ground truth data...")
239
+ try:
240
+ with rasterio.open(FIXED_GT_CHM) as src:
241
+ gt_chm_full = src.read(1).astype("float32")
242
+ gt_profile = src.profile
243
+ nodata_chm = src.nodata
244
+ if nodata_chm is not None:
245
+ gt_chm_full[gt_chm_full == nodata_chm] = np.nan
246
+
247
+ with rasterio.open(FIXED_GT_PAI) as src:
248
+ gt_pai_full = src.read(1).astype("float32")
249
+ nodata_pai = src.nodata
250
+ if nodata_pai is not None:
251
+ gt_pai_full[gt_pai_full == nodata_pai] = np.nan
252
+
253
+ with rasterio.open(FIXED_GT_FHD) as src:
254
+ gt_fhd_full = src.read(1).astype("float32")
255
+ nodata_fhd = src.nodata
256
+ if nodata_fhd is not None:
257
+ gt_fhd_full[gt_fhd_full == nodata_fhd] = np.nan
258
+ gt_fhd_full[gt_fhd_full < -9000] = np.nan
259
+
260
+ print(f"✓ Ground truth loaded")
261
+ print(f" CHM range: [{np.nanmin(gt_chm_full):.2f}, {np.nanmax(gt_chm_full):.2f}]")
262
+ print(f" PAI range: [{np.nanmin(gt_pai_full):.2f}, {np.nanmax(gt_pai_full):.2f}]")
263
+ print(f" FHD range: [{np.nanmin(gt_fhd_full):.2f}, {np.nanmax(gt_fhd_full):.2f}]")
264
+ GT_LOADED = True
265
+ except Exception as e:
266
+ print(f"✗ Error loading ground truth: {e}")
267
+ GT_LOADED = False
268
+
269
+
270
+ # ============================================================
271
+ # PATCH EXTRACTION AND INFERENCE
272
+ # ============================================================
273
+ def extract_patch(array, center_y, center_x, patch_size):
274
+ """Extract a patch centered at (center_y, center_x)."""
275
+ H, W = array.shape[-2:]
276
+
277
+ # Calculate patch boundaries
278
+ half_patch = patch_size // 2
279
+ y_start = max(0, center_y - half_patch)
280
+ y_end = min(H, center_y + half_patch)
281
+ x_start = max(0, center_x - half_patch)
282
+ x_end = min(W, center_x + half_patch)
283
+
284
+ # Extract patch
285
+ if array.ndim == 3: # C, H, W
286
+ patch = array[:, y_start:y_end, x_start:x_end]
287
+ else: # H, W
288
+ patch = array[y_start:y_end, x_start:x_end]
289
+
290
+ return patch, (y_start, y_end, x_start, x_end)
291
+
292
+
293
+ def run_patch_inference(evt: gr.SelectData, progress=gr.Progress()):
294
+ """Run inference on a patch centered at clicked location."""
295
+
296
+ if not INPUT_LOADED or not GT_LOADED or model is None:
297
+ return (
298
+ None,
299
+ None,
300
+ None,
301
+ None,
302
+ None,
303
+ None,
304
+ None,
305
+ None,
306
+ None,
307
+ None,
308
+ None,
309
+ "❌ Model or data not loaded",
310
+ )
311
+
312
+ # Clear any previous GPU memory
313
+ if DEVICE == "cuda":
314
+ torch.cuda.empty_cache()
315
+ elif DEVICE == "mps":
316
+ torch.mps.empty_cache()
317
+
318
+ progress(0, desc="Extracting patch...")
319
+
320
+ try:
321
+ # Get click coordinates from the event
322
+ # evt.index is [x, y] for images in Gradio
323
+ click_x, click_y = evt.index
324
+
325
+ print(f"[INFO] Click at display ({click_x}, {click_y})")
326
+
327
+ # Get display and original image dimensions
328
+ display_w, display_h = input_preview.size
329
+ original_h, original_w = input_data_full.shape[1], input_data_full.shape[2]
330
+
331
+ # Calculate scaling factors
332
+ scale_x = original_w / display_w
333
+ scale_y = original_h / display_h
334
+
335
+ # Convert to original coordinates
336
+ center_x = int(click_x * scale_x)
337
+ center_y = int(click_y * scale_y)
338
+
339
+ print(f"[INFO] Mapped to original ({center_x}, {center_y})")
340
+
341
+ progress(0.1, desc="Extracting input patch...")
342
+
343
+ # Extract input patch (RGBI) - from original high-res input
344
+ input_patch, (y_start, y_end, x_start, x_end) = extract_patch(
345
+ input_data_full, center_y, center_x, PATCH_SIZE_PIXELS
346
+ )
347
+
348
+ # Normalize
349
+ input_patch = input_patch / max(input_patch.max(), 1e-6)
350
+
351
+ # Ensure 4 channels
352
+ C, H, W = input_patch.shape
353
+ if C == 3:
354
+ input_patch = np.vstack(
355
+ [input_patch, np.zeros((1, H, W), dtype=np.float32)]
356
+ )
357
+
358
+ # Create RGB preview of patch
359
+ rgb_patch = input_patch[:3].copy()
360
+ rgb_patch = np.transpose(rgb_patch, (1, 2, 0))
361
+ rgb_patch = (rgb_patch * 255).astype(np.uint8)
362
+ rgb_patch_img = Image.fromarray(rgb_patch)
363
+
364
+ progress(0.2, desc="Preparing display...")
365
+
366
+ # Add boundary box to full image
367
+ full_img_with_box = input_preview.copy()
368
+ draw = ImageDraw.Draw(full_img_with_box)
369
+
370
+ # Scale box coordinates to display size
371
+ box_x_start = int(x_start / scale_x)
372
+ box_x_end = int(x_end / scale_x)
373
+ box_y_start = int(y_start / scale_y)
374
+ box_y_end = int(y_end / scale_y)
375
+
376
+ draw.rectangle(
377
+ [box_x_start, box_y_start, box_x_end, box_y_end], outline="red", width=8
378
+ )
379
+
380
+ progress(0.3, desc="Running model inference...")
381
+
382
+ # Run inference on patch
383
+ patch_tensor = torch.from_numpy(input_patch).unsqueeze(0).to(DEVICE)
384
+
385
+ with torch.inference_mode(): # More efficient than no_grad()
386
+ pred = model(patch_tensor)[0] # C, H_out, W_out
387
+
388
+ # Clear GPU memory after inference
389
+ del patch_tensor
390
+ if DEVICE == "cuda":
391
+ torch.cuda.empty_cache()
392
+ elif DEVICE == "mps":
393
+ torch.mps.empty_cache()
394
+
395
+ progress(0.5, desc="Processing predictions...")
396
+
397
+ # Upsample to original patch size
398
+ pred = (
399
+ F.interpolate(
400
+ pred.unsqueeze(0), size=(H, W), mode="bilinear", align_corners=False
401
+ )[0]
402
+ .cpu()
403
+ .numpy()
404
+ )
405
+
406
+ # Extract predictions
407
+ pred_chm = pred[0]
408
+ pred_pai = pred[1]
409
+ pred_fhd = pred[2]
410
+
411
+ progress(0.6, desc="Extracting ground truth...")
412
+
413
+ # Calculate corresponding GT coordinates
414
+ # GT has different resolution than input, so we need to scale coordinates
415
+ gt_h, gt_w = gt_chm_full.shape
416
+
417
+ # Scale center coordinates to GT resolution
418
+ gt_scale_y = gt_h / original_h
419
+ gt_scale_x = gt_w / original_w
420
+
421
+ gt_center_y = int(center_y * gt_scale_y)
422
+ gt_center_x = int(center_x * gt_scale_x)
423
+
424
+ # Scale patch size to GT resolution
425
+ gt_patch_size = int(PATCH_SIZE_PIXELS * gt_scale_y) # Assuming square pixels
426
+
427
+ # Extract ground truth patches at correct resolution
428
+ gt_chm_patch, _ = extract_patch(
429
+ gt_chm_full, gt_center_y, gt_center_x, gt_patch_size
430
+ )
431
+ gt_pai_patch, _ = extract_patch(
432
+ gt_pai_full, gt_center_y, gt_center_x, gt_patch_size
433
+ )
434
+ gt_fhd_patch, _ = extract_patch(
435
+ gt_fhd_full, gt_center_y, gt_center_x, gt_patch_size
436
+ )
437
+
438
+ progress(0.7, desc="Resampling predictions to match ground truth...")
439
+
440
+ # Resize predictions to match GT resolution
441
+ if pred_chm.shape != gt_chm_patch.shape:
442
+ zoom_factors = (
443
+ gt_chm_patch.shape[0] / pred_chm.shape[0],
444
+ gt_chm_patch.shape[1] / pred_chm.shape[1],
445
+ )
446
+ pred_chm = zoom(pred_chm, zoom_factors, order=1)
447
+ pred_pai = zoom(pred_pai, zoom_factors, order=1)
448
+ pred_fhd = zoom(pred_fhd, zoom_factors, order=1)
449
+ print(
450
+ f"[INFO] Resampled predictions from {pred.shape} to {gt_chm_patch.shape}"
451
+ )
452
+
453
+ progress(0.8, desc="Creating visualizations...")
454
+
455
+ # Create visualization images (patches are now large enough, no upscaling needed)
456
+ chm_pred_img = create_colormap_image(
457
+ pred_chm, "viridis", np.nanmin(gt_chm_full), np.nanmax(gt_chm_full)
458
+ )
459
+ pai_pred_img = create_colormap_image(
460
+ pred_pai, "Greens", np.nanmin(gt_pai_full), np.nanmax(gt_pai_full)
461
+ )
462
+ fhd_pred_img = create_colormap_image(
463
+ pred_fhd, "magma", np.nanmin(gt_fhd_full), np.nanmax(gt_fhd_full)
464
+ )
465
+
466
+ chm_gt_img = create_colormap_image(
467
+ gt_chm_patch, "viridis", np.nanmin(gt_chm_full), np.nanmax(gt_chm_full)
468
+ )
469
+ pai_gt_img = create_colormap_image(
470
+ gt_pai_patch, "Greens", np.nanmin(gt_pai_full), np.nanmax(gt_pai_full)
471
+ )
472
+ fhd_gt_img = create_colormap_image(
473
+ gt_fhd_patch, "magma", np.nanmin(gt_fhd_full), np.nanmax(gt_fhd_full)
474
+ )
475
+
476
+ progress(0.9, desc="Generating comparison plots...")
477
+
478
+ # Create comparison plots
479
+ chm_plot = create_comparison_plot(
480
+ pred_chm, gt_chm_patch, "CHM Comparison", "viridis"
481
+ )
482
+ pai_plot = create_comparison_plot(
483
+ pred_pai, gt_pai_patch, "PAI Comparison", "Greens"
484
+ )
485
+ fhd_plot = create_comparison_plot(
486
+ pred_fhd, gt_fhd_patch, "FHD Comparison", "magma"
487
+ )
488
+
489
+ # Free intermediate arrays to reduce memory footprint
490
+ del pred_chm, pred_pai, pred_fhd
491
+ del gt_chm_patch, gt_pai_patch, gt_fhd_patch
492
+
493
+ # Calculate metrics
494
+ chm_mae = np.nanmean(np.abs(pred_chm - gt_chm_patch))
495
+ chm_rmse = np.sqrt(np.nanmean((pred_chm - gt_chm_patch) ** 2))
496
+ pai_mae = np.nanmean(np.abs(pred_pai - gt_pai_patch))
497
+ pai_rmse = np.sqrt(np.nanmean((pred_pai - gt_pai_patch) ** 2))
498
+ fhd_mae = np.nanmean(np.abs(pred_fhd - gt_fhd_patch))
499
+ fhd_rmse = np.sqrt(np.nanmean((pred_fhd - gt_fhd_patch) ** 2))
500
+
501
+ status = f"""
502
+ ✓ **Inference completed on 45m×45m patch**
503
+
504
+ **Location:** ({center_x}, {center_y}) in input coordinates
505
+ **Patch Metrics:**
506
+ - CHM: MAE = {chm_mae:.4f} m, RMSE = {chm_rmse:.4f} m
507
+ - PAI: MAE = {pai_mae:.4f}, RMSE = {pai_rmse:.4f}
508
+ - FHD: MAE = {fhd_mae:.4f}, RMSE = {fhd_rmse:.4f}
509
+ - Input patch: {H}×{W} pixels ({PATCH_SIZE_METERS}m×{PATCH_SIZE_METERS}m)
510
+ - GT patch: {gt_chm_patch.shape[0]}×{gt_chm_patch.shape[1]} pixels
511
+ """
512
+
513
+ progress(1.0, desc="Complete!")
514
+
515
+ # Close all matplotlib figures to prevent memory leaks
516
+ plt.close('all')
517
+
518
+ # Force garbage collection
519
+ gc.collect()
520
+ if DEVICE == "cuda":
521
+ torch.cuda.empty_cache()
522
+ elif DEVICE == "mps":
523
+ torch.mps.empty_cache()
524
+
525
+ return (
526
+ full_img_with_box,
527
+ rgb_patch_img,
528
+ chm_pred_img,
529
+ pai_pred_img,
530
+ fhd_pred_img,
531
+ chm_gt_img,
532
+ pai_gt_img,
533
+ fhd_gt_img,
534
+ chm_plot,
535
+ pai_plot,
536
+ fhd_plot,
537
+ status,
538
+ )
539
+
540
+ except Exception as e:
541
+ import traceback
542
+
543
+ # Cleanup on error
544
+ plt.close('all')
545
+ gc.collect()
546
+ if DEVICE == "cuda":
547
+ torch.cuda.empty_cache()
548
+ elif DEVICE == "mps":
549
+ torch.mps.empty_cache()
550
+
551
+ error_msg = f"❌ Error during inference:\n```\n{traceback.format_exc()}\n```"
552
+ return (
553
+ None,
554
+ None,
555
+ None,
556
+ None,
557
+ None,
558
+ None,
559
+ None,
560
+ None,
561
+ None,
562
+ None,
563
+ None,
564
+ error_msg,
565
+ )
566
+
567
+
568
+ # ============================================================
569
+ # GRADIO INTERFACE
570
+ # ============================================================
571
+ def create_demo():
572
+
573
+ with gr.Blocks(title="FSKD Inference Demo") as demo:
574
+ # <sup>1</sup>, Co-Author Name<sup>1,2</sup>, Another Author<sup>1</sup>
575
+ gr.HTML(
576
+ """
577
+ <div style="text-align: center; padding: 2.5rem 2rem; background: linear-gradient(135deg, #065f46 0%, #047857 100%); border-radius: 1rem; margin-bottom: 1.5rem; color: white;">
578
+ <h1 style="font-size: 2.5rem; font-weight: 700; margin-bottom: 0.5rem; color: white;">Monocular Forest Structure Inference via LiDAR-to-RGBI Knowledge Distillation</h1>
579
+ <h2 style="font-size: 1.2rem; font-weight: 400; margin: 1rem 0 0.5rem 0; color: white; opacity: 0.95;">
580
+ Taimur Khan<sup style="color: red;">1,2</sup>
581
+ </h2>
582
+ <p style="font-size: 0.85rem; margin: 0.25rem 0 0.75rem 0; color: white; opacity: 0.8;">
583
+ <sup style="color: red;">1</sup>Leipzig University, <br/>
584
+ <sup style="color: red;">2</sup>Helmholtz Centre for Environmental Research -- UFZ
585
+ </p>
586
+ <div style="display: flex; justify-content: center; gap: 1rem; margin-top: 1.5rem; flex-wrap: wrap;">
587
+ <button disabled style="padding: 0.75rem 1.5rem; background: linear-gradient(135deg, #24292e 0%, #1a1e22 100%); color: white; border: none; border-radius: 0.5rem; font-weight: 600; cursor: not-allowed; opacity: 0.7; font-size: 0.95rem;">
588
+ 💻 Code (Coming Soon)
589
+ </button>
590
+ <button disabled style="padding: 0.75rem 1.5rem; background: linear-gradient(135deg, #FF9D00 0%, #FF7A00 100%); color: white; border: none; border-radius: 0.5rem; font-weight: 600; cursor: not-allowed; opacity: 0.7; font-size: 0.95rem;">
591
+ 🤗 Model (Coming Soon)
592
+ </button>
593
+ <button disabled style="padding: 0.75rem 1.5rem; background: linear-gradient(135deg, #FF9D00 0%, #FF7A00 100%); color: white; border: none; border-radius: 0.5rem; font-weight: 600; cursor: not-allowed; opacity: 0.7; font-size: 0.95rem;">
594
+ 🤗 Dataset (Coming Soon)
595
+ </button>
596
+ <button disabled style="padding: 0.75rem 1.5rem; background: linear-gradient(135deg, #B31B1B 0%, #8B0000 100%); color: white; border: none; border-radius: 0.5rem; font-weight: 600; cursor: not-allowed; opacity: 0.7; font-size: 0.95rem;">
597
+ 📄 Paper (Coming Soon)
598
+ </button>
599
+ </div>
600
+ <p style="font-size: 1.1rem; opacity: 0.95; color: white;">Interactive Patch-based Deep Learning Demo for Forest Structure Inference <br/> (CHM, PAI, and FHD) from RGBI Aerial Image with SegFormer-B2</p>
601
+
602
+ </div>
603
+ """
604
+ )
605
+
606
+ # Video overview
607
+ if os.path.exists(VIDEO_PATH):
608
+ gr.Video(
609
+ value=VIDEO_PATH,
610
+ label="📹 Simplified Paper Overview (NotebookLM Generated)",
611
+ autoplay=False,
612
+ show_label=True,
613
+ )
614
+
615
+ gr.Markdown("---")
616
+
617
+ # Main interaction area
618
+ with gr.Row():
619
+ with gr.Column(scale=1):
620
+ gr.Markdown(
621
+ "### 👇🏼 Click anywhere on the input tile to get inferred forest metrics"
622
+ )
623
+ input_tile_display = gr.Image(
624
+ value=input_preview if INPUT_LOADED else None,
625
+ label="Input RGBI Tile (RGB preview)",
626
+ type="pil",
627
+ interactive=False,
628
+ )
629
+
630
+ status_display = gr.Markdown(
631
+ value="Awaiting user input...",
632
+ )
633
+
634
+ gr.Markdown(
635
+ """
636
+ ### 📖 About
637
+ This demo allows interactive inference of forest structure metrics from aerial RGBI imagery using a deep learning model trained via knowledge distillation from LiDAR data.
638
+
639
+ **Model**: SegFormer (MiT-B2) trained with knowledge distillation
640
+
641
+ **Patch Size**: 380m×380m ≈ 1900×1900 pixels at 0.2m GSD
642
+
643
+ **Test Tile ID**: 3348_5612_2
644
+ """
645
+ )
646
+
647
+ with gr.Column(scale=1):
648
+ gr.Markdown("### 🔍 Selected Patch")
649
+ patch_preview = gr.Image(
650
+ label="RGB Patch Preview", type="pil", interactive=False
651
+ )
652
+
653
+ gr.Markdown("---")
654
+ gr.Markdown("## 📊 Row 1: Predictions")
655
+
656
+ with gr.Row():
657
+ chm_pred_display = gr.Image(
658
+ label="CHM - Prediction", type="pil", interactive=False, height=600
659
+ )
660
+ pai_pred_display = gr.Image(
661
+ label="PAI - Prediction", type="pil", interactive=False, height=600
662
+ )
663
+ fhd_pred_display = gr.Image(
664
+ label="FHD - Prediction", type="pil", interactive=False, height=600
665
+ )
666
+
667
+ gr.Markdown("## 🎯 Row 2: Ground Truth")
668
+
669
+ with gr.Row():
670
+ chm_gt_display = gr.Image(
671
+ label="CHM - Ground Truth", type="pil", interactive=False, height=600
672
+ )
673
+ pai_gt_display = gr.Image(
674
+ label="PAI - Ground Truth", type="pil", interactive=False, height=600
675
+ )
676
+ fhd_gt_display = gr.Image(
677
+ label="FHD - Ground Truth", type="pil", interactive=False, height=600
678
+ )
679
+
680
+ gr.Markdown("## 📈 Row 3: Prediction vs Ground Truth Comparison")
681
+
682
+ with gr.Row():
683
+ chm_comparison = gr.Plot(label="CHM: Predicted vs Ground Truth")
684
+ pai_comparison = gr.Plot(label="PAI: Predicted vs Ground Truth")
685
+ fhd_comparison = gr.Plot(label="FHD: Predicted vs Ground Truth")
686
+
687
+ # Info section
688
+ gr.Markdown(
689
+ """
690
+ ---
691
+ ### 📚 Metric Definitions
692
+
693
+ - **CHM (Canopy Height Model)**: Height of forest canopy above ground (meters)
694
+ - **PAI (Plant Area Index)**: Total one-sided area of plant material per unit ground area
695
+ - **FHD (Foliage Height Diversity)**: Vertical distribution of vegetation layers
696
+
697
+ ### 🎯 How to Use
698
+ 1. Click anywhere on the input tile image
699
+ 2. The model runs inference on a 45m×45m patch at that location
700
+ 3. View predictions, ground truth, and comparison plots
701
+ 4. Click different locations to explore the tile
702
+ """
703
+ )
704
+
705
+ # Connect click event
706
+ input_tile_display.select(
707
+ fn=run_patch_inference,
708
+ inputs=[],
709
+ outputs=[
710
+ input_tile_display,
711
+ patch_preview,
712
+ chm_pred_display,
713
+ pai_pred_display,
714
+ fhd_pred_display,
715
+ chm_gt_display,
716
+ pai_gt_display,
717
+ fhd_gt_display,
718
+ chm_comparison,
719
+ pai_comparison,
720
+ fhd_comparison,
721
+ status_display,
722
+ ],
723
+ )
724
+
725
+ return demo
726
+
727
+
728
+ # ============================================================
729
+ # LAUNCH
730
+ # ============================================================
731
+ if __name__ == "__main__":
732
+
733
+ # Check files
734
+ print("\n" + "=" * 70)
735
+ print("Checking required files...")
736
+ print("=" * 70)
737
+
738
+ for path, name in [
739
+ (MODEL_PATH, "Model checkpoint"),
740
+ (FIXED_INPUT_TIF, "Input RGBI tile"),
741
+ (FIXED_GT_CHM, "Ground truth CHM"),
742
+ (FIXED_GT_PAI, "Ground truth PAI"),
743
+ (FIXED_GT_FHD, "Ground truth FHD"),
744
+ (VIDEO_PATH, "Overview video"),
745
+ ]:
746
+ if os.path.exists(path):
747
+ size_mb = os.path.getsize(path) / (1024**2)
748
+ print(f"✓ {name}: {path} ({size_mb:.1f} MB)")
749
+ else:
750
+ print(f"✗ {name}: {path} (NOT FOUND)")
751
+
752
+ print("=" * 70 + "\n")
753
+
754
+ # Launch
755
+ custom_theme = gr.themes.Soft(
756
+ primary_hue="emerald",
757
+ secondary_hue="teal",
758
+ )
759
+
760
+ custom_css = """
761
+ .gradio-container {
762
+ max-width: 1900px !important;
763
+ }
764
+ """
765
+
766
+ demo = create_demo()
767
+ demo.launch(
768
+ share=False,
769
+ server_name="0.0.0.0",
770
+ server_port=7860,
771
+ theme=custom_theme,
772
+ css=custom_css,
773
+ )
data/dop20rgbi_33348_5612_2_sn.tif ADDED

Git LFS Details

  • SHA256: a47bf4339bcb3bb90b703ea5181c96c8769aed275e330db379c3f009dc60640d
  • Pointer size: 134 Bytes
  • Size of remote file: 400 MB
data/lsc_33348_5612_2_sn_chm.tif ADDED

Git LFS Details

  • SHA256: 67858544eb4e7b14b733421c380127728559a881cb29d9562728e0cdb65c424e
  • Pointer size: 133 Bytes
  • Size of remote file: 32 MB
data/lsc_33348_5612_2_sn_fhd.tif ADDED

Git LFS Details

  • SHA256: caa4559e6554e25de057123d150d2eca8c638c8c3db56262c341fbefd326ee6e
  • Pointer size: 133 Bytes
  • Size of remote file: 32 MB
data/lsc_33348_5612_2_sn_pai.tif ADDED

Git LFS Details

  • SHA256: 8d68ebed8b83d4792e50c7d2c42106a3f9a06e9e11a8e5c29a478221fc00f2d2
  • Pointer size: 133 Bytes
  • Size of remote file: 32 MB
data/overview.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a9d61b4092346e73bca76bc063a90767e8e0bdc3535783c0b167496655a73852
3
+ size 37614381
data/student_MiTB2.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c507e2978fe053b8e77e974abc83a9e4e20638a0562e5f30ad59924342e1979a
3
+ size 97904267
requirements.txt ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==1.11.0
2
+ affine==2.4.0
3
+ aiohappyeyeballs==2.6.1
4
+ aiohttp==3.12.15
5
+ aiosignal==1.4.0
6
+ alabaster==0.7.16
7
+ ansicolors==1.1.8
8
+ appdirs==1.4.4
9
+ asn1crypto==1.5.1
10
+ asttokens==3.0.0
11
+ atomicwrites==1.4.1
12
+ attrs==23.2.0
13
+ azure-core==1.35.1
14
+ azure-datalake-store==1.0.1
15
+ azure-identity==1.25.0
16
+ azure-storage-blob==12.26.0
17
+ Babel==2.15.0
18
+ backports.entry-points-selectable==1.3.0
19
+ backports.functools-lru-cache==2.0.0
20
+ beniget==0.4.1
21
+ bitarray==2.9.2
22
+ bitstring==4.2.3
23
+ black==25.9.0
24
+ blist==1.3.6
25
+ boto3==1.40.40
26
+ botocore==1.40.40
27
+ Bottleneck==1.3.8
28
+ CacheControl==0.14.0
29
+ cachetools==5.5.2
30
+ cachy==0.3.0
31
+ certifi==2024.6.2
32
+ cffi==1.16.0
33
+ chardet==5.2.0
34
+ charset-normalizer==3.3.2
35
+ cleo==2.1.0
36
+ click==8.1.7
37
+ click-plugins==1.1.1
38
+ cligj==0.7.2
39
+ cloudpickle==3.0.0
40
+ colorama==0.4.6
41
+ comm==0.2.2
42
+ commonmark==0.9.1
43
+ contourpy==1.3.2
44
+ crashtest==0.4.1
45
+ cryptography==42.0.8
46
+ cycler==0.12.1
47
+ deap==1.4.1
48
+ debugpy==1.8.15
49
+ decorator==5.1.1
50
+ distlib==0.3.8
51
+ distro==1.9.0
52
+ docopt==0.6.2
53
+ docutils==0.21.2
54
+ doit==0.36.0
55
+ dulwich==0.22.1
56
+ ecdsa==0.19.0
57
+ editables==0.5
58
+ einops==0.8.1
59
+ entrypoints==0.4
60
+ exceptiongroup==1.2.1
61
+ execnet==2.1.1
62
+ executing==2.2.0
63
+ fastjsonschema==2.21.2
64
+ filelock==3.15.1
65
+ fonttools==4.59.0
66
+ frozenlist==1.7.0
67
+ fsspec==2024.6.0
68
+ future==1.0.0
69
+ gast==0.5.4
70
+ gcsfs==2025.9.0
71
+ GDAL==3.10.3
72
+ geopandas==1.1.1
73
+ glob2==0.7
74
+ google-api-core==2.25.1
75
+ google-auth==2.40.3
76
+ google-auth-oauthlib==1.2.2
77
+ google-cloud-core==2.4.3
78
+ google-cloud-storage==3.4.0
79
+ google-crc32c==1.7.1
80
+ google-resumable-media==2.7.2
81
+ googleapis-common-protos==1.70.0
82
+ hf-xet==1.1.5
83
+ html5lib==1.1
84
+ huggingface-hub==0.36.0
85
+ idna==3.7
86
+ imagesize==1.4.1
87
+ importlib-metadata==7.1.0
88
+ importlib-resources==6.4.0
89
+ iniconfig==2.0.0
90
+ intervaltree==3.1.0
91
+ intreehooks==1.0
92
+ ipaddress==1.0.23
93
+ ipykernel==6.30.0
94
+ ipython==9.4.0
95
+ ipython-pygments-lexers==1.1.1
96
+ isodate==0.7.2
97
+ jaraco.classes==3.4.0
98
+ jaraco.context==5.3.0
99
+ jedi==0.19.2
100
+ jeepney==0.8.0
101
+ Jinja2==3.1.4
102
+ jmespath==1.0.1
103
+ joblib==1.4.2
104
+ jsonschema==4.22.0
105
+ jsonschema-specifications==2023.12.1
106
+ jupyter-client==8.7.1
107
+ jupyter-core==5.7.3
108
+ kiwisolver==1.4.8
109
+ keyring==25.2.1
110
+ lark==1.1.9
111
+ liac-arff==2.5.0
112
+ lockfile==0.12.2
113
+ lxml==5.2.2
114
+ MarkupSafe==2.1.5
115
+ matplotlib==3.9.4
116
+ matplotlib-inline==0.1.7
117
+ mock==5.1.0
118
+ more-itertools==10.2.0
119
+ mpi4py==4.0.3
120
+ mpmath==1.4.0
121
+ msal==1.34.1
122
+ msal-extensions==1.2.2
123
+ msgpack==1.0.8
124
+ multidict==6.1.0
125
+ mypy-extensions==1.0.0
126
+ nest-asyncio==1.6.0
127
+ netCDF4==1.7.3
128
+ networkx==3.5.1
129
+ nose==1.3.7
130
+ numexpr==2.10.0
131
+ numpy==1.26.4
132
+ oauthlib==3.2.2
133
+ packaging==24.1
134
+ pandas==2.2.2
135
+ parso==0.8.4
136
+ pathspec==0.12.1
137
+ pbr==6.1.0
138
+ pexpect==4.9.0
139
+ pillow==10.4.0
140
+ pkgconfig==1.5.5
141
+ platformdirs==4.2.2
142
+ ply==3.11
143
+ pooch==1.8.2
144
+ prompt-toolkit==3.0.51
145
+ propcache==0.3.2
146
+ proto-plus==1.26.1
147
+ protobuf==6.32.1
148
+ psutil==5.9.8
149
+ ptyprocess==0.7.0
150
+ pure-eval==0.2.3
151
+ py==1.11.0
152
+ py-expression-eval==0.3.14
153
+ pyarrow==21.0.0
154
+ pyasn1==0.6.0
155
+ pyasn1-modules==0.4.2
156
+ pycparser==2.22
157
+ pycryptodome==3.20.0
158
+ pydevtool==0.3.0
159
+ pyforestscan==0.3.0
160
+ PyGithub==2.8.1
161
+ Pygments==2.18.0
162
+ PyJWT==2.10.1
163
+ pylev==1.4.0
164
+ PyNaCl==1.5.0
165
+ pyogrio==0.11.0
166
+ pyparsing==3.1.2
167
+ pyproj==3.7.1
168
+ pyrsistent==0.20.0
169
+ pytest==8.2.2
170
+ pytest-xdist==3.6.1
171
+ python-dateutil==2.9.0.post0
172
+ pythran==0.16.1
173
+ pytokens==0.1.10
174
+ pytoml==0.1.21
175
+ pytz==2024.1
176
+ PyYAML==6.0.2
177
+ pyzmq==27.0.0
178
+ rapidfuzz==3.9.3
179
+ rasterio==1.4.3
180
+ referencing==0.35.1
181
+ regex==2024.5.15
182
+ requests==2.32.3
183
+ requests-oauthlib==2.0.0
184
+ requests-toolbelt==1.0.0
185
+ rich==13.7.1
186
+ rich-click==1.8.3
187
+ rioxarray==0.19.0
188
+ rpds-py==0.18.1
189
+ rsa==4.9.1
190
+ s3transfer==0.14.0
191
+ safetensors==0.5.3
192
+ scandir==1.10.0
193
+ scipy==1.13.1
194
+ seaborn==0.13.2
195
+ SecretStorage==3.3.3
196
+ semantic-version==2.10.0
197
+ setuptools==80.9.0
198
+ shapely==2.1.1
199
+ shellingham==1.5.4
200
+ simplegeneric==0.8.1
201
+ simplejson==3.19.2
202
+ six==1.16.0
203
+ snowballstemmer==2.2.0
204
+ sortedcontainers==2.4.0
205
+ Sphinx==7.3.7
206
+ sphinx-bootstrap-theme==0.8.1
207
+ sphinxcontrib-applehelp==1.0.8
208
+ sphinxcontrib-devhelp==1.0.6
209
+ sphinxcontrib-htmlhelp==2.0.5
210
+ sphinxcontrib-jsmath==1.0.1
211
+ sphinxcontrib-qthelp==1.0.7
212
+ sphinxcontrib-serializinghtml==1.1.10
213
+ sphinxcontrib-websupport==1.2.7
214
+ stack-data==0.6.3
215
+ sympy==1.14.0
216
+ tabulate==0.9.0
217
+ tenacity==9.1.2
218
+ threadpoolctl==3.5.0
219
+ timm==1.0.17
220
+ tokenizers==0.22.1
221
+ toml==0.10.2
222
+ tomli==2.0.1
223
+ tomli-w==1.0.0
224
+ tomlkit==0.12.5
225
+ torch==2.7.1
226
+ torchvision==0.22.1
227
+ tornado==6.5.1
228
+ tqdm==4.67.1
229
+ traitlets==5.14.3
230
+ transformers==4.57.1
231
+ triton==3.3.1
232
+ typing-extensions==4.14.1
233
+ tzdata==2024.1
234
+ ujson==5.10.0
235
+ urllib3==2.2.1
236
+ versioneer==0.29
237
+ virtualenv==20.26.2
238
+ wcwidth==0.2.13
239
+ webencodings==0.5.1
240
+ xarray==2025.7.1
241
+ xlrd==2.0.1
242
+ yarl==1.20.1
243
+ zipfile36==0.1.3
244
+ zipp==3.19.2