Professional Noob commited on
Commit
d8b9abb
·
verified ·
1 Parent(s): e32d379

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +332 -14
app.py CHANGED
@@ -7,9 +7,17 @@ import numpy as np
7
  import spaces
8
  import torch
9
  import random
10
- from PIL import Image
11
  from typing import Iterable, Optional
12
 
 
 
 
 
 
 
 
 
13
  from huggingface_hub import hf_hub_download
14
  from safetensors.torch import load_file as safetensors_load_file
15
 
@@ -194,6 +202,221 @@ except Exception as e:
194
 
195
  MAX_SEED = np.iinfo(np.int32).max
196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  # ============================================================
198
  # LoRA adapters + presets
199
  # ============================================================
@@ -777,6 +1000,55 @@ def on_lora_change_ui(selected_lora, current_prompt, current_extras_condition_on
777
  extras_update = gr.update(value=current_extras_condition_only)
778
 
779
  return prompt_update, img2_update, extras_update
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
780
 
781
 
782
  # ============================================================
@@ -869,12 +1141,9 @@ def infer(
869
  try:
870
  print(
871
  "[DEBUG][infer] submitting request | "
872
- f"lora_adapter={lora_adapter!r} seed={seed} prompt={prompt!r} "
873
- f"canvas={width}x{height} target_area={target_area} "
874
- f"extras_condition_only={extras_condition_only} vae_image_indices={vae_image_indices} "
875
- f"pad_to_canvas={bool(pad_to_canvas)}"
876
  )
877
-
878
  result = pipe(
879
  image=pipe_images,
880
  prompt=prompt,
@@ -887,7 +1156,7 @@ def infer(
887
  vae_image_indices=vae_image_indices,
888
  pad_to_canvas=bool(pad_to_canvas),
889
  ).images[0]
890
- return result, seed
891
  finally:
892
  gc.collect()
893
  if torch.cuda.is_available():
@@ -897,13 +1166,14 @@ def infer(
897
  @spaces.GPU
898
  def infer_example(input_image, prompt, lora_adapter):
899
  if input_image is None:
900
- return None, 0
901
  input_pil = input_image.convert("RGB")
902
  guidance_scale = 1.0
903
  steps = 4
904
  # Examples don't supply Image 2 or extra images; and example list doesn't include AnyPose/BFS.
905
- result, seed = infer(input_pil, None, None, prompt, lora_adapter, 0, True, guidance_scale, steps, 1.0, True, True)
906
- return result, seed
 
907
 
908
 
909
  # ============================================================
@@ -958,6 +1228,21 @@ with gr.Blocks() as demo:
958
  with gr.Column():
959
  output_image = gr.Image(label="Output Image", interactive=False, format="png", height=353)
960
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
961
  with gr.Row():
962
  lora_choices = [NONE_LORA] + list(ADAPTER_SPECS.keys())
963
  lora_adapter = gr.Dropdown(
@@ -967,6 +1252,27 @@ with gr.Blocks() as demo:
967
  )
968
 
969
  with gr.Accordion("Advanced Settings", open=False, visible=True):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
970
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
971
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
972
  guidance_scale = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=10.0, step=0.1, value=1.0)
@@ -987,7 +1293,7 @@ with gr.Blocks() as demo:
987
  value=True,
988
  )
989
 
990
- # On LoRA selection: preset prompt + toggle Image 2 + default extras routing
991
  lora_adapter.change(
992
  fn=on_lora_change_ui,
993
  inputs=[lora_adapter, prompt, extras_condition_only],
@@ -1022,7 +1328,7 @@ with gr.Blocks() as demo:
1022
  ["examples/11.jpg", "Upscale this picture to 4K resolution.", "Upscale2K"],
1023
  ],
1024
  inputs=[input_image_1, prompt, lora_adapter],
1025
- outputs=[output_image, seed],
1026
  fn=infer_example,
1027
  cache_examples=False,
1028
  label="Examples",
@@ -1044,9 +1350,21 @@ with gr.Blocks() as demo:
1044
  extras_condition_only,
1045
  pad_to_canvas,
1046
  ],
1047
- outputs=[output_image, seed],
1048
  )
1049
 
 
 
 
 
 
 
 
 
 
 
 
 
1050
  if __name__ == "__main__":
1051
  demo.queue(max_size=30).launch(
1052
  css=css,
@@ -1054,4 +1372,4 @@ if __name__ == "__main__":
1054
  mcp_server=True,
1055
  ssr_mode=False,
1056
  show_error=True,
1057
- )
 
7
  import spaces
8
  import torch
9
  import random
10
+ from PIL import Image, ImageDraw
11
  from typing import Iterable, Optional
12
 
13
+ from transformers import (
14
+ AutoProcessor,
15
+ RTDetrForObjectDetection,
16
+ VitPoseForPoseEstimation,
17
+ AutoImageProcessor,
18
+ AutoModelForDepthEstimation,
19
+ )
20
+
21
  from huggingface_hub import hf_hub_download
22
  from safetensors.torch import load_file as safetensors_load_file
23
 
 
202
 
203
  MAX_SEED = np.iinfo(np.int32).max
204
 
205
+ # ============================================================
206
+ # Derived conditioning (Transformers): Pose + Depth
207
+ # ============================================================
208
+ # Pose estimation uses ViTPose (top-down). Official docs show RT-DETR -> ViTPose flow:
209
+ # https://huggingface.co/docs/transformers/model_doc/vitpose
210
+ # Depth uses Depth Anything V2 Small (Transformers-compatible):
211
+ # https://huggingface.co/depth-anything/Depth-Anything-V2-Small-hf
212
+
213
+ POSE_MODEL_ID = "usyd-community/vitpose-base-simple"
214
+ POSE_DETECTOR_ID = "PekingU/rtdetr_r50vd_coco_o365"
215
+ DEPTH_MODEL_ID = "depth-anything/Depth-Anything-V2-Small-hf"
216
+
217
+ # Lazy caches keyed by device string ("cpu" / "cuda")
218
+ _POSE_CACHE = {}
219
+ _DEPTH_CACHE = {}
220
+
221
+ # COCO-17 skeleton connections (approx "OpenPose-like" stick figure)
222
+ COCO17_EDGES = [
223
+ (0, 1), (0, 2), (1, 3), (2, 4), # head
224
+ (5, 6), # shoulders
225
+ (5, 7), (7, 9), # left arm
226
+ (6, 8), (8, 10), # right arm
227
+ (5, 11), (6, 12), (11, 12), # torso
228
+ (11, 13), (13, 15), # left leg
229
+ (12, 14), (14, 16), # right leg
230
+ ]
231
+
232
+ def _derived_device(use_gpu: bool) -> torch.device:
233
+ return torch.device("cuda" if (use_gpu and torch.cuda.is_available()) else "cpu")
234
+
235
+
236
+ def _load_pose_models(dev: torch.device):
237
+ key = str(dev)
238
+ if key in _POSE_CACHE:
239
+ return _POSE_CACHE[key]
240
+
241
+ # Detector (optional but used for multi-person boxes)
242
+ det_proc = AutoProcessor.from_pretrained(POSE_DETECTOR_ID)
243
+ det_model = RTDetrForObjectDetection.from_pretrained(POSE_DETECTOR_ID).to(dev)
244
+
245
+ # Pose model
246
+ pose_proc = AutoProcessor.from_pretrained(POSE_MODEL_ID)
247
+ pose_model = VitPoseForPoseEstimation.from_pretrained(POSE_MODEL_ID).to(dev)
248
+
249
+ det_model.eval()
250
+ pose_model.eval()
251
+
252
+ _POSE_CACHE[key] = (det_proc, det_model, pose_proc, pose_model)
253
+ return _POSE_CACHE[key]
254
+
255
+
256
+ def _load_depth_models(dev: torch.device):
257
+ key = str(dev)
258
+ if key in _DEPTH_CACHE:
259
+ return _DEPTH_CACHE[key]
260
+
261
+ proc = AutoImageProcessor.from_pretrained(DEPTH_MODEL_ID)
262
+ model = AutoModelForDepthEstimation.from_pretrained(DEPTH_MODEL_ID).to(dev)
263
+ model.eval()
264
+
265
+ _DEPTH_CACHE[key] = (proc, model)
266
+ return _DEPTH_CACHE[key]
267
+
268
+
269
+ def _draw_skeleton_on_blank(
270
+ size: tuple[int, int],
271
+ persons_keypoints: list[np.ndarray],
272
+ persons_scores: list[np.ndarray],
273
+ kp_thresh: float = 0.20,
274
+ point_r: int = 3,
275
+ line_w: int = 3,
276
+ ) -> Image.Image:
277
+ w, h = size
278
+ canvas = Image.new("RGB", (w, h), (0, 0, 0))
279
+ draw = ImageDraw.Draw(canvas)
280
+
281
+ for kps, sc in zip(persons_keypoints, persons_scores):
282
+ # Draw edges
283
+ for a, b in COCO17_EDGES:
284
+ if a >= len(sc) or b >= len(sc):
285
+ continue
286
+ if sc[a] < kp_thresh or sc[b] < kp_thresh:
287
+ continue
288
+ xa, ya = float(kps[a, 0]), float(kps[a, 1])
289
+ xb, yb = float(kps[b, 0]), float(kps[b, 1])
290
+ draw.line([(xa, ya), (xb, yb)], fill=(255, 255, 255), width=line_w)
291
+
292
+ # Draw keypoints
293
+ for i in range(min(len(sc), len(kps))):
294
+ if sc[i] < kp_thresh:
295
+ continue
296
+ x, y = float(kps[i, 0]), float(kps[i, 1])
297
+ draw.ellipse(
298
+ [(x - point_r, y - point_r), (x + point_r, y + point_r)],
299
+ fill=(255, 255, 255),
300
+ outline=None,
301
+ )
302
+
303
+ return canvas
304
+
305
+
306
+ def make_pose_map(
307
+ img: Image.Image,
308
+ *,
309
+ use_gpu: bool,
310
+ mode: str,
311
+ det_thresh: float = 0.30,
312
+ max_people: int = 4,
313
+ ) -> Image.Image:
314
+ """Return an OpenPose-like skeleton map (RGB) using Transformers models.
315
+
316
+ mode:
317
+ - "fast": full-frame box (no detector). Good when Image 1 is already a single subject.
318
+ - "detect": RT-DETR person boxes -> ViTPose. Better for multi-person scenes.
319
+ """
320
+ img = img.convert("RGB")
321
+ dev = _derived_device(use_gpu)
322
+ det_proc, det_model, pose_proc, pose_model = _load_pose_models(dev)
323
+
324
+ w, h = img.size
325
+
326
+ if mode == "fast":
327
+ # Single box covering whole image, COCO format [x, y, w, h]
328
+ boxes = np.array([[0.0, 0.0, float(w), float(h)]], dtype=np.float32)
329
+ else:
330
+ # Detect people
331
+ inputs = det_proc(images=img, return_tensors="pt").to(dev)
332
+ with torch.no_grad():
333
+ outputs = det_model(**inputs)
334
+
335
+ results = det_proc.post_process_object_detection(
336
+ outputs,
337
+ target_sizes=torch.tensor([(h, w)], device=dev),
338
+ threshold=det_thresh,
339
+ )[0]
340
+
341
+ # COCO label 0 is "person" for COCO-trained detectors
342
+ person_boxes = results["boxes"][results["labels"] == 0].detach().cpu().numpy()
343
+
344
+ if person_boxes.size == 0:
345
+ # Fallback to full-frame
346
+ boxes = np.array([[0.0, 0.0, float(w), float(h)]], dtype=np.float32)
347
+ else:
348
+ # Convert VOC x1,y1,x2,y2 to COCO x,y,w,h
349
+ person_boxes[:, 2] = person_boxes[:, 2] - person_boxes[:, 0]
350
+ person_boxes[:, 3] = person_boxes[:, 3] - person_boxes[:, 1]
351
+ boxes = person_boxes.astype(np.float32)
352
+
353
+ if boxes.shape[0] > max_people:
354
+ boxes = boxes[:max_people]
355
+
356
+ pose_inputs = pose_proc(img, boxes=[boxes], return_tensors="pt").to(dev)
357
+ with torch.no_grad():
358
+ pose_outputs = pose_model(**pose_inputs)
359
+
360
+ pose_results = pose_proc.post_process_pose_estimation(pose_outputs, boxes=[boxes])[0]
361
+
362
+ persons_kps = []
363
+ persons_sc = []
364
+ for pr in pose_results:
365
+ kps = pr["keypoints"].detach().cpu().numpy()
366
+ sc = pr["scores"].detach().cpu().numpy()
367
+ persons_kps.append(kps)
368
+ persons_sc.append(sc)
369
+
370
+ if not persons_kps:
371
+ # No pose found; return black canvas
372
+ return Image.new("RGB", img.size, (0, 0, 0))
373
+
374
+ return _draw_skeleton_on_blank(img.size, persons_kps, persons_sc)
375
+
376
+
377
+ def make_depth_map(img: Image.Image, *, use_gpu: bool) -> Image.Image:
378
+ """Return a grayscale (RGB) depth map using Depth Anything V2 Small."""
379
+ img = img.convert("RGB")
380
+ dev = _derived_device(use_gpu)
381
+ proc, model = _load_depth_models(dev)
382
+
383
+ inputs = proc(images=img, return_tensors="pt")
384
+ inputs = {k: v.to(dev) for k, v in inputs.items()}
385
+
386
+ with torch.no_grad():
387
+ out = model(**inputs)
388
+
389
+ # predicted_depth: (B, H, W)
390
+ pred = out.predicted_depth
391
+
392
+ # Upsample to original image size
393
+ pred = torch.nn.functional.interpolate(
394
+ pred.unsqueeze(1),
395
+ size=(img.height, img.width),
396
+ mode="bicubic",
397
+ align_corners=False,
398
+ ).squeeze(1)[0]
399
+
400
+ arr = pred.detach().float().cpu().numpy()
401
+ arr = arr - float(arr.min())
402
+ denom = float(arr.max()) + 1e-8
403
+ arr = arr / denom
404
+
405
+ depth8 = (arr * 255.0).clip(0, 255).astype(np.uint8)
406
+ depth_img = Image.fromarray(depth8, mode="L").convert("RGB")
407
+ return depth_img
408
+
409
+
410
+ def _append_to_gallery(existing, new_img: Image.Image):
411
+ items = []
412
+ if existing:
413
+ for it in existing:
414
+ pil = _to_pil_rgb(it)
415
+ if pil is not None:
416
+ items.append(pil)
417
+ items.append(new_img)
418
+ return items
419
+
420
  # ============================================================
421
  # LoRA adapters + presets
422
  # ============================================================
 
1000
  extras_update = gr.update(value=current_extras_condition_only)
1001
 
1002
  return prompt_update, img2_update, extras_update
1003
+ # ============================================================
1004
+ # UI helpers: output routing + derived conditioning
1005
+ # ============================================================
1006
+
1007
+ def set_output_as_image1(last):
1008
+ if last is None:
1009
+ raise gr.Error("No output available yet.")
1010
+ return gr.update(value=last)
1011
+
1012
+
1013
+ def set_output_as_image2(last):
1014
+ if last is None:
1015
+ raise gr.Error("No output available yet.")
1016
+ return gr.update(value=last)
1017
+
1018
+
1019
+ def set_output_as_extra(last, existing_extra):
1020
+ if last is None:
1021
+ raise gr.Error("No output available yet.")
1022
+ return _append_to_gallery(existing_extra, last)
1023
+
1024
+
1025
+ @spaces.GPU
1026
+ def add_derived_ref(img1, existing_extra, derived_type, derived_use_gpu, derived_max_people):
1027
+ if img1 is None:
1028
+ raise gr.Error("Please upload Image 1 first.")
1029
+
1030
+ if derived_type == "None":
1031
+ return gr.update(value=existing_extra), gr.update(visible=False, value=None)
1032
+
1033
+ base = img1.convert("RGB")
1034
+
1035
+ if derived_type == "Pose (ViTPose, fast)":
1036
+ derived = make_pose_map(base, use_gpu=bool(derived_use_gpu), mode="fast")
1037
+ elif derived_type == "Pose (ViTPose + RT-DETR detect)":
1038
+ derived = make_pose_map(
1039
+ base,
1040
+ use_gpu=bool(derived_use_gpu),
1041
+ mode="detect",
1042
+ max_people=int(derived_max_people),
1043
+ )
1044
+ elif derived_type == "Depth (Depth Anything V2 Small)":
1045
+ derived = make_depth_map(base, use_gpu=bool(derived_use_gpu))
1046
+ else:
1047
+ raise gr.Error(f"Unknown derived type: {derived_type}")
1048
+
1049
+ new_gallery = _append_to_gallery(existing_extra, derived)
1050
+ return gr.update(value=new_gallery), gr.update(visible=True, value=derived)
1051
+
1052
 
1053
 
1054
  # ============================================================
 
1141
  try:
1142
  print(
1143
  "[DEBUG][infer] submitting request | "
1144
+ f"lora_adapter={lora_adapter!r} seed={seed} prompt={prompt!r}"
 
 
 
1145
  )
1146
+
1147
  result = pipe(
1148
  image=pipe_images,
1149
  prompt=prompt,
 
1156
  vae_image_indices=vae_image_indices,
1157
  pad_to_canvas=bool(pad_to_canvas),
1158
  ).images[0]
1159
+ return result, seed, result
1160
  finally:
1161
  gc.collect()
1162
  if torch.cuda.is_available():
 
1166
  @spaces.GPU
1167
  def infer_example(input_image, prompt, lora_adapter):
1168
  if input_image is None:
1169
+ return None, 0, None
1170
  input_pil = input_image.convert("RGB")
1171
  guidance_scale = 1.0
1172
  steps = 4
1173
  # Examples don't supply Image 2 or extra images; and example list doesn't include AnyPose/BFS.
1174
+ result, seed, last = infer(input_pil, None, None, prompt, lora_adapter, 0, True, guidance_scale, steps)
1175
+ return result, seed, last
1176
+ , result
1177
 
1178
 
1179
  # ============================================================
 
1228
  with gr.Column():
1229
  output_image = gr.Image(label="Output Image", interactive=False, format="png", height=353)
1230
 
1231
+ last_output = gr.State(value=None)
1232
+
1233
+ with gr.Row():
1234
+ btn_out_to_img1 = gr.Button("⬅️ Output → Image 1", variant="secondary")
1235
+ btn_out_to_img2 = gr.Button("⬅️ Output → Image 2", variant="secondary")
1236
+ btn_out_to_extra = gr.Button("➕ Output → Extra Ref", variant="secondary")
1237
+
1238
+ derived_preview = gr.Image(
1239
+ label="Derived Conditioning Preview",
1240
+ interactive=False,
1241
+ format="png",
1242
+ height=200,
1243
+ visible=False,
1244
+ )
1245
+
1246
  with gr.Row():
1247
  lora_choices = [NONE_LORA] + list(ADAPTER_SPECS.keys())
1248
  lora_adapter = gr.Dropdown(
 
1252
  )
1253
 
1254
  with gr.Accordion("Advanced Settings", open=False, visible=True):
1255
+ with gr.Accordion("Derived Conditioning (Pose / Depth)", open=False):
1256
+ derived_type = gr.Dropdown(
1257
+ label="Derived Type (from Image 1)",
1258
+ choices=[
1259
+ "None",
1260
+ "Pose (ViTPose, fast)",
1261
+ "Pose (ViTPose + RT-DETR detect)",
1262
+ "Depth (Depth Anything V2 Small)",
1263
+ ],
1264
+ value="None",
1265
+ )
1266
+ derived_use_gpu = gr.Checkbox(label="Use GPU for derived model", value=False)
1267
+ derived_max_people = gr.Slider(
1268
+ label="Max people (pose detect mode)",
1269
+ minimum=1,
1270
+ maximum=10,
1271
+ step=1,
1272
+ value=4,
1273
+ )
1274
+ add_derived_btn = gr.Button("➕ Add derived ref to Extras (conditioning-only recommended)")
1275
+
1276
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
1277
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
1278
  guidance_scale = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=10.0, step=0.1, value=1.0)
 
1293
  value=True,
1294
  )
1295
 
1296
+ # On LoRA selection: preset prompt + toggle Image 2
1297
  lora_adapter.change(
1298
  fn=on_lora_change_ui,
1299
  inputs=[lora_adapter, prompt, extras_condition_only],
 
1328
  ["examples/11.jpg", "Upscale this picture to 4K resolution.", "Upscale2K"],
1329
  ],
1330
  inputs=[input_image_1, prompt, lora_adapter],
1331
+ outputs=[output_image, seed, last_output],
1332
  fn=infer_example,
1333
  cache_examples=False,
1334
  label="Examples",
 
1350
  extras_condition_only,
1351
  pad_to_canvas,
1352
  ],
1353
+ outputs=[output_image, seed, last_output],
1354
  )
1355
 
1356
+ # Output routing buttons
1357
+ btn_out_to_img1.click(fn=set_output_as_image1, inputs=[last_output], outputs=[input_image_1])
1358
+ btn_out_to_img2.click(fn=set_output_as_image2, inputs=[last_output], outputs=[input_image_2])
1359
+ btn_out_to_extra.click(fn=set_output_as_extra, inputs=[last_output, input_images_extra], outputs=[input_images_extra])
1360
+
1361
+ # Derived conditioning: append pose/depth map as extra ref (UI shows preview)
1362
+ add_derived_btn.click(
1363
+ fn=add_derived_ref,
1364
+ inputs=[input_image_1, input_images_extra, derived_type, derived_use_gpu, derived_max_people],
1365
+ outputs=[input_images_extra, derived_preview],
1366
+ )
1367
+
1368
  if __name__ == "__main__":
1369
  demo.queue(max_size=30).launch(
1370
  css=css,
 
1372
  mcp_server=True,
1373
  ssr_mode=False,
1374
  show_error=True,
1375
+ )