Spaces:
Running on Zero
Running on Zero
Professional Noob commited on
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
| 1 |
import os
|
| 2 |
import re
|
| 3 |
import gc
|
| 4 |
-
import threading
|
| 5 |
import random
|
|
|
|
|
|
|
| 6 |
from typing import Iterable, Optional
|
| 7 |
|
| 8 |
import gradio as gr
|
|
@@ -112,7 +113,7 @@ if torch.cuda.is_available():
|
|
| 112 |
print("Using device:", device)
|
| 113 |
|
| 114 |
# ============================================================
|
| 115 |
-
# Pipeline
|
| 116 |
# ============================================================
|
| 117 |
|
| 118 |
from diffusers import FlowMatchEulerDiscreteScheduler # noqa: F401
|
|
@@ -122,60 +123,97 @@ from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
|
|
| 122 |
|
| 123 |
dtype = torch.bfloat16
|
| 124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
AIO_REPO_ID = "Pr0f3ssi0n4ln00b/Phr00t-Qwen-Rapid-AIO"
|
| 126 |
DEFAULT_AIO_VERSION = "v19"
|
| 127 |
-
_AIO_VERSION_RE = re.compile(r"^(v\d+)/transformer/")
|
| 128 |
|
| 129 |
-
|
|
|
|
|
|
|
| 130 |
CURRENT_AIO_VERSION = DEFAULT_AIO_VERSION
|
| 131 |
|
| 132 |
|
| 133 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
"""
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
v21/transformer/...
|
| 138 |
-
Returns sorted: v1, v2, v10, ...
|
| 139 |
"""
|
| 140 |
try:
|
| 141 |
api = HfApi()
|
| 142 |
-
files = api.list_repo_files(repo_id=repo_id)
|
| 143 |
versions = set()
|
| 144 |
for f in files:
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
|
|
|
|
|
|
|
|
|
| 148 |
if not versions:
|
| 149 |
return [DEFAULT_AIO_VERSION]
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
return out
|
| 153 |
except Exception as e:
|
| 154 |
-
print(f"⚠️
|
| 155 |
return [DEFAULT_AIO_VERSION]
|
| 156 |
|
| 157 |
|
| 158 |
-
|
| 159 |
-
if DEFAULT_AIO_VERSION not in
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
CURRENT_AIO_VERSION = DEFAULT_AIO_VERSION
|
| 163 |
|
| 164 |
|
| 165 |
-
def
|
| 166 |
"""
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
|
|
|
| 170 |
"""
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
|
| 180 |
|
| 181 |
def _apply_fa3_if_possible():
|
|
@@ -186,33 +224,18 @@ def _apply_fa3_if_possible():
|
|
| 186 |
print(f"Warning: Could not set FA3 processor: {e}")
|
| 187 |
|
| 188 |
|
| 189 |
-
|
| 190 |
-
gc.collect()
|
| 191 |
-
if torch.cuda.is_available():
|
| 192 |
-
try:
|
| 193 |
-
torch.cuda.synchronize()
|
| 194 |
-
except Exception:
|
| 195 |
-
pass
|
| 196 |
-
torch.cuda.empty_cache()
|
| 197 |
-
try:
|
| 198 |
-
torch.cuda.ipc_collect()
|
| 199 |
-
except Exception:
|
| 200 |
-
pass
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
# Build pipeline once. We only swap pipe.transformer at runtime.
|
| 204 |
pipe = QwenImageEditPlusPipeline.from_pretrained(
|
| 205 |
"Qwen/Qwen-Image-Edit-2511",
|
| 206 |
-
transformer=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
torch_dtype=dtype,
|
| 208 |
).to(device)
|
| 209 |
|
| 210 |
-
# move transformer to device (pipeline .to() might not fully move nested module in some custom pipelines)
|
| 211 |
-
try:
|
| 212 |
-
pipe.transformer.to(device)
|
| 213 |
-
except Exception:
|
| 214 |
-
pass
|
| 215 |
-
|
| 216 |
_apply_fa3_if_possible()
|
| 217 |
|
| 218 |
MAX_SEED = np.iinfo(np.int32).max
|
|
@@ -566,12 +589,11 @@ def _ensure_loaded_and_get_active_adapters(selected_lora: str):
|
|
| 566 |
|
| 567 |
|
| 568 |
# ============================================================
|
| 569 |
-
# AIO
|
| 570 |
# ============================================================
|
| 571 |
|
| 572 |
|
| 573 |
def _unload_all_loras():
|
| 574 |
-
# When swapping transformer versions, previously loaded LoRAs are no longer safe to keep around.
|
| 575 |
global LOADED_ADAPTERS
|
| 576 |
try:
|
| 577 |
pipe.set_adapters([], adapter_weights=[])
|
|
@@ -584,39 +606,39 @@ def _unload_all_loras():
|
|
| 584 |
LOADED_ADAPTERS.clear()
|
| 585 |
|
| 586 |
|
| 587 |
-
def
|
| 588 |
"""
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
Key safeguards vs your crash:
|
| 592 |
-
- detach & move old transformer off GPU before loading a new one
|
| 593 |
-
- load new transformer without device_map (CPU), then move to GPU
|
| 594 |
-
- clear LoRAs because module graph changes across transformers
|
| 595 |
"""
|
| 596 |
global CURRENT_AIO_VERSION
|
| 597 |
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
if target_version == CURRENT_AIO_VERSION:
|
| 601 |
return
|
| 602 |
|
| 603 |
-
with
|
| 604 |
-
if
|
| 605 |
return
|
| 606 |
|
| 607 |
-
print(f"🔁 Switching AIO transformer to: {AIO_REPO_ID} / {
|
| 608 |
|
| 609 |
-
#
|
| 610 |
_unload_all_loras()
|
| 611 |
|
| 612 |
-
# Detach old transformer
|
| 613 |
old_t = getattr(pipe, "transformer", None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 614 |
try:
|
| 615 |
-
|
| 616 |
-
pipe.register_modules(transformer=torch.nn.Identity())
|
| 617 |
except Exception:
|
| 618 |
-
|
| 619 |
-
pipe.transformer = torch.nn.Identity()
|
| 620 |
|
| 621 |
if old_t is not None:
|
| 622 |
try:
|
|
@@ -627,25 +649,22 @@ def switch_aio_version(target_version: str):
|
|
| 627 |
|
| 628 |
_hard_cuda_cleanup()
|
| 629 |
|
| 630 |
-
# Load
|
| 631 |
-
new_t =
|
| 632 |
-
|
| 633 |
-
|
| 634 |
-
|
| 635 |
-
|
| 636 |
-
|
| 637 |
-
|
| 638 |
-
raise gr.Error(f"Failed to move transformer {target_version} to {device}: {e}")
|
| 639 |
-
|
| 640 |
-
# Swap in
|
| 641 |
try:
|
| 642 |
-
pipe.
|
| 643 |
except Exception:
|
| 644 |
pipe.transformer = new_t
|
| 645 |
|
| 646 |
_apply_fa3_if_possible()
|
| 647 |
|
| 648 |
-
CURRENT_AIO_VERSION =
|
| 649 |
_hard_cuda_cleanup()
|
| 650 |
|
| 651 |
|
|
@@ -655,6 +674,7 @@ def switch_aio_version(target_version: str):
|
|
| 655 |
|
| 656 |
|
| 657 |
def on_lora_change_ui(selected_lora, current_prompt):
|
|
|
|
| 658 |
if selected_lora != NONE_LORA:
|
| 659 |
preset = LORA_PRESET_PROMPTS.get(selected_lora, "")
|
| 660 |
if preset and (current_prompt is None or str(current_prompt).strip() == ""):
|
|
@@ -664,6 +684,7 @@ def on_lora_change_ui(selected_lora, current_prompt):
|
|
| 664 |
else:
|
| 665 |
prompt_update = gr.update(value=current_prompt)
|
| 666 |
|
|
|
|
| 667 |
if lora_requires_two_images(selected_lora):
|
| 668 |
img2_update = gr.update(visible=True, label=image2_label_for_lora(selected_lora))
|
| 669 |
else:
|
|
@@ -673,14 +694,19 @@ def on_lora_change_ui(selected_lora, current_prompt):
|
|
| 673 |
|
| 674 |
|
| 675 |
def refresh_aio_versions_ui(current_value: str):
|
| 676 |
-
global
|
| 677 |
-
|
| 678 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 679 |
new_value = current_value
|
| 680 |
else:
|
| 681 |
-
new_value = DEFAULT_AIO_VERSION
|
| 682 |
|
| 683 |
-
|
|
|
|
| 684 |
|
| 685 |
|
| 686 |
# ============================================================
|
|
@@ -702,58 +728,62 @@ def infer(
|
|
| 702 |
steps,
|
| 703 |
progress=gr.Progress(track_tqdm=True),
|
| 704 |
):
|
| 705 |
-
|
|
|
|
| 706 |
|
| 707 |
-
|
| 708 |
-
|
| 709 |
|
| 710 |
-
|
| 711 |
-
|
| 712 |
-
|
| 713 |
|
| 714 |
-
|
| 715 |
-
|
| 716 |
-
|
| 717 |
-
|
| 718 |
-
|
| 719 |
-
|
| 720 |
-
|
| 721 |
-
|
| 722 |
-
|
| 723 |
-
|
| 724 |
|
| 725 |
-
|
| 726 |
-
|
| 727 |
|
| 728 |
-
|
| 729 |
-
|
| 730 |
-
|
| 731 |
-
|
| 732 |
-
|
| 733 |
|
| 734 |
-
|
| 735 |
-
|
| 736 |
|
| 737 |
-
|
| 738 |
-
|
| 739 |
-
|
| 740 |
-
|
| 741 |
-
|
| 742 |
-
|
|
|
|
| 743 |
|
| 744 |
-
|
| 745 |
-
|
|
|
|
| 746 |
|
| 747 |
-
|
|
|
|
| 748 |
|
| 749 |
-
|
| 750 |
-
|
| 751 |
-
|
| 752 |
|
| 753 |
-
|
| 754 |
-
|
|
|
|
| 755 |
|
| 756 |
-
try:
|
| 757 |
result = pipe(
|
| 758 |
image=pipe_images,
|
| 759 |
prompt=prompt,
|
|
@@ -764,7 +794,11 @@ def infer(
|
|
| 764 |
generator=generator,
|
| 765 |
true_cfg_scale=guidance_scale,
|
| 766 |
).images[0]
|
|
|
|
| 767 |
return result, seed
|
|
|
|
|
|
|
|
|
|
| 768 |
finally:
|
| 769 |
_hard_cuda_cleanup()
|
| 770 |
|
|
@@ -776,7 +810,7 @@ def infer_example(input_image, prompt, lora_adapter):
|
|
| 776 |
input_pil = input_image.convert("RGB")
|
| 777 |
guidance_scale = 1.0
|
| 778 |
steps = 4
|
| 779 |
-
# Examples
|
| 780 |
result, seed = infer(CURRENT_AIO_VERSION, input_pil, None, None, prompt, lora_adapter, 0, True, guidance_scale, steps)
|
| 781 |
return result, seed
|
| 782 |
|
|
@@ -805,11 +839,24 @@ with gr.Blocks() as demo:
|
|
| 805 |
with gr.Row():
|
| 806 |
aio_version = gr.Dropdown(
|
| 807 |
label="Phr00t Rapid AIO Version",
|
| 808 |
-
choices=
|
| 809 |
-
value=
|
|
|
|
| 810 |
)
|
| 811 |
refresh_versions = gr.Button("Refresh", variant="secondary")
|
| 812 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 813 |
|
| 814 |
refresh_versions.click(
|
| 815 |
fn=refresh_aio_versions_ui,
|
|
@@ -897,7 +944,14 @@ with gr.Blocks() as demo:
|
|
| 897 |
label="Examples",
|
| 898 |
)
|
| 899 |
|
|
|
|
|
|
|
|
|
|
| 900 |
run_button.click(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 901 |
fn=infer,
|
| 902 |
inputs=[
|
| 903 |
aio_version,
|
|
|
|
| 1 |
import os
|
| 2 |
import re
|
| 3 |
import gc
|
|
|
|
| 4 |
import random
|
| 5 |
+
import threading
|
| 6 |
+
import traceback
|
| 7 |
from typing import Iterable, Optional
|
| 8 |
|
| 9 |
import gradio as gr
|
|
|
|
| 113 |
print("Using device:", device)
|
| 114 |
|
| 115 |
# ============================================================
|
| 116 |
+
# Pipeline
|
| 117 |
# ============================================================
|
| 118 |
|
| 119 |
from diffusers import FlowMatchEulerDiscreteScheduler # noqa: F401
|
|
|
|
| 123 |
|
| 124 |
dtype = torch.bfloat16
|
| 125 |
|
| 126 |
+
# ============================================================
|
| 127 |
+
# AIO version discovery + caching (CPU) + switching (GPU)
|
| 128 |
+
# ============================================================
|
| 129 |
+
|
| 130 |
AIO_REPO_ID = "Pr0f3ssi0n4ln00b/Phr00t-Qwen-Rapid-AIO"
|
| 131 |
DEFAULT_AIO_VERSION = "v19"
|
|
|
|
| 132 |
|
| 133 |
+
_AIO_VER_RE = re.compile(r"^(v\d+)$")
|
| 134 |
+
_AIO_SWITCH_LOCK = threading.Lock()
|
| 135 |
+
|
| 136 |
CURRENT_AIO_VERSION = DEFAULT_AIO_VERSION
|
| 137 |
|
| 138 |
|
| 139 |
+
def _hard_cuda_cleanup():
|
| 140 |
+
gc.collect()
|
| 141 |
+
if torch.cuda.is_available():
|
| 142 |
+
try:
|
| 143 |
+
torch.cuda.synchronize()
|
| 144 |
+
except Exception:
|
| 145 |
+
pass
|
| 146 |
+
torch.cuda.empty_cache()
|
| 147 |
+
try:
|
| 148 |
+
torch.cuda.ipc_collect()
|
| 149 |
+
except Exception:
|
| 150 |
+
pass
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def discover_aio_versions(repo_id: str) -> list[str]:
|
| 154 |
"""
|
| 155 |
+
Finds versions by scanning repo file paths with the naming convention:
|
| 156 |
+
vNN/transformer/...
|
|
|
|
|
|
|
| 157 |
"""
|
| 158 |
try:
|
| 159 |
api = HfApi()
|
| 160 |
+
files = api.list_repo_files(repo_id=repo_id, repo_type="model")
|
| 161 |
versions = set()
|
| 162 |
for f in files:
|
| 163 |
+
if "/transformer/" not in f:
|
| 164 |
+
continue
|
| 165 |
+
head = f.split("/transformer/", 1)[0]
|
| 166 |
+
if _AIO_VER_RE.fullmatch(head):
|
| 167 |
+
versions.add(head)
|
| 168 |
+
|
| 169 |
if not versions:
|
| 170 |
return [DEFAULT_AIO_VERSION]
|
| 171 |
+
|
| 172 |
+
return sorted(versions, key=lambda s: int(s[1:]))
|
|
|
|
| 173 |
except Exception as e:
|
| 174 |
+
print(f"⚠️ AIO version discovery failed: {e}")
|
| 175 |
return [DEFAULT_AIO_VERSION]
|
| 176 |
|
| 177 |
|
| 178 |
+
AVAILABLE_AIO_VERSIONS = discover_aio_versions(AIO_REPO_ID)
|
| 179 |
+
if DEFAULT_AIO_VERSION not in AVAILABLE_AIO_VERSIONS and AVAILABLE_AIO_VERSIONS:
|
| 180 |
+
DEFAULT_AIO_VERSION = AVAILABLE_AIO_VERSIONS[0]
|
| 181 |
+
CURRENT_AIO_VERSION = DEFAULT_AIO_VERSION
|
|
|
|
| 182 |
|
| 183 |
|
| 184 |
+
def ensure_aio_cached(version: str) -> None:
|
| 185 |
"""
|
| 186 |
+
CPU-only: download all files under vXX/transformer/ into HF cache.
|
| 187 |
+
|
| 188 |
+
This prevents long Hub downloads during GPU tasks (which often causes
|
| 189 |
+
"GPU task aborted" on ZeroGPU).
|
| 190 |
"""
|
| 191 |
+
version = version or DEFAULT_AIO_VERSION
|
| 192 |
+
sub = f"{version}/transformer"
|
| 193 |
+
api = HfApi()
|
| 194 |
+
|
| 195 |
+
files = api.list_repo_files(repo_id=AIO_REPO_ID, repo_type="model")
|
| 196 |
+
needed = [f for f in files if f.startswith(sub + "/")]
|
| 197 |
+
if not needed:
|
| 198 |
+
raise gr.Error(f"No files found under {sub}/ in {AIO_REPO_ID}")
|
| 199 |
+
|
| 200 |
+
for f in needed:
|
| 201 |
+
hf_hub_download(repo_id=AIO_REPO_ID, filename=f, repo_type="model")
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def ensure_aio_cached_ui(version: str):
|
| 205 |
+
"""
|
| 206 |
+
Gradio handler (CPU): cache selected version.
|
| 207 |
+
Returns status markdown + keeps run button interactive.
|
| 208 |
+
"""
|
| 209 |
+
try:
|
| 210 |
+
version = version or DEFAULT_AIO_VERSION
|
| 211 |
+
print(f"⬇️ Caching AIO version on CPU: {version}")
|
| 212 |
+
ensure_aio_cached(version)
|
| 213 |
+
return gr.update(value=f"✅ Cached {version} (ready)"), gr.update(interactive=True)
|
| 214 |
+
except Exception as e:
|
| 215 |
+
print("❌ Cache step failed:\n", traceback.format_exc())
|
| 216 |
+
raise gr.Error(f"Cache failed for {version}: {e}")
|
| 217 |
|
| 218 |
|
| 219 |
def _apply_fa3_if_possible():
|
|
|
|
| 224 |
print(f"Warning: Could not set FA3 processor: {e}")
|
| 225 |
|
| 226 |
|
| 227 |
+
# Build pipeline once (default version at startup)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
pipe = QwenImageEditPlusPipeline.from_pretrained(
|
| 229 |
"Qwen/Qwen-Image-Edit-2511",
|
| 230 |
+
transformer=QwenImageTransformer2DModel.from_pretrained(
|
| 231 |
+
AIO_REPO_ID,
|
| 232 |
+
subfolder=f"{DEFAULT_AIO_VERSION}/transformer",
|
| 233 |
+
torch_dtype=dtype,
|
| 234 |
+
device_map="cuda", # keep your existing setup
|
| 235 |
+
),
|
| 236 |
torch_dtype=dtype,
|
| 237 |
).to(device)
|
| 238 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
_apply_fa3_if_possible()
|
| 240 |
|
| 241 |
MAX_SEED = np.iinfo(np.int32).max
|
|
|
|
| 589 |
|
| 590 |
|
| 591 |
# ============================================================
|
| 592 |
+
# AIO switch (GPU, local cache only; called inside infer)
|
| 593 |
# ============================================================
|
| 594 |
|
| 595 |
|
| 596 |
def _unload_all_loras():
|
|
|
|
| 597 |
global LOADED_ADAPTERS
|
| 598 |
try:
|
| 599 |
pipe.set_adapters([], adapter_weights=[])
|
|
|
|
| 606 |
LOADED_ADAPTERS.clear()
|
| 607 |
|
| 608 |
|
| 609 |
+
def _switch_aio_version_local_only(version: str):
|
| 610 |
"""
|
| 611 |
+
Must be called while already inside a GPU task.
|
| 612 |
+
Uses local_files_only=True (assumes ensure_aio_cached ran on CPU first).
|
|
|
|
|
|
|
|
|
|
|
|
|
| 613 |
"""
|
| 614 |
global CURRENT_AIO_VERSION
|
| 615 |
|
| 616 |
+
version = version or DEFAULT_AIO_VERSION
|
| 617 |
+
if version == CURRENT_AIO_VERSION:
|
|
|
|
| 618 |
return
|
| 619 |
|
| 620 |
+
with _AIO_SWITCH_LOCK:
|
| 621 |
+
if version == CURRENT_AIO_VERSION:
|
| 622 |
return
|
| 623 |
|
| 624 |
+
print(f"🔁 Switching AIO transformer to: {AIO_REPO_ID} / {version}/transformer (local-only)")
|
| 625 |
|
| 626 |
+
# LoRAs are transformer-graph dependent
|
| 627 |
_unload_all_loras()
|
| 628 |
|
| 629 |
+
# Detach old transformer strongly
|
| 630 |
old_t = getattr(pipe, "transformer", None)
|
| 631 |
+
|
| 632 |
+
try:
|
| 633 |
+
if hasattr(pipe, "_modules") and "transformer" in pipe._modules:
|
| 634 |
+
pipe._modules.pop("transformer", None)
|
| 635 |
+
except Exception:
|
| 636 |
+
pass
|
| 637 |
+
|
| 638 |
try:
|
| 639 |
+
pipe.transformer = None
|
|
|
|
| 640 |
except Exception:
|
| 641 |
+
pass
|
|
|
|
| 642 |
|
| 643 |
if old_t is not None:
|
| 644 |
try:
|
|
|
|
| 649 |
|
| 650 |
_hard_cuda_cleanup()
|
| 651 |
|
| 652 |
+
# Load from local cache only (no downloads inside GPU task)
|
| 653 |
+
new_t = QwenImageTransformer2DModel.from_pretrained(
|
| 654 |
+
AIO_REPO_ID,
|
| 655 |
+
subfolder=f"{version}/transformer",
|
| 656 |
+
torch_dtype=dtype,
|
| 657 |
+
local_files_only=True,
|
| 658 |
+
).to(device)
|
| 659 |
+
|
|
|
|
|
|
|
|
|
|
| 660 |
try:
|
| 661 |
+
pipe.add_module("transformer", new_t)
|
| 662 |
except Exception:
|
| 663 |
pipe.transformer = new_t
|
| 664 |
|
| 665 |
_apply_fa3_if_possible()
|
| 666 |
|
| 667 |
+
CURRENT_AIO_VERSION = version
|
| 668 |
_hard_cuda_cleanup()
|
| 669 |
|
| 670 |
|
|
|
|
| 674 |
|
| 675 |
|
| 676 |
def on_lora_change_ui(selected_lora, current_prompt):
|
| 677 |
+
# Preset prompt (fill only if empty)
|
| 678 |
if selected_lora != NONE_LORA:
|
| 679 |
preset = LORA_PRESET_PROMPTS.get(selected_lora, "")
|
| 680 |
if preset and (current_prompt is None or str(current_prompt).strip() == ""):
|
|
|
|
| 684 |
else:
|
| 685 |
prompt_update = gr.update(value=current_prompt)
|
| 686 |
|
| 687 |
+
# Image2 visibility/label
|
| 688 |
if lora_requires_two_images(selected_lora):
|
| 689 |
img2_update = gr.update(visible=True, label=image2_label_for_lora(selected_lora))
|
| 690 |
else:
|
|
|
|
| 694 |
|
| 695 |
|
| 696 |
def refresh_aio_versions_ui(current_value: str):
|
| 697 |
+
global AVAILABLE_AIO_VERSIONS, DEFAULT_AIO_VERSION
|
| 698 |
+
AVAILABLE_AIO_VERSIONS = discover_aio_versions(AIO_REPO_ID)
|
| 699 |
+
|
| 700 |
+
if DEFAULT_AIO_VERSION not in AVAILABLE_AIO_VERSIONS and AVAILABLE_AIO_VERSIONS:
|
| 701 |
+
DEFAULT_AIO_VERSION = AVAILABLE_AIO_VERSIONS[0]
|
| 702 |
+
|
| 703 |
+
if current_value in AVAILABLE_AIO_VERSIONS:
|
| 704 |
new_value = current_value
|
| 705 |
else:
|
| 706 |
+
new_value = DEFAULT_AIO_VERSION
|
| 707 |
|
| 708 |
+
status = f"Found {len(AVAILABLE_AIO_VERSIONS)} version(s): {', '.join(AVAILABLE_AIO_VERSIONS)}"
|
| 709 |
+
return gr.update(choices=AVAILABLE_AIO_VERSIONS, value=new_value), gr.update(value=status)
|
| 710 |
|
| 711 |
|
| 712 |
# ============================================================
|
|
|
|
| 728 |
steps,
|
| 729 |
progress=gr.Progress(track_tqdm=True),
|
| 730 |
):
|
| 731 |
+
try:
|
| 732 |
+
_hard_cuda_cleanup()
|
| 733 |
|
| 734 |
+
if input_image_1 is None:
|
| 735 |
+
raise gr.Error("Please upload Image 1.")
|
| 736 |
|
| 737 |
+
# Switch AIO version quickly (local cache only). No downloads here.
|
| 738 |
+
if aio_version and aio_version != CURRENT_AIO_VERSION:
|
| 739 |
+
_switch_aio_version_local_only(aio_version)
|
| 740 |
|
| 741 |
+
# Handle "None"
|
| 742 |
+
if lora_adapter == NONE_LORA:
|
| 743 |
+
try:
|
| 744 |
+
pipe.set_adapters([], adapter_weights=[])
|
| 745 |
+
except Exception:
|
| 746 |
+
if LOADED_ADAPTERS:
|
| 747 |
+
pipe.set_adapters(list(LOADED_ADAPTERS), adapter_weights=[0.0] * len(LOADED_ADAPTERS))
|
| 748 |
+
else:
|
| 749 |
+
adapter_names, adapter_weights = _ensure_loaded_and_get_active_adapters(lora_adapter)
|
| 750 |
+
pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
|
| 751 |
|
| 752 |
+
if randomize_seed:
|
| 753 |
+
seed = random.randint(0, MAX_SEED)
|
| 754 |
|
| 755 |
+
generator = torch.Generator(device=device).manual_seed(seed)
|
| 756 |
+
negative_prompt = (
|
| 757 |
+
"worst quality, low quality, bad anatomy, bad hands, text, error, missing fingers, "
|
| 758 |
+
"extra digit, fewer digits, cropped, jpeg artifacts, signature, watermark, username, blurry"
|
| 759 |
+
)
|
| 760 |
|
| 761 |
+
img1 = input_image_1.convert("RGB")
|
| 762 |
+
img2 = input_image_2.convert("RGB") if input_image_2 is not None else None
|
| 763 |
|
| 764 |
+
# Normalize extra images (Gallery) to PIL RGB
|
| 765 |
+
extra_imgs: list[Image.Image] = []
|
| 766 |
+
if input_images_extra:
|
| 767 |
+
for item in input_images_extra:
|
| 768 |
+
pil = _to_pil_rgb(item)
|
| 769 |
+
if pil is not None:
|
| 770 |
+
extra_imgs.append(pil)
|
| 771 |
|
| 772 |
+
# Enforce existing 2-image LoRA behavior
|
| 773 |
+
if lora_requires_two_images(lora_adapter) and img2 is None:
|
| 774 |
+
raise gr.Error("This LoRA needs two images. Please upload Image 2 as well.")
|
| 775 |
|
| 776 |
+
# Label images as image_1, image_2, image_3...
|
| 777 |
+
labeled = build_labeled_images(img1, img2, extra_imgs)
|
| 778 |
|
| 779 |
+
pipe_images = list(labeled.values())
|
| 780 |
+
if len(pipe_images) == 1:
|
| 781 |
+
pipe_images = pipe_images[0]
|
| 782 |
|
| 783 |
+
# Resolution derived from Image 1
|
| 784 |
+
target_long_edge = get_target_long_edge_for_lora(lora_adapter)
|
| 785 |
+
width, height = compute_dimensions(img1, target_long_edge)
|
| 786 |
|
|
|
|
| 787 |
result = pipe(
|
| 788 |
image=pipe_images,
|
| 789 |
prompt=prompt,
|
|
|
|
| 794 |
generator=generator,
|
| 795 |
true_cfg_scale=guidance_scale,
|
| 796 |
).images[0]
|
| 797 |
+
|
| 798 |
return result, seed
|
| 799 |
+
except Exception as e:
|
| 800 |
+
print("❌ Infer failed:\n", traceback.format_exc())
|
| 801 |
+
raise
|
| 802 |
finally:
|
| 803 |
_hard_cuda_cleanup()
|
| 804 |
|
|
|
|
| 810 |
input_pil = input_image.convert("RGB")
|
| 811 |
guidance_scale = 1.0
|
| 812 |
steps = 4
|
| 813 |
+
# Examples use current loaded AIO version, no swapping.
|
| 814 |
result, seed = infer(CURRENT_AIO_VERSION, input_pil, None, None, prompt, lora_adapter, 0, True, guidance_scale, steps)
|
| 815 |
return result, seed
|
| 816 |
|
|
|
|
| 839 |
with gr.Row():
|
| 840 |
aio_version = gr.Dropdown(
|
| 841 |
label="Phr00t Rapid AIO Version",
|
| 842 |
+
choices=AVAILABLE_AIO_VERSIONS,
|
| 843 |
+
value=DEFAULT_AIO_VERSION,
|
| 844 |
+
interactive=True,
|
| 845 |
)
|
| 846 |
refresh_versions = gr.Button("Refresh", variant="secondary")
|
| 847 |
+
|
| 848 |
+
aio_status = gr.Markdown(
|
| 849 |
+
f"Found {len(AVAILABLE_AIO_VERSIONS)} version(s): {', '.join(AVAILABLE_AIO_VERSIONS)}"
|
| 850 |
+
)
|
| 851 |
+
|
| 852 |
+
# When user changes version: CPU-cache it (fast if already cached)
|
| 853 |
+
# Also keep Run enabled (interactive=True)
|
| 854 |
+
run_button_placeholder = gr.Button("Edit Image", variant="primary", visible=False)
|
| 855 |
+
aio_version.change(
|
| 856 |
+
fn=ensure_aio_cached_ui,
|
| 857 |
+
inputs=[aio_version],
|
| 858 |
+
outputs=[aio_status, run_button_placeholder],
|
| 859 |
+
)
|
| 860 |
|
| 861 |
refresh_versions.click(
|
| 862 |
fn=refresh_aio_versions_ui,
|
|
|
|
| 944 |
label="Examples",
|
| 945 |
)
|
| 946 |
|
| 947 |
+
# Run pipeline:
|
| 948 |
+
# 1) CPU cache selected version (fast if already cached)
|
| 949 |
+
# 2) GPU infer (will switch using local_files_only=True if needed)
|
| 950 |
run_button.click(
|
| 951 |
+
fn=ensure_aio_cached_ui,
|
| 952 |
+
inputs=[aio_version],
|
| 953 |
+
outputs=[aio_status, run_button],
|
| 954 |
+
).then(
|
| 955 |
fn=infer,
|
| 956 |
inputs=[
|
| 957 |
aio_version,
|