Professional Noob commited on
Commit
63f2ef7
·
verified ·
1 Parent(s): d92200a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +136 -104
app.py CHANGED
@@ -1,6 +1,8 @@
1
  import os
2
  import re
3
  import gc
 
 
4
  import random
5
  import threading
6
  import traceback
@@ -114,7 +116,7 @@ if torch.cuda.is_available():
114
  print("Using device:", device)
115
 
116
  # ============================================================
117
- # Pipeline imports (KEEP your existing transformer class)
118
  # ============================================================
119
 
120
  from diffusers import FlowMatchEulerDiscreteScheduler # noqa: F401
@@ -125,41 +127,34 @@ from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
125
  dtype = torch.bfloat16
126
 
127
  # ============================================================
128
- # AIO versioning
129
  # ============================================================
130
 
131
  AIO_REPO_ID = "Pr0f3ssi0n4ln00b/Phr00t-Qwen-Rapid-AIO"
132
  DEFAULT_AIO_VERSION = "v19"
133
  _AIO_VER_RE = re.compile(r"^(v\d+)$")
134
 
135
- # Cache control (prevents double-download when dropdown+run are both triggered)
136
- _CACHED_AIO_VERSIONS: set[str] = set()
137
- _CACHE_LOCKS: dict[str, threading.Lock] = {}
138
- _CACHE_LOCKS_GUARD = threading.Lock()
139
-
140
- # GPU switch lock (prevents concurrent swaps)
141
- _AIO_SWITCH_LOCK = threading.Lock()
142
 
143
 
144
- def _hard_cuda_cleanup():
145
- gc.collect()
146
- if torch.cuda.is_available():
147
- try:
148
- torch.cuda.synchronize()
149
- except Exception:
150
- pass
151
- torch.cuda.empty_cache()
152
- try:
153
- torch.cuda.ipc_collect()
154
- except Exception:
155
- pass
156
 
157
 
158
- def _get_cache_lock(version: str) -> threading.Lock:
159
- with _CACHE_LOCKS_GUARD:
160
- if version not in _CACHE_LOCKS:
161
- _CACHE_LOCKS[version] = threading.Lock()
162
- return _CACHE_LOCKS[version]
163
 
164
 
165
  def discover_aio_versions(repo_id: str) -> list[str]:
@@ -188,14 +183,50 @@ def discover_aio_versions(repo_id: str) -> list[str]:
188
 
189
 
190
  AVAILABLE_AIO_VERSIONS = discover_aio_versions(AIO_REPO_ID)
191
- if DEFAULT_AIO_VERSION not in AVAILABLE_AIO_VERSIONS and AVAILABLE_AIO_VERSIONS:
192
- DEFAULT_AIO_VERSION = AVAILABLE_AIO_VERSIONS[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
 
195
  def ensure_aio_cached(version: str) -> None:
196
  """
197
  CPU-only: download all files under vXX/transformer/ into HF cache.
198
-
199
  Idempotent + locked per version to avoid duplicate concurrent downloads.
200
  """
201
  version = version or DEFAULT_AIO_VERSION
@@ -214,7 +245,6 @@ def ensure_aio_cached(version: str) -> None:
214
  if not needed:
215
  raise gr.Error(f"No files found under {sub}/ in {AIO_REPO_ID}")
216
 
217
- # Download into cache (fast/no-op if already present)
218
  for f in needed:
219
  hf_hub_download(repo_id=AIO_REPO_ID, filename=f, repo_type="model")
220
 
@@ -228,28 +258,21 @@ def ensure_aio_cached_ui(version: str):
228
  try:
229
  version = version or DEFAULT_AIO_VERSION
230
  if version in _CACHED_AIO_VERSIONS:
231
- return gr.update(value=f"✅ Cached **{version}** (on disk)")
232
 
233
  print(f"⬇️ Caching AIO version on CPU: {version}")
234
  ensure_aio_cached(version)
235
- return gr.update(value=f"✅ Cached **{version}** (on disk)")
236
  except Exception as e:
237
  print("❌ Cache step failed:\n", traceback.format_exc())
238
  raise gr.Error(f"Cache failed for {version}: {e}")
239
 
240
 
241
  def refresh_aio_versions_ui(current_value: str):
242
- global AVAILABLE_AIO_VERSIONS, DEFAULT_AIO_VERSION
243
  AVAILABLE_AIO_VERSIONS = discover_aio_versions(AIO_REPO_ID)
244
 
245
- if DEFAULT_AIO_VERSION not in AVAILABLE_AIO_VERSIONS and AVAILABLE_AIO_VERSIONS:
246
- DEFAULT_AIO_VERSION = AVAILABLE_AIO_VERSIONS[0]
247
-
248
- if current_value in AVAILABLE_AIO_VERSIONS:
249
- new_value = current_value
250
- else:
251
- new_value = DEFAULT_AIO_VERSION
252
-
253
  status = f"Found {len(AVAILABLE_AIO_VERSIONS)} version(s): {', '.join(AVAILABLE_AIO_VERSIONS)}"
254
  return gr.update(choices=AVAILABLE_AIO_VERSIONS, value=new_value), gr.update(value=status)
255
 
@@ -262,24 +285,45 @@ def _apply_fa3_if_possible():
262
  print(f"Warning: Could not set FA3 processor: {e}")
263
 
264
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
  # ============================================================
266
- # Build pipeline once (default version at startup)
267
- # IMPORTANT: do NOT use device_map="cuda" here.
268
- # Load transformer on CPU, then move whole pipeline to GPU normally.
269
  # ============================================================
270
 
271
- print(f"📦 Loading AIO transformer: {AIO_REPO_ID} / {DEFAULT_AIO_VERSION}/transformer (CPU -> then move to cuda)")
272
-
273
- _default_transformer = QwenImageTransformer2DModel.from_pretrained(
274
- AIO_REPO_ID,
275
- subfolder=f"{DEFAULT_AIO_VERSION}/transformer",
276
- torch_dtype=dtype,
277
- # NO device_map here
278
- )
279
 
280
  pipe = QwenImageEditPlusPipeline.from_pretrained(
281
  "Qwen/Qwen-Image-Edit-2511",
282
- transformer=_default_transformer,
 
 
 
 
 
283
  torch_dtype=dtype,
284
  ).to(device)
285
 
@@ -636,40 +680,10 @@ def _unload_all_loras():
636
 
637
 
638
  # ============================================================
639
- # AIO switch (GPU, local cache only) — SAFER SWAP
640
  # ============================================================
641
 
642
 
643
- def _remove_transformer_from_pipe():
644
- """Best-effort remove transformer references so GC can free VRAM."""
645
- try:
646
- if hasattr(pipe, "_modules") and "transformer" in pipe._modules:
647
- pipe._modules.pop("transformer", None)
648
- except Exception:
649
- pass
650
- try:
651
- pipe.transformer = None
652
- except Exception:
653
- pass
654
-
655
-
656
- def _register_transformer_in_pipe(new_t):
657
- """Register transformer the diffusers way when possible."""
658
- try:
659
- if hasattr(pipe, "register_modules"):
660
- pipe.register_modules(transformer=new_t)
661
- return
662
- except Exception:
663
- pass
664
- # fallback
665
- pipe.transformer = new_t
666
- try:
667
- if hasattr(pipe, "_modules"):
668
- pipe._modules["transformer"] = new_t
669
- except Exception:
670
- pass
671
-
672
-
673
  def _switch_aio_version_local_only(target_version: str, current_loaded: str) -> str:
674
  """
675
  Must be called while already inside a GPU task.
@@ -686,30 +700,42 @@ def _switch_aio_version_local_only(target_version: str, current_loaded: str) ->
686
 
687
  print(f"🔁 Switching AIO transformer to: {AIO_REPO_ID} / {target_version}/transformer (local-only)")
688
 
689
- # Detach LoRAs first (keeps swap cleaner)
690
  _unload_all_loras()
691
 
692
- # Drop old transformer refs and free VRAM WITHOUT copying back to CPU
693
  old_t = getattr(pipe, "transformer", None)
694
- _remove_transformer_from_pipe()
 
 
 
 
 
 
 
 
 
 
 
 
695
  if old_t is not None:
696
  try:
697
- del old_t
698
  except Exception:
699
  pass
 
700
 
701
  _hard_cuda_cleanup()
702
 
703
- # Load new transformer on CPU (from local cache), then move to GPU normally
704
  new_t = QwenImageTransformer2DModel.from_pretrained(
705
  AIO_REPO_ID,
706
  subfolder=f"{target_version}/transformer",
707
  torch_dtype=dtype,
708
  local_files_only=True,
709
- )
710
- new_t = new_t.to(device)
711
 
712
- _register_transformer_in_pipe(new_t)
 
 
 
713
 
714
  _apply_fa3_if_possible()
715
  _hard_cuda_cleanup()
@@ -810,17 +836,16 @@ def infer(
810
  target_long_edge = get_target_long_edge_for_lora(lora_adapter)
811
  width, height = compute_dimensions(img1, target_long_edge)
812
 
813
- with torch.inference_mode():
814
- result = pipe(
815
- image=pipe_images,
816
- prompt=prompt,
817
- negative_prompt=negative_prompt,
818
- height=height,
819
- width=width,
820
- num_inference_steps=steps,
821
- generator=generator,
822
- true_cfg_scale=guidance_scale,
823
- ).images[0]
824
 
825
  status = f"✅ Loaded: **{new_loaded}** | Selected: **{aio_version}**"
826
  return result, seed, new_loaded, gr.update(value=status)
@@ -838,7 +863,6 @@ def infer_example(input_image, prompt, lora_adapter, loaded_version_state):
838
  input_pil = input_image.convert("RGB")
839
  guidance_scale = 1.0
840
  steps = 4
841
- # Examples: run with currently loaded transformer (no switch)
842
  result, seed, new_loaded, _ = infer(
843
  loaded_version_state,
844
  loaded_version_state,
@@ -886,6 +910,7 @@ with gr.Blocks() as demo:
886
  interactive=True,
887
  )
888
  refresh_versions = gr.Button("Refresh", variant="secondary")
 
889
 
890
  aio_status = gr.Markdown(
891
  f"✅ Loaded: **{DEFAULT_AIO_VERSION}** | Found {len(AVAILABLE_AIO_VERSIONS)} version(s): {', '.join(AVAILABLE_AIO_VERSIONS)}"
@@ -949,6 +974,13 @@ with gr.Blocks() as demo:
949
  outputs=[aio_version, aio_status],
950
  )
951
 
 
 
 
 
 
 
 
952
  gr.Examples(
953
  examples=[
954
  ["examples/1.jpg", "Transform into anime.", "Photo-to-Anime"],
 
1
  import os
2
  import re
3
  import gc
4
+ import sys
5
+ import time
6
  import random
7
  import threading
8
  import traceback
 
116
  print("Using device:", device)
117
 
118
  # ============================================================
119
+ # Pipeline imports (keep your existing transformer class)
120
  # ============================================================
121
 
122
  from diffusers import FlowMatchEulerDiscreteScheduler # noqa: F401
 
127
  dtype = torch.bfloat16
128
 
129
  # ============================================================
130
+ # AIO versioning + "boot version" persistence
131
  # ============================================================
132
 
133
  AIO_REPO_ID = "Pr0f3ssi0n4ln00b/Phr00t-Qwen-Rapid-AIO"
134
  DEFAULT_AIO_VERSION = "v19"
135
  _AIO_VER_RE = re.compile(r"^(v\d+)$")
136
 
137
+ # Preferred boot version sources:
138
+ # 1) Space Variable: DEFAULT_AIO_VERSION
139
+ # 2) Local preference file in HF cache dir (best-effort; not guaranteed across cold rebuilds)
140
+ _PREF_PATH = os.path.join(os.path.expanduser("~"), ".cache", "aio_default_version.txt")
 
 
 
141
 
142
 
143
+ def _read_pref_file() -> Optional[str]:
144
+ try:
145
+ if os.path.isfile(_PREF_PATH):
146
+ with open(_PREF_PATH, "r", encoding="utf-8") as f:
147
+ v = f.read().strip()
148
+ return v or None
149
+ except Exception:
150
+ return None
151
+ return None
 
 
 
152
 
153
 
154
+ def _write_pref_file(v: str) -> None:
155
+ os.makedirs(os.path.dirname(_PREF_PATH), exist_ok=True)
156
+ with open(_PREF_PATH, "w", encoding="utf-8") as f:
157
+ f.write(v)
 
158
 
159
 
160
  def discover_aio_versions(repo_id: str) -> list[str]:
 
183
 
184
 
185
  AVAILABLE_AIO_VERSIONS = discover_aio_versions(AIO_REPO_ID)
186
+
187
+ # pick boot version (env > file > fallback)
188
+ _env_boot = (os.environ.get("DEFAULT_AIO_VERSION") or "").strip()
189
+ _file_boot = (_read_pref_file() or "").strip()
190
+ BOOT_AIO_VERSION = _env_boot or _file_boot or DEFAULT_AIO_VERSION
191
+
192
+ if BOOT_AIO_VERSION not in AVAILABLE_AIO_VERSIONS and AVAILABLE_AIO_VERSIONS:
193
+ BOOT_AIO_VERSION = AVAILABLE_AIO_VERSIONS[0]
194
+
195
+ DEFAULT_AIO_VERSION = BOOT_AIO_VERSION # use boot version as the UI + pipeline default
196
+
197
+ # Cache control (prevents double-download when dropdown+run are both triggered)
198
+ _CACHED_AIO_VERSIONS: set[str] = set()
199
+ _CACHE_LOCKS: dict[str, threading.Lock] = {}
200
+ _CACHE_LOCKS_GUARD = threading.Lock()
201
+
202
+ # GPU switch lock (prevents concurrent swaps)
203
+ _AIO_SWITCH_LOCK = threading.Lock()
204
+
205
+
206
+ def _hard_cuda_cleanup():
207
+ gc.collect()
208
+ if torch.cuda.is_available():
209
+ try:
210
+ torch.cuda.synchronize()
211
+ except Exception:
212
+ pass
213
+ torch.cuda.empty_cache()
214
+ try:
215
+ torch.cuda.ipc_collect()
216
+ except Exception:
217
+ pass
218
+
219
+
220
+ def _get_cache_lock(version: str) -> threading.Lock:
221
+ with _CACHE_LOCKS_GUARD:
222
+ if version not in _CACHE_LOCKS:
223
+ _CACHE_LOCKS[version] = threading.Lock()
224
+ return _CACHE_LOCKS[version]
225
 
226
 
227
  def ensure_aio_cached(version: str) -> None:
228
  """
229
  CPU-only: download all files under vXX/transformer/ into HF cache.
 
230
  Idempotent + locked per version to avoid duplicate concurrent downloads.
231
  """
232
  version = version or DEFAULT_AIO_VERSION
 
245
  if not needed:
246
  raise gr.Error(f"No files found under {sub}/ in {AIO_REPO_ID}")
247
 
 
248
  for f in needed:
249
  hf_hub_download(repo_id=AIO_REPO_ID, filename=f, repo_type="model")
250
 
 
258
  try:
259
  version = version or DEFAULT_AIO_VERSION
260
  if version in _CACHED_AIO_VERSIONS:
261
+ return gr.update(value=f"✅ Cached {version} (ready)")
262
 
263
  print(f"⬇️ Caching AIO version on CPU: {version}")
264
  ensure_aio_cached(version)
265
+ return gr.update(value=f"✅ Cached {version} (ready)")
266
  except Exception as e:
267
  print("❌ Cache step failed:\n", traceback.format_exc())
268
  raise gr.Error(f"Cache failed for {version}: {e}")
269
 
270
 
271
  def refresh_aio_versions_ui(current_value: str):
272
+ global AVAILABLE_AIO_VERSIONS
273
  AVAILABLE_AIO_VERSIONS = discover_aio_versions(AIO_REPO_ID)
274
 
275
+ new_value = current_value if current_value in AVAILABLE_AIO_VERSIONS else (AVAILABLE_AIO_VERSIONS[0] if AVAILABLE_AIO_VERSIONS else DEFAULT_AIO_VERSION)
 
 
 
 
 
 
 
276
  status = f"Found {len(AVAILABLE_AIO_VERSIONS)} version(s): {', '.join(AVAILABLE_AIO_VERSIONS)}"
277
  return gr.update(choices=AVAILABLE_AIO_VERSIONS, value=new_value), gr.update(value=status)
278
 
 
285
  print(f"Warning: Could not set FA3 processor: {e}")
286
 
287
 
288
+ def set_default_and_restart_ui(version: str):
289
+ """
290
+ Best-effort: store desired boot version and force a restart so it loads at startup
291
+ (avoids transformer swapping during inference when possible).
292
+ """
293
+ version = version or DEFAULT_AIO_VERSION
294
+ if version not in AVAILABLE_AIO_VERSIONS:
295
+ raise gr.Error(f"Unknown version: {version}")
296
+
297
+ try:
298
+ _write_pref_file(version)
299
+ except Exception as e:
300
+ print(f"⚠️ Could not write preference file: {e}")
301
+
302
+ # Trigger restart a moment after returning UI update
303
+ def _restart_soon():
304
+ time.sleep(1.0)
305
+ # Let the supervisor restart the process
306
+ os._exit(0)
307
+
308
+ threading.Thread(target=_restart_soon, daemon=True).start()
309
+ return gr.update(value=f"✅ Saved startup version: **{version}**. Restarting Space now…")
310
+
311
+
312
  # ============================================================
313
+ # Build pipeline once (boot version at startup)
 
 
314
  # ============================================================
315
 
316
+ print(f"📦 Boot AIO version: {DEFAULT_AIO_VERSION} (env={_env_boot or '—'}, file={_file_boot or '—'})")
317
+ print(f"📦 Loading AIO transformer: {AIO_REPO_ID} / {DEFAULT_AIO_VERSION}/transformer (startup)")
 
 
 
 
 
 
318
 
319
  pipe = QwenImageEditPlusPipeline.from_pretrained(
320
  "Qwen/Qwen-Image-Edit-2511",
321
+ transformer=QwenImageTransformer2DModel.from_pretrained(
322
+ AIO_REPO_ID,
323
+ subfolder=f"{DEFAULT_AIO_VERSION}/transformer",
324
+ torch_dtype=dtype,
325
+ device_map="cuda", # keep your existing setup
326
+ ),
327
  torch_dtype=dtype,
328
  ).to(device)
329
 
 
680
 
681
 
682
  # ============================================================
683
+ # AIO switch (GPU, local cache only)
684
  # ============================================================
685
 
686
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
687
  def _switch_aio_version_local_only(target_version: str, current_loaded: str) -> str:
688
  """
689
  Must be called while already inside a GPU task.
 
700
 
701
  print(f"🔁 Switching AIO transformer to: {AIO_REPO_ID} / {target_version}/transformer (local-only)")
702
 
 
703
  _unload_all_loras()
704
 
 
705
  old_t = getattr(pipe, "transformer", None)
706
+
707
+ # Drop module registry refs so old transformer can be freed
708
+ try:
709
+ if hasattr(pipe, "_modules") and "transformer" in pipe._modules:
710
+ pipe._modules.pop("transformer", None)
711
+ except Exception:
712
+ pass
713
+
714
+ try:
715
+ pipe.transformer = None
716
+ except Exception:
717
+ pass
718
+
719
  if old_t is not None:
720
  try:
721
+ old_t.to("cpu")
722
  except Exception:
723
  pass
724
+ del old_t
725
 
726
  _hard_cuda_cleanup()
727
 
 
728
  new_t = QwenImageTransformer2DModel.from_pretrained(
729
  AIO_REPO_ID,
730
  subfolder=f"{target_version}/transformer",
731
  torch_dtype=dtype,
732
  local_files_only=True,
733
+ ).to(device)
 
734
 
735
+ try:
736
+ pipe.add_module("transformer", new_t)
737
+ except Exception:
738
+ pipe.transformer = new_t
739
 
740
  _apply_fa3_if_possible()
741
  _hard_cuda_cleanup()
 
836
  target_long_edge = get_target_long_edge_for_lora(lora_adapter)
837
  width, height = compute_dimensions(img1, target_long_edge)
838
 
839
+ result = pipe(
840
+ image=pipe_images,
841
+ prompt=prompt,
842
+ negative_prompt=negative_prompt,
843
+ height=height,
844
+ width=width,
845
+ num_inference_steps=steps,
846
+ generator=generator,
847
+ true_cfg_scale=guidance_scale,
848
+ ).images[0]
 
849
 
850
  status = f"✅ Loaded: **{new_loaded}** | Selected: **{aio_version}**"
851
  return result, seed, new_loaded, gr.update(value=status)
 
863
  input_pil = input_image.convert("RGB")
864
  guidance_scale = 1.0
865
  steps = 4
 
866
  result, seed, new_loaded, _ = infer(
867
  loaded_version_state,
868
  loaded_version_state,
 
910
  interactive=True,
911
  )
912
  refresh_versions = gr.Button("Refresh", variant="secondary")
913
+ set_default_restart = gr.Button("Set as startup version & restart", variant="secondary")
914
 
915
  aio_status = gr.Markdown(
916
  f"✅ Loaded: **{DEFAULT_AIO_VERSION}** | Found {len(AVAILABLE_AIO_VERSIONS)} version(s): {', '.join(AVAILABLE_AIO_VERSIONS)}"
 
974
  outputs=[aio_version, aio_status],
975
  )
976
 
977
+ # Save boot version + restart
978
+ set_default_restart.click(
979
+ fn=set_default_and_restart_ui,
980
+ inputs=[aio_version],
981
+ outputs=[aio_status],
982
+ )
983
+
984
  gr.Examples(
985
  examples=[
986
  ["examples/1.jpg", "Transform into anime.", "Photo-to-Anime"],