Professional Noob commited on
Commit
5d0d264
·
verified ·
1 Parent(s): b43b225

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +156 -26
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import gc
 
3
  import gradio as gr
4
  import numpy as np
5
  import spaces
@@ -8,7 +9,7 @@ import random
8
  from PIL import Image
9
  from typing import Iterable, Optional
10
 
11
- from huggingface_hub import hf_hub_download
12
  from safetensors.torch import load_file as safetensors_load_file
13
 
14
  from gradio.themes import Soft
@@ -119,17 +120,119 @@ from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
119
 
120
  dtype = torch.bfloat16
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  pipe = QwenImageEditPlusPipeline.from_pretrained(
123
  "Qwen/Qwen-Image-Edit-2511",
124
  transformer=QwenImageTransformer2DModel.from_pretrained(
125
- "Pr0f3ssi0n4ln00b/Phr00t-Qwen-Rapid-AIO",
126
- subfolder="transformer",
127
  torch_dtype=dtype,
128
  device_map="cuda",
129
  ),
130
  torch_dtype=dtype,
131
  ).to(device)
132
 
 
 
133
  # Apply FA3 Optimization
134
  try:
135
  pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
@@ -285,11 +388,9 @@ LOADED_ADAPTERS = set()
285
  # Helpers: resolution
286
  # ============================================================
287
 
288
-
289
  def _round8(x: int) -> int:
290
  return max(8, (int(x) // 8) * 8)
291
 
292
-
293
  def compute_dimensions(image: Image.Image, long_edge: int) -> tuple[int, int]:
294
  w, h = image.size
295
  if w >= h:
@@ -300,25 +401,20 @@ def compute_dimensions(image: Image.Image, long_edge: int) -> tuple[int, int]:
300
  new_w = int(round(long_edge * (w / h)))
301
  return _round8(new_w), _round8(new_h)
302
 
303
-
304
  def get_target_long_edge_for_lora(lora_adapter: str) -> int:
305
  spec = ADAPTER_SPECS.get(lora_adapter, {})
306
  return int(spec.get("target_long_edge", 1024))
307
 
308
-
309
  # ============================================================
310
  # Helpers: multi-input routing + gallery normalization
311
  # ============================================================
312
 
313
-
314
  def lora_requires_two_images(lora_adapter: str) -> bool:
315
  return bool(ADAPTER_SPECS.get(lora_adapter, {}).get("requires_two_images", False))
316
 
317
-
318
  def image2_label_for_lora(lora_adapter: str) -> str:
319
  return str(ADAPTER_SPECS.get(lora_adapter, {}).get("image2_label", "Upload Reference (Image 2)"))
320
 
321
-
322
  def _to_pil_rgb(x) -> Optional[Image.Image]:
323
  """
324
  Accepts PIL / numpy / (image, caption) tuples from gr.Gallery and returns PIL RGB.
@@ -345,7 +441,6 @@ def _to_pil_rgb(x) -> Optional[Image.Image]:
345
  except Exception:
346
  return None
347
 
348
-
349
  def build_labeled_images(
350
  img1: Image.Image,
351
  img2: Optional[Image.Image],
@@ -377,12 +472,10 @@ def build_labeled_images(
377
 
378
  return labeled
379
 
380
-
381
  # ============================================================
382
  # Helpers: BFS alpha key fix
383
  # ============================================================
384
 
385
-
386
  def _inject_missing_alpha_keys(state_dict: dict) -> dict:
387
  """
388
  Diffusers' Qwen LoRA converter expects '<module>.alpha' keys.
@@ -418,7 +511,6 @@ def _inject_missing_alpha_keys(state_dict: dict) -> dict:
418
 
419
  return state_dict
420
 
421
-
422
  def _load_lora_weights_with_fallback(repo: str, weight_name: str, adapter_name: str, needs_alpha_fix: bool = False):
423
  """
424
  Normal path: pipe.load_lora_weights(repo, weight_name=..., adapter_name=...)
@@ -439,12 +531,10 @@ def _load_lora_weights_with_fallback(repo: str, weight_name: str, adapter_name:
439
  pipe.load_lora_weights(sd, adapter_name=adapter_name)
440
  return
441
 
442
-
443
  # ============================================================
444
  # LoRA loader: single/package + strengths
445
  # ============================================================
446
 
447
-
448
  def _ensure_loaded_and_get_active_adapters(selected_lora: str):
449
  spec = ADAPTER_SPECS.get(selected_lora)
450
  if not spec:
@@ -510,12 +600,10 @@ def _ensure_loaded_and_get_active_adapters(selected_lora: str):
510
 
511
  return adapter_names, adapter_weights
512
 
513
-
514
  # ============================================================
515
  # UI handlers
516
  # ============================================================
517
 
518
-
519
  def on_lora_change_ui(selected_lora, current_prompt):
520
  # Preset prompt (fill only if empty)
521
  if selected_lora != NONE_LORA:
@@ -535,19 +623,18 @@ def on_lora_change_ui(selected_lora, current_prompt):
535
 
536
  return prompt_update, img2_update
537
 
538
-
539
  # ============================================================
540
  # Inference
541
  # ============================================================
542
 
543
-
544
  @spaces.GPU
545
  def infer(
546
  input_image_1,
547
  input_image_2,
548
- input_images_extra, # NEW: gallery multi-image box
549
  prompt,
550
  lora_adapter,
 
551
  seed,
552
  randomize_seed,
553
  guidance_scale,
@@ -558,6 +645,10 @@ def infer(
558
  if torch.cuda.is_available():
559
  torch.cuda.empty_cache()
560
 
 
 
 
 
561
  if input_image_1 is None:
562
  raise gr.Error("Please upload Image 1.")
563
 
@@ -625,7 +716,6 @@ def infer(
625
  if torch.cuda.is_available():
626
  torch.cuda.empty_cache()
627
 
628
-
629
  @spaces.GPU
630
  def infer_example(input_image, prompt, lora_adapter):
631
  if input_image is None:
@@ -634,10 +724,20 @@ def infer_example(input_image, prompt, lora_adapter):
634
  guidance_scale = 1.0
635
  steps = 4
636
  # Examples don't supply Image 2 or extra images; and example list doesn't include AnyPose/BFS.
637
- result, seed = infer(input_pil, None, None, prompt, lora_adapter, 0, True, guidance_scale, steps)
 
 
 
 
 
 
 
 
 
 
 
638
  return result, seed
639
 
640
-
641
  # ============================================================
642
  # UI
643
  # ============================================================
@@ -664,7 +764,6 @@ with gr.Blocks() as demo:
664
  input_image_1 = gr.Image(label="Upload Image 1 (Base / Target)", type="pil", height=290)
665
  input_image_2 = gr.Image(label="Upload Reference (Image 2)", type="pil", height=290, visible=False)
666
 
667
- # NEW: multi-image input box (supports multiple images)
668
  input_images_extra = gr.Gallery(
669
  label="Upload Additional Images (auto-indexed after Image 1/2)",
670
  type="pil",
@@ -685,6 +784,22 @@ with gr.Blocks() as demo:
685
  with gr.Column():
686
  output_image = gr.Image(label="Output Image", interactive=False, format="png", height=353)
687
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
688
  with gr.Row():
689
  lora_choices = [NONE_LORA] + list(ADAPTER_SPECS.keys())
690
  lora_adapter = gr.Dropdown(
@@ -706,6 +821,20 @@ with gr.Blocks() as demo:
706
  outputs=[prompt, input_image_2],
707
  )
708
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
709
  gr.Examples(
710
  examples=[
711
  ["examples/1.jpg", "Transform into anime.", "Photo-to-Anime"],
@@ -746,9 +875,10 @@ with gr.Blocks() as demo:
746
  inputs=[
747
  input_image_1,
748
  input_image_2,
749
- input_images_extra, # NEW
750
  prompt,
751
  lora_adapter,
 
752
  seed,
753
  randomize_seed,
754
  guidance_scale,
 
1
  import os
2
  import gc
3
+ import re
4
  import gradio as gr
5
  import numpy as np
6
  import spaces
 
9
  from PIL import Image
10
  from typing import Iterable, Optional
11
 
12
+ from huggingface_hub import hf_hub_download, HfApi
13
  from safetensors.torch import load_file as safetensors_load_file
14
 
15
  from gradio.themes import Soft
 
120
 
121
  dtype = torch.bfloat16
122
 
123
+ # ------------------------------------------------------------
124
+ # AIO versioning
125
+ # ------------------------------------------------------------
126
+ AIO_REPO_ID = "Pr0f3ssi0n4ln00b/Phr00t-Qwen-Rapid-AIO"
127
+ DEFAULT_AIO_VERSION = "v19"
128
+ _VERSION_RE = re.compile(r"^v\d+$")
129
+
130
+ def discover_aio_versions(repo_id: str) -> list[str]:
131
+ """
132
+ Discovers versions that follow vXX/transformer/ in the HF repo.
133
+ Returns sorted list like: ['v19', 'v21', ...]
134
+ """
135
+ api = HfApi()
136
+ try:
137
+ files = api.list_repo_files(repo_id=repo_id, repo_type="model")
138
+ except Exception as e:
139
+ print(f"⚠️ Could not list repo files for {repo_id}: {e}")
140
+ return [DEFAULT_AIO_VERSION]
141
+
142
+ versions = set()
143
+ for p in files:
144
+ if "/transformer/" not in p:
145
+ continue
146
+ head = p.split("/transformer/", 1)[0] # "v19"
147
+ if _VERSION_RE.fullmatch(head):
148
+ versions.add(head)
149
+
150
+ if not versions:
151
+ versions = {DEFAULT_AIO_VERSION}
152
+
153
+ return sorted(versions, key=lambda x: int(x[1:]))
154
+
155
+ AVAILABLE_AIO_VERSIONS = discover_aio_versions(AIO_REPO_ID)
156
+
157
+ # Track currently loaded transformer version
158
+ CURRENT_AIO_VERSION: Optional[str] = None
159
+
160
+ def _free_cuda():
161
+ gc.collect()
162
+ if torch.cuda.is_available():
163
+ torch.cuda.empty_cache()
164
+
165
+ @spaces.GPU
166
+ def switch_aio_version(version: str):
167
+ """
168
+ Loads transformer weights from {version}/transformer/ into the already-created pipeline.
169
+ """
170
+ global CURRENT_AIO_VERSION, pipe
171
+
172
+ if version is None or str(version).strip() == "":
173
+ version = DEFAULT_AIO_VERSION
174
+
175
+ if CURRENT_AIO_VERSION == version:
176
+ return gr.update(value=f"✅ Already using {version}")
177
+
178
+ _free_cuda()
179
+
180
+ subfolder = f"{version}/transformer"
181
+ print(f"🔁 Switching AIO transformer to: {AIO_REPO_ID} / {subfolder}")
182
+
183
+ old_transformer = getattr(pipe, "transformer", None)
184
+
185
+ new_transformer = QwenImageTransformer2DModel.from_pretrained(
186
+ AIO_REPO_ID,
187
+ subfolder=subfolder,
188
+ torch_dtype=dtype,
189
+ device_map="cuda",
190
+ )
191
+ pipe.transformer = new_transformer
192
+
193
+ # Re-apply FA3 Optimization
194
+ try:
195
+ pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
196
+ print("Flash Attention 3 Processor set successfully.")
197
+ except Exception as e:
198
+ print(f"Warning: Could not set FA3 processor: {e}")
199
+
200
+ # Best-effort free old transformer reference
201
+ try:
202
+ del old_transformer
203
+ except Exception:
204
+ pass
205
+
206
+ _free_cuda()
207
+
208
+ CURRENT_AIO_VERSION = version
209
+ return gr.update(value=f"✅ Loaded {version} ({subfolder}/)")
210
+
211
+ def refresh_aio_versions():
212
+ global AVAILABLE_AIO_VERSIONS
213
+ AVAILABLE_AIO_VERSIONS = discover_aio_versions(AIO_REPO_ID)
214
+ new_default = DEFAULT_AIO_VERSION if DEFAULT_AIO_VERSION in AVAILABLE_AIO_VERSIONS else AVAILABLE_AIO_VERSIONS[0]
215
+ return (
216
+ gr.update(choices=AVAILABLE_AIO_VERSIONS, value=new_default),
217
+ gr.update(value=f"🔄 Found: {', '.join(AVAILABLE_AIO_VERSIONS)}")
218
+ )
219
+
220
+ # ------------------------------------------------------------
221
+ # Create pipeline (loads DEFAULT_AIO_VERSION only)
222
+ # ------------------------------------------------------------
223
  pipe = QwenImageEditPlusPipeline.from_pretrained(
224
  "Qwen/Qwen-Image-Edit-2511",
225
  transformer=QwenImageTransformer2DModel.from_pretrained(
226
+ AIO_REPO_ID,
227
+ subfolder=f"{DEFAULT_AIO_VERSION}/transformer",
228
  torch_dtype=dtype,
229
  device_map="cuda",
230
  ),
231
  torch_dtype=dtype,
232
  ).to(device)
233
 
234
+ CURRENT_AIO_VERSION = DEFAULT_AIO_VERSION
235
+
236
  # Apply FA3 Optimization
237
  try:
238
  pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
 
388
  # Helpers: resolution
389
  # ============================================================
390
 
 
391
  def _round8(x: int) -> int:
392
  return max(8, (int(x) // 8) * 8)
393
 
 
394
  def compute_dimensions(image: Image.Image, long_edge: int) -> tuple[int, int]:
395
  w, h = image.size
396
  if w >= h:
 
401
  new_w = int(round(long_edge * (w / h)))
402
  return _round8(new_w), _round8(new_h)
403
 
 
404
  def get_target_long_edge_for_lora(lora_adapter: str) -> int:
405
  spec = ADAPTER_SPECS.get(lora_adapter, {})
406
  return int(spec.get("target_long_edge", 1024))
407
 
 
408
  # ============================================================
409
  # Helpers: multi-input routing + gallery normalization
410
  # ============================================================
411
 
 
412
  def lora_requires_two_images(lora_adapter: str) -> bool:
413
  return bool(ADAPTER_SPECS.get(lora_adapter, {}).get("requires_two_images", False))
414
 
 
415
  def image2_label_for_lora(lora_adapter: str) -> str:
416
  return str(ADAPTER_SPECS.get(lora_adapter, {}).get("image2_label", "Upload Reference (Image 2)"))
417
 
 
418
  def _to_pil_rgb(x) -> Optional[Image.Image]:
419
  """
420
  Accepts PIL / numpy / (image, caption) tuples from gr.Gallery and returns PIL RGB.
 
441
  except Exception:
442
  return None
443
 
 
444
  def build_labeled_images(
445
  img1: Image.Image,
446
  img2: Optional[Image.Image],
 
472
 
473
  return labeled
474
 
 
475
  # ============================================================
476
  # Helpers: BFS alpha key fix
477
  # ============================================================
478
 
 
479
  def _inject_missing_alpha_keys(state_dict: dict) -> dict:
480
  """
481
  Diffusers' Qwen LoRA converter expects '<module>.alpha' keys.
 
511
 
512
  return state_dict
513
 
 
514
  def _load_lora_weights_with_fallback(repo: str, weight_name: str, adapter_name: str, needs_alpha_fix: bool = False):
515
  """
516
  Normal path: pipe.load_lora_weights(repo, weight_name=..., adapter_name=...)
 
531
  pipe.load_lora_weights(sd, adapter_name=adapter_name)
532
  return
533
 
 
534
  # ============================================================
535
  # LoRA loader: single/package + strengths
536
  # ============================================================
537
 
 
538
  def _ensure_loaded_and_get_active_adapters(selected_lora: str):
539
  spec = ADAPTER_SPECS.get(selected_lora)
540
  if not spec:
 
600
 
601
  return adapter_names, adapter_weights
602
 
 
603
  # ============================================================
604
  # UI handlers
605
  # ============================================================
606
 
 
607
  def on_lora_change_ui(selected_lora, current_prompt):
608
  # Preset prompt (fill only if empty)
609
  if selected_lora != NONE_LORA:
 
623
 
624
  return prompt_update, img2_update
625
 
 
626
  # ============================================================
627
  # Inference
628
  # ============================================================
629
 
 
630
  @spaces.GPU
631
  def infer(
632
  input_image_1,
633
  input_image_2,
634
+ input_images_extra, # gallery multi-image box
635
  prompt,
636
  lora_adapter,
637
+ aio_version, # NEW: selected AIO version
638
  seed,
639
  randomize_seed,
640
  guidance_scale,
 
645
  if torch.cuda.is_available():
646
  torch.cuda.empty_cache()
647
 
648
+ # Ensure the requested transformer version is loaded
649
+ if aio_version and aio_version != CURRENT_AIO_VERSION:
650
+ switch_aio_version(aio_version)
651
+
652
  if input_image_1 is None:
653
  raise gr.Error("Please upload Image 1.")
654
 
 
716
  if torch.cuda.is_available():
717
  torch.cuda.empty_cache()
718
 
 
719
  @spaces.GPU
720
  def infer_example(input_image, prompt, lora_adapter):
721
  if input_image is None:
 
724
  guidance_scale = 1.0
725
  steps = 4
726
  # Examples don't supply Image 2 or extra images; and example list doesn't include AnyPose/BFS.
727
+ result, seed = infer(
728
+ input_pil,
729
+ None,
730
+ None,
731
+ prompt,
732
+ lora_adapter,
733
+ CURRENT_AIO_VERSION or DEFAULT_AIO_VERSION, # NEW: keep whatever is loaded
734
+ 0,
735
+ True,
736
+ guidance_scale,
737
+ steps,
738
+ )
739
  return result, seed
740
 
 
741
  # ============================================================
742
  # UI
743
  # ============================================================
 
764
  input_image_1 = gr.Image(label="Upload Image 1 (Base / Target)", type="pil", height=290)
765
  input_image_2 = gr.Image(label="Upload Reference (Image 2)", type="pil", height=290, visible=False)
766
 
 
767
  input_images_extra = gr.Gallery(
768
  label="Upload Additional Images (auto-indexed after Image 1/2)",
769
  type="pil",
 
784
  with gr.Column():
785
  output_image = gr.Image(label="Output Image", interactive=False, format="png", height=353)
786
 
787
+ # NEW: AIO version selector + refresh
788
+ with gr.Row():
789
+ aio_version = gr.Dropdown(
790
+ label="Phr00t Rapid AIO Version",
791
+ choices=AVAILABLE_AIO_VERSIONS,
792
+ value=DEFAULT_AIO_VERSION if DEFAULT_AIO_VERSION in AVAILABLE_AIO_VERSIONS else AVAILABLE_AIO_VERSIONS[0],
793
+ interactive=True,
794
+ )
795
+ refresh_versions_btn = gr.Button("↻", scale=0)
796
+
797
+ aio_status = gr.Textbox(
798
+ label="Model Status",
799
+ value=f"Using {CURRENT_AIO_VERSION}",
800
+ interactive=False,
801
+ )
802
+
803
  with gr.Row():
804
  lora_choices = [NONE_LORA] + list(ADAPTER_SPECS.keys())
805
  lora_adapter = gr.Dropdown(
 
821
  outputs=[prompt, input_image_2],
822
  )
823
 
824
+ # On AIO version change: swap transformer
825
+ aio_version.change(
826
+ fn=switch_aio_version,
827
+ inputs=[aio_version],
828
+ outputs=[aio_status],
829
+ )
830
+
831
+ # Refresh available versions
832
+ refresh_versions_btn.click(
833
+ fn=refresh_aio_versions,
834
+ inputs=[],
835
+ outputs=[aio_version, aio_status],
836
+ )
837
+
838
  gr.Examples(
839
  examples=[
840
  ["examples/1.jpg", "Transform into anime.", "Photo-to-Anime"],
 
875
  inputs=[
876
  input_image_1,
877
  input_image_2,
878
+ input_images_extra,
879
  prompt,
880
  lora_adapter,
881
+ aio_version, # NEW
882
  seed,
883
  randomize_seed,
884
  guidance_scale,