Professional Noob commited on
Commit
bfef88e
·
verified ·
1 Parent(s): d5b8c31

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +193 -139
app.py CHANGED
@@ -1,8 +1,9 @@
1
  import os
2
  import re
3
  import gc
4
- import threading
5
  import random
 
 
6
  from typing import Iterable, Optional
7
 
8
  import gradio as gr
@@ -112,7 +113,7 @@ if torch.cuda.is_available():
112
  print("Using device:", device)
113
 
114
  # ============================================================
115
- # Pipeline + AIO versioning
116
  # ============================================================
117
 
118
  from diffusers import FlowMatchEulerDiscreteScheduler # noqa: F401
@@ -122,60 +123,97 @@ from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
122
 
123
  dtype = torch.bfloat16
124
 
 
 
 
 
125
  AIO_REPO_ID = "Pr0f3ssi0n4ln00b/Phr00t-Qwen-Rapid-AIO"
126
  DEFAULT_AIO_VERSION = "v19"
127
- _AIO_VERSION_RE = re.compile(r"^(v\d+)/transformer/")
128
 
129
- AIO_SWITCH_LOCK = threading.Lock()
 
 
130
  CURRENT_AIO_VERSION = DEFAULT_AIO_VERSION
131
 
132
 
133
- def _discover_aio_versions(repo_id: str) -> list[str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  """
135
- Lists versions by scanning repo files for paths like:
136
- v19/transformer/...
137
- v21/transformer/...
138
- Returns sorted: v1, v2, v10, ...
139
  """
140
  try:
141
  api = HfApi()
142
- files = api.list_repo_files(repo_id=repo_id)
143
  versions = set()
144
  for f in files:
145
- m = _AIO_VERSION_RE.match(f)
146
- if m:
147
- versions.add(m.group(1))
 
 
 
148
  if not versions:
149
  return [DEFAULT_AIO_VERSION]
150
- # numeric sort on the digits after 'v'
151
- out = sorted(list(versions), key=lambda s: int(s[1:]) if s[1:].isdigit() else 10**9)
152
- return out
153
  except Exception as e:
154
- print(f"⚠️ Could not discover AIO versions from repo: {e}")
155
  return [DEFAULT_AIO_VERSION]
156
 
157
 
158
- AIO_VERSIONS = _discover_aio_versions(AIO_REPO_ID)
159
- if DEFAULT_AIO_VERSION not in AIO_VERSIONS and AIO_VERSIONS:
160
- CURRENT_AIO_VERSION = AIO_VERSIONS[0]
161
- else:
162
- CURRENT_AIO_VERSION = DEFAULT_AIO_VERSION
163
 
164
 
165
- def _load_aio_transformer(version: str) -> QwenImageTransformer2DModel:
166
  """
167
- IMPORTANT: we do NOT pass device_map="cuda" here.
168
- Loading with device_map can trigger diffusers' CUDA caching-allocator warmup path,
169
- which is where your NVML/PyTorch allocator assert is happening on MIG/ZeroGPU.
 
170
  """
171
- subfolder = f"{version}/transformer"
172
- print(f"📦 Loading AIO transformer: {AIO_REPO_ID} / {subfolder} (CPU -> then move to {device})")
173
- t = QwenImageTransformer2DModel.from_pretrained(
174
- AIO_REPO_ID,
175
- subfolder=subfolder,
176
- torch_dtype=dtype,
177
- )
178
- return t
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
 
181
  def _apply_fa3_if_possible():
@@ -186,33 +224,18 @@ def _apply_fa3_if_possible():
186
  print(f"Warning: Could not set FA3 processor: {e}")
187
 
188
 
189
- def _hard_cuda_cleanup():
190
- gc.collect()
191
- if torch.cuda.is_available():
192
- try:
193
- torch.cuda.synchronize()
194
- except Exception:
195
- pass
196
- torch.cuda.empty_cache()
197
- try:
198
- torch.cuda.ipc_collect()
199
- except Exception:
200
- pass
201
-
202
-
203
- # Build pipeline once. We only swap pipe.transformer at runtime.
204
  pipe = QwenImageEditPlusPipeline.from_pretrained(
205
  "Qwen/Qwen-Image-Edit-2511",
206
- transformer=_load_aio_transformer(CURRENT_AIO_VERSION),
 
 
 
 
 
207
  torch_dtype=dtype,
208
  ).to(device)
209
 
210
- # move transformer to device (pipeline .to() might not fully move nested module in some custom pipelines)
211
- try:
212
- pipe.transformer.to(device)
213
- except Exception:
214
- pass
215
-
216
  _apply_fa3_if_possible()
217
 
218
  MAX_SEED = np.iinfo(np.int32).max
@@ -566,12 +589,11 @@ def _ensure_loaded_and_get_active_adapters(selected_lora: str):
566
 
567
 
568
  # ============================================================
569
- # AIO version switching (robust for MIG/ZeroGPU)
570
  # ============================================================
571
 
572
 
573
  def _unload_all_loras():
574
- # When swapping transformer versions, previously loaded LoRAs are no longer safe to keep around.
575
  global LOADED_ADAPTERS
576
  try:
577
  pipe.set_adapters([], adapter_weights=[])
@@ -584,39 +606,39 @@ def _unload_all_loras():
584
  LOADED_ADAPTERS.clear()
585
 
586
 
587
- def switch_aio_version(target_version: str):
588
  """
589
- Swap only the transformer module inside the pipeline.
590
-
591
- Key safeguards vs your crash:
592
- - detach & move old transformer off GPU before loading a new one
593
- - load new transformer without device_map (CPU), then move to GPU
594
- - clear LoRAs because module graph changes across transformers
595
  """
596
  global CURRENT_AIO_VERSION
597
 
598
- if not target_version:
599
- return
600
- if target_version == CURRENT_AIO_VERSION:
601
  return
602
 
603
- with AIO_SWITCH_LOCK:
604
- if target_version == CURRENT_AIO_VERSION:
605
  return
606
 
607
- print(f"🔁 Switching AIO transformer to: {AIO_REPO_ID} / {target_version}/transformer")
608
 
609
- # Make sure no adapters are active and free adapter memory references
610
  _unload_all_loras()
611
 
612
- # Detach old transformer as aggressively as possible
613
  old_t = getattr(pipe, "transformer", None)
 
 
 
 
 
 
 
614
  try:
615
- # Register a tiny placeholder so the pipeline drops references to the huge module
616
- pipe.register_modules(transformer=torch.nn.Identity())
617
  except Exception:
618
- # Fallback: direct attribute overwrite
619
- pipe.transformer = torch.nn.Identity()
620
 
621
  if old_t is not None:
622
  try:
@@ -627,25 +649,22 @@ def switch_aio_version(target_version: str):
627
 
628
  _hard_cuda_cleanup()
629
 
630
- # Load on CPU (no device_map) then move to GPU
631
- new_t = _load_aio_transformer(target_version)
632
- try:
633
- new_t.to(device)
634
- except Exception as e:
635
- # Ensure we don't leave partially loaded modules around
636
- del new_t
637
- _hard_cuda_cleanup()
638
- raise gr.Error(f"Failed to move transformer {target_version} to {device}: {e}")
639
-
640
- # Swap in
641
  try:
642
- pipe.register_modules(transformer=new_t)
643
  except Exception:
644
  pipe.transformer = new_t
645
 
646
  _apply_fa3_if_possible()
647
 
648
- CURRENT_AIO_VERSION = target_version
649
  _hard_cuda_cleanup()
650
 
651
 
@@ -655,6 +674,7 @@ def switch_aio_version(target_version: str):
655
 
656
 
657
  def on_lora_change_ui(selected_lora, current_prompt):
 
658
  if selected_lora != NONE_LORA:
659
  preset = LORA_PRESET_PROMPTS.get(selected_lora, "")
660
  if preset and (current_prompt is None or str(current_prompt).strip() == ""):
@@ -664,6 +684,7 @@ def on_lora_change_ui(selected_lora, current_prompt):
664
  else:
665
  prompt_update = gr.update(value=current_prompt)
666
 
 
667
  if lora_requires_two_images(selected_lora):
668
  img2_update = gr.update(visible=True, label=image2_label_for_lora(selected_lora))
669
  else:
@@ -673,14 +694,19 @@ def on_lora_change_ui(selected_lora, current_prompt):
673
 
674
 
675
  def refresh_aio_versions_ui(current_value: str):
676
- global AIO_VERSIONS
677
- AIO_VERSIONS = _discover_aio_versions(AIO_REPO_ID)
678
- if current_value in AIO_VERSIONS:
 
 
 
 
679
  new_value = current_value
680
  else:
681
- new_value = DEFAULT_AIO_VERSION if DEFAULT_AIO_VERSION in AIO_VERSIONS else (AIO_VERSIONS[0] if AIO_VERSIONS else DEFAULT_AIO_VERSION)
682
 
683
- return gr.update(choices=AIO_VERSIONS, value=new_value), f"Found {len(AIO_VERSIONS)} version(s): {', '.join(AIO_VERSIONS)}"
 
684
 
685
 
686
  # ============================================================
@@ -702,58 +728,62 @@ def infer(
702
  steps,
703
  progress=gr.Progress(track_tqdm=True),
704
  ):
705
- _hard_cuda_cleanup()
 
706
 
707
- if input_image_1 is None:
708
- raise gr.Error("Please upload Image 1.")
709
 
710
- # Ensure selected AIO version is loaded
711
- if aio_version and aio_version != CURRENT_AIO_VERSION:
712
- switch_aio_version(aio_version)
713
 
714
- # Handle "None"
715
- if lora_adapter == NONE_LORA:
716
- try:
717
- pipe.set_adapters([], adapter_weights=[])
718
- except Exception:
719
- if LOADED_ADAPTERS:
720
- pipe.set_adapters(list(LOADED_ADAPTERS), adapter_weights=[0.0] * len(LOADED_ADAPTERS))
721
- else:
722
- adapter_names, adapter_weights = _ensure_loaded_and_get_active_adapters(lora_adapter)
723
- pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
724
 
725
- if randomize_seed:
726
- seed = random.randint(0, MAX_SEED)
727
 
728
- generator = torch.Generator(device=device).manual_seed(seed)
729
- negative_prompt = (
730
- "worst quality, low quality, bad anatomy, bad hands, text, error, missing fingers, "
731
- "extra digit, fewer digits, cropped, jpeg artifacts, signature, watermark, username, blurry"
732
- )
733
 
734
- img1 = input_image_1.convert("RGB")
735
- img2 = input_image_2.convert("RGB") if input_image_2 is not None else None
736
 
737
- extra_imgs: list[Image.Image] = []
738
- if input_images_extra:
739
- for item in input_images_extra:
740
- pil = _to_pil_rgb(item)
741
- if pil is not None:
742
- extra_imgs.append(pil)
 
743
 
744
- if lora_requires_two_images(lora_adapter) and img2 is None:
745
- raise gr.Error("This LoRA needs two images. Please upload Image 2 as well.")
 
746
 
747
- labeled = build_labeled_images(img1, img2, extra_imgs)
 
748
 
749
- pipe_images = list(labeled.values())
750
- if len(pipe_images) == 1:
751
- pipe_images = pipe_images[0]
752
 
753
- target_long_edge = get_target_long_edge_for_lora(lora_adapter)
754
- width, height = compute_dimensions(img1, target_long_edge)
 
755
 
756
- try:
757
  result = pipe(
758
  image=pipe_images,
759
  prompt=prompt,
@@ -764,7 +794,11 @@ def infer(
764
  generator=generator,
765
  true_cfg_scale=guidance_scale,
766
  ).images[0]
 
767
  return result, seed
 
 
 
768
  finally:
769
  _hard_cuda_cleanup()
770
 
@@ -776,7 +810,7 @@ def infer_example(input_image, prompt, lora_adapter):
776
  input_pil = input_image.convert("RGB")
777
  guidance_scale = 1.0
778
  steps = 4
779
- # Examples always run on current loaded AIO version
780
  result, seed = infer(CURRENT_AIO_VERSION, input_pil, None, None, prompt, lora_adapter, 0, True, guidance_scale, steps)
781
  return result, seed
782
 
@@ -805,11 +839,24 @@ with gr.Blocks() as demo:
805
  with gr.Row():
806
  aio_version = gr.Dropdown(
807
  label="Phr00t Rapid AIO Version",
808
- choices=AIO_VERSIONS,
809
- value=CURRENT_AIO_VERSION if CURRENT_AIO_VERSION in AIO_VERSIONS else (AIO_VERSIONS[0] if AIO_VERSIONS else DEFAULT_AIO_VERSION),
 
810
  )
811
  refresh_versions = gr.Button("Refresh", variant="secondary")
812
- aio_status = gr.Markdown(f"Found {len(AIO_VERSIONS)} version(s): {', '.join(AIO_VERSIONS)}")
 
 
 
 
 
 
 
 
 
 
 
 
813
 
814
  refresh_versions.click(
815
  fn=refresh_aio_versions_ui,
@@ -897,7 +944,14 @@ with gr.Blocks() as demo:
897
  label="Examples",
898
  )
899
 
 
 
 
900
  run_button.click(
 
 
 
 
901
  fn=infer,
902
  inputs=[
903
  aio_version,
 
1
  import os
2
  import re
3
  import gc
 
4
  import random
5
+ import threading
6
+ import traceback
7
  from typing import Iterable, Optional
8
 
9
  import gradio as gr
 
113
  print("Using device:", device)
114
 
115
  # ============================================================
116
+ # Pipeline
117
  # ============================================================
118
 
119
  from diffusers import FlowMatchEulerDiscreteScheduler # noqa: F401
 
123
 
124
  dtype = torch.bfloat16
125
 
126
+ # ============================================================
127
+ # AIO version discovery + caching (CPU) + switching (GPU)
128
+ # ============================================================
129
+
130
  AIO_REPO_ID = "Pr0f3ssi0n4ln00b/Phr00t-Qwen-Rapid-AIO"
131
  DEFAULT_AIO_VERSION = "v19"
 
132
 
133
+ _AIO_VER_RE = re.compile(r"^(v\d+)$")
134
+ _AIO_SWITCH_LOCK = threading.Lock()
135
+
136
  CURRENT_AIO_VERSION = DEFAULT_AIO_VERSION
137
 
138
 
139
+ def _hard_cuda_cleanup():
140
+ gc.collect()
141
+ if torch.cuda.is_available():
142
+ try:
143
+ torch.cuda.synchronize()
144
+ except Exception:
145
+ pass
146
+ torch.cuda.empty_cache()
147
+ try:
148
+ torch.cuda.ipc_collect()
149
+ except Exception:
150
+ pass
151
+
152
+
153
+ def discover_aio_versions(repo_id: str) -> list[str]:
154
  """
155
+ Finds versions by scanning repo file paths with the naming convention:
156
+ vNN/transformer/...
 
 
157
  """
158
  try:
159
  api = HfApi()
160
+ files = api.list_repo_files(repo_id=repo_id, repo_type="model")
161
  versions = set()
162
  for f in files:
163
+ if "/transformer/" not in f:
164
+ continue
165
+ head = f.split("/transformer/", 1)[0]
166
+ if _AIO_VER_RE.fullmatch(head):
167
+ versions.add(head)
168
+
169
  if not versions:
170
  return [DEFAULT_AIO_VERSION]
171
+
172
+ return sorted(versions, key=lambda s: int(s[1:]))
 
173
  except Exception as e:
174
+ print(f"⚠️ AIO version discovery failed: {e}")
175
  return [DEFAULT_AIO_VERSION]
176
 
177
 
178
+ AVAILABLE_AIO_VERSIONS = discover_aio_versions(AIO_REPO_ID)
179
+ if DEFAULT_AIO_VERSION not in AVAILABLE_AIO_VERSIONS and AVAILABLE_AIO_VERSIONS:
180
+ DEFAULT_AIO_VERSION = AVAILABLE_AIO_VERSIONS[0]
181
+ CURRENT_AIO_VERSION = DEFAULT_AIO_VERSION
 
182
 
183
 
184
+ def ensure_aio_cached(version: str) -> None:
185
  """
186
+ CPU-only: download all files under vXX/transformer/ into HF cache.
187
+
188
+ This prevents long Hub downloads during GPU tasks (which often causes
189
+ "GPU task aborted" on ZeroGPU).
190
  """
191
+ version = version or DEFAULT_AIO_VERSION
192
+ sub = f"{version}/transformer"
193
+ api = HfApi()
194
+
195
+ files = api.list_repo_files(repo_id=AIO_REPO_ID, repo_type="model")
196
+ needed = [f for f in files if f.startswith(sub + "/")]
197
+ if not needed:
198
+ raise gr.Error(f"No files found under {sub}/ in {AIO_REPO_ID}")
199
+
200
+ for f in needed:
201
+ hf_hub_download(repo_id=AIO_REPO_ID, filename=f, repo_type="model")
202
+
203
+
204
+ def ensure_aio_cached_ui(version: str):
205
+ """
206
+ Gradio handler (CPU): cache selected version.
207
+ Returns status markdown + keeps run button interactive.
208
+ """
209
+ try:
210
+ version = version or DEFAULT_AIO_VERSION
211
+ print(f"⬇️ Caching AIO version on CPU: {version}")
212
+ ensure_aio_cached(version)
213
+ return gr.update(value=f"✅ Cached {version} (ready)"), gr.update(interactive=True)
214
+ except Exception as e:
215
+ print("❌ Cache step failed:\n", traceback.format_exc())
216
+ raise gr.Error(f"Cache failed for {version}: {e}")
217
 
218
 
219
  def _apply_fa3_if_possible():
 
224
  print(f"Warning: Could not set FA3 processor: {e}")
225
 
226
 
227
+ # Build pipeline once (default version at startup)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  pipe = QwenImageEditPlusPipeline.from_pretrained(
229
  "Qwen/Qwen-Image-Edit-2511",
230
+ transformer=QwenImageTransformer2DModel.from_pretrained(
231
+ AIO_REPO_ID,
232
+ subfolder=f"{DEFAULT_AIO_VERSION}/transformer",
233
+ torch_dtype=dtype,
234
+ device_map="cuda", # keep your existing setup
235
+ ),
236
  torch_dtype=dtype,
237
  ).to(device)
238
 
 
 
 
 
 
 
239
  _apply_fa3_if_possible()
240
 
241
  MAX_SEED = np.iinfo(np.int32).max
 
589
 
590
 
591
  # ============================================================
592
+ # AIO switch (GPU, local cache only; called inside infer)
593
  # ============================================================
594
 
595
 
596
  def _unload_all_loras():
 
597
  global LOADED_ADAPTERS
598
  try:
599
  pipe.set_adapters([], adapter_weights=[])
 
606
  LOADED_ADAPTERS.clear()
607
 
608
 
609
+ def _switch_aio_version_local_only(version: str):
610
  """
611
+ Must be called while already inside a GPU task.
612
+ Uses local_files_only=True (assumes ensure_aio_cached ran on CPU first).
 
 
 
 
613
  """
614
  global CURRENT_AIO_VERSION
615
 
616
+ version = version or DEFAULT_AIO_VERSION
617
+ if version == CURRENT_AIO_VERSION:
 
618
  return
619
 
620
+ with _AIO_SWITCH_LOCK:
621
+ if version == CURRENT_AIO_VERSION:
622
  return
623
 
624
+ print(f"🔁 Switching AIO transformer to: {AIO_REPO_ID} / {version}/transformer (local-only)")
625
 
626
+ # LoRAs are transformer-graph dependent
627
  _unload_all_loras()
628
 
629
+ # Detach old transformer strongly
630
  old_t = getattr(pipe, "transformer", None)
631
+
632
+ try:
633
+ if hasattr(pipe, "_modules") and "transformer" in pipe._modules:
634
+ pipe._modules.pop("transformer", None)
635
+ except Exception:
636
+ pass
637
+
638
  try:
639
+ pipe.transformer = None
 
640
  except Exception:
641
+ pass
 
642
 
643
  if old_t is not None:
644
  try:
 
649
 
650
  _hard_cuda_cleanup()
651
 
652
+ # Load from local cache only (no downloads inside GPU task)
653
+ new_t = QwenImageTransformer2DModel.from_pretrained(
654
+ AIO_REPO_ID,
655
+ subfolder=f"{version}/transformer",
656
+ torch_dtype=dtype,
657
+ local_files_only=True,
658
+ ).to(device)
659
+
 
 
 
660
  try:
661
+ pipe.add_module("transformer", new_t)
662
  except Exception:
663
  pipe.transformer = new_t
664
 
665
  _apply_fa3_if_possible()
666
 
667
+ CURRENT_AIO_VERSION = version
668
  _hard_cuda_cleanup()
669
 
670
 
 
674
 
675
 
676
  def on_lora_change_ui(selected_lora, current_prompt):
677
+ # Preset prompt (fill only if empty)
678
  if selected_lora != NONE_LORA:
679
  preset = LORA_PRESET_PROMPTS.get(selected_lora, "")
680
  if preset and (current_prompt is None or str(current_prompt).strip() == ""):
 
684
  else:
685
  prompt_update = gr.update(value=current_prompt)
686
 
687
+ # Image2 visibility/label
688
  if lora_requires_two_images(selected_lora):
689
  img2_update = gr.update(visible=True, label=image2_label_for_lora(selected_lora))
690
  else:
 
694
 
695
 
696
  def refresh_aio_versions_ui(current_value: str):
697
+ global AVAILABLE_AIO_VERSIONS, DEFAULT_AIO_VERSION
698
+ AVAILABLE_AIO_VERSIONS = discover_aio_versions(AIO_REPO_ID)
699
+
700
+ if DEFAULT_AIO_VERSION not in AVAILABLE_AIO_VERSIONS and AVAILABLE_AIO_VERSIONS:
701
+ DEFAULT_AIO_VERSION = AVAILABLE_AIO_VERSIONS[0]
702
+
703
+ if current_value in AVAILABLE_AIO_VERSIONS:
704
  new_value = current_value
705
  else:
706
+ new_value = DEFAULT_AIO_VERSION
707
 
708
+ status = f"Found {len(AVAILABLE_AIO_VERSIONS)} version(s): {', '.join(AVAILABLE_AIO_VERSIONS)}"
709
+ return gr.update(choices=AVAILABLE_AIO_VERSIONS, value=new_value), gr.update(value=status)
710
 
711
 
712
  # ============================================================
 
728
  steps,
729
  progress=gr.Progress(track_tqdm=True),
730
  ):
731
+ try:
732
+ _hard_cuda_cleanup()
733
 
734
+ if input_image_1 is None:
735
+ raise gr.Error("Please upload Image 1.")
736
 
737
+ # Switch AIO version quickly (local cache only). No downloads here.
738
+ if aio_version and aio_version != CURRENT_AIO_VERSION:
739
+ _switch_aio_version_local_only(aio_version)
740
 
741
+ # Handle "None"
742
+ if lora_adapter == NONE_LORA:
743
+ try:
744
+ pipe.set_adapters([], adapter_weights=[])
745
+ except Exception:
746
+ if LOADED_ADAPTERS:
747
+ pipe.set_adapters(list(LOADED_ADAPTERS), adapter_weights=[0.0] * len(LOADED_ADAPTERS))
748
+ else:
749
+ adapter_names, adapter_weights = _ensure_loaded_and_get_active_adapters(lora_adapter)
750
+ pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
751
 
752
+ if randomize_seed:
753
+ seed = random.randint(0, MAX_SEED)
754
 
755
+ generator = torch.Generator(device=device).manual_seed(seed)
756
+ negative_prompt = (
757
+ "worst quality, low quality, bad anatomy, bad hands, text, error, missing fingers, "
758
+ "extra digit, fewer digits, cropped, jpeg artifacts, signature, watermark, username, blurry"
759
+ )
760
 
761
+ img1 = input_image_1.convert("RGB")
762
+ img2 = input_image_2.convert("RGB") if input_image_2 is not None else None
763
 
764
+ # Normalize extra images (Gallery) to PIL RGB
765
+ extra_imgs: list[Image.Image] = []
766
+ if input_images_extra:
767
+ for item in input_images_extra:
768
+ pil = _to_pil_rgb(item)
769
+ if pil is not None:
770
+ extra_imgs.append(pil)
771
 
772
+ # Enforce existing 2-image LoRA behavior
773
+ if lora_requires_two_images(lora_adapter) and img2 is None:
774
+ raise gr.Error("This LoRA needs two images. Please upload Image 2 as well.")
775
 
776
+ # Label images as image_1, image_2, image_3...
777
+ labeled = build_labeled_images(img1, img2, extra_imgs)
778
 
779
+ pipe_images = list(labeled.values())
780
+ if len(pipe_images) == 1:
781
+ pipe_images = pipe_images[0]
782
 
783
+ # Resolution derived from Image 1
784
+ target_long_edge = get_target_long_edge_for_lora(lora_adapter)
785
+ width, height = compute_dimensions(img1, target_long_edge)
786
 
 
787
  result = pipe(
788
  image=pipe_images,
789
  prompt=prompt,
 
794
  generator=generator,
795
  true_cfg_scale=guidance_scale,
796
  ).images[0]
797
+
798
  return result, seed
799
+ except Exception as e:
800
+ print("❌ Infer failed:\n", traceback.format_exc())
801
+ raise
802
  finally:
803
  _hard_cuda_cleanup()
804
 
 
810
  input_pil = input_image.convert("RGB")
811
  guidance_scale = 1.0
812
  steps = 4
813
+ # Examples use current loaded AIO version, no swapping.
814
  result, seed = infer(CURRENT_AIO_VERSION, input_pil, None, None, prompt, lora_adapter, 0, True, guidance_scale, steps)
815
  return result, seed
816
 
 
839
  with gr.Row():
840
  aio_version = gr.Dropdown(
841
  label="Phr00t Rapid AIO Version",
842
+ choices=AVAILABLE_AIO_VERSIONS,
843
+ value=DEFAULT_AIO_VERSION,
844
+ interactive=True,
845
  )
846
  refresh_versions = gr.Button("Refresh", variant="secondary")
847
+
848
+ aio_status = gr.Markdown(
849
+ f"Found {len(AVAILABLE_AIO_VERSIONS)} version(s): {', '.join(AVAILABLE_AIO_VERSIONS)}"
850
+ )
851
+
852
+ # When user changes version: CPU-cache it (fast if already cached)
853
+ # Also keep Run enabled (interactive=True)
854
+ run_button_placeholder = gr.Button("Edit Image", variant="primary", visible=False)
855
+ aio_version.change(
856
+ fn=ensure_aio_cached_ui,
857
+ inputs=[aio_version],
858
+ outputs=[aio_status, run_button_placeholder],
859
+ )
860
 
861
  refresh_versions.click(
862
  fn=refresh_aio_versions_ui,
 
944
  label="Examples",
945
  )
946
 
947
+ # Run pipeline:
948
+ # 1) CPU cache selected version (fast if already cached)
949
+ # 2) GPU infer (will switch using local_files_only=True if needed)
950
  run_button.click(
951
+ fn=ensure_aio_cached_ui,
952
+ inputs=[aio_version],
953
+ outputs=[aio_status, run_button],
954
+ ).then(
955
  fn=infer,
956
  inputs=[
957
  aio_version,