Improve model loading robustness, add offline mode handling, and support local single-file SDXL checkpoints
This commit is contained in:
@@ -400,12 +400,21 @@ def postprocess_image(
|
|||||||
mask = create_mask(width, height, mask_w)
|
mask = create_mask(width, height, mask_w)
|
||||||
|
|
||||||
print("→ Loading seam inpaint model…")
|
print("→ Loading seam inpaint model…")
|
||||||
|
try:
|
||||||
inpaint_pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
inpaint_pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||||
INPAINT_MODEL,
|
INPAINT_MODEL,
|
||||||
torch_dtype=torch.float32,
|
torch_dtype=torch.float32,
|
||||||
safety_checker=None,
|
safety_checker=None,
|
||||||
requires_safety_checker=False,
|
requires_safety_checker=False,
|
||||||
|
**diffusers_load_kwargs(),
|
||||||
).to(device)
|
).to(device)
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
mode = "offline/local-cache" if local_files_only() else "online"
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Failed to load seam inpaint model '{INPAINT_MODEL}' in {mode} mode. "
|
||||||
|
"If downloads are disabled, make sure the model is already present in the "
|
||||||
|
"Hugging Face cache or set SKYMAP_ALLOW_DOWNLOADS=1 before launching."
|
||||||
|
) from e
|
||||||
configure_pipeline_memory(inpaint_pipe, vae_tiling=False)
|
configure_pipeline_memory(inpaint_pipe, vae_tiling=False)
|
||||||
|
|
||||||
print("→ Inpainting seam for seamless tiling…")
|
print("→ Inpainting seam for seamless tiling…")
|
||||||
@@ -525,18 +534,20 @@ def generate(
|
|||||||
height: int = 512,
|
height: int = 512,
|
||||||
seam_inpaint: bool = False,
|
seam_inpaint: bool = False,
|
||||||
) -> str:
|
) -> str:
|
||||||
# Normalize common aliases that 404 on HF
|
model_path = resolve_generation_model_path(model_path, work_dir)
|
||||||
aliases = {
|
|
||||||
"proximasan/sdxl-360-diffusion": "ProGamerGov/sdxl-360-diffusion",
|
|
||||||
"proximasan": "ProGamerGov/sdxl-360-diffusion",
|
|
||||||
}
|
|
||||||
model_path = aliases.get(model_path, model_path)
|
|
||||||
|
|
||||||
device = select_device()
|
device = select_device()
|
||||||
is_sdxl = "sdxl" in model_path.lower()
|
local_single_file = os.path.isfile(model_path)
|
||||||
|
is_sdxl = "sdxl" in model_path.lower() or local_single_file
|
||||||
enable_upscale = bool(upscale and upscale != "none")
|
enable_upscale = bool(upscale and upscale != "none")
|
||||||
|
|
||||||
os.makedirs(work_dir, exist_ok=True)
|
os.makedirs(work_dir, exist_ok=True)
|
||||||
|
if local_files_only():
|
||||||
|
print(
|
||||||
|
"→ Offline mode enabled: using local files/cache only "
|
||||||
|
"(set SKYMAP_ALLOW_DOWNLOADS=1 to allow downloads).",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory(dir=work_dir) as tempdir:
|
with tempfile.TemporaryDirectory(dir=work_dir) as tempdir:
|
||||||
print(f"→ Using tempdir: {tempdir}")
|
print(f"→ Using tempdir: {tempdir}")
|
||||||
@@ -544,9 +555,14 @@ def generate(
|
|||||||
gen_pipe = None
|
gen_pipe = None
|
||||||
load_errors: list[str] = []
|
load_errors: list[str] = []
|
||||||
vae = None
|
vae = None
|
||||||
if is_sdxl:
|
if is_sdxl and not local_single_file:
|
||||||
try:
|
try:
|
||||||
vae = AutoencoderKL.from_pretrained(vae_model, subfolder="vae", torch_dtype=torch.float32)
|
vae = AutoencoderKL.from_pretrained(
|
||||||
|
vae_model,
|
||||||
|
subfolder="vae",
|
||||||
|
torch_dtype=torch.float32,
|
||||||
|
**diffusers_load_kwargs(),
|
||||||
|
)
|
||||||
except Exception as e: # noqa: BLE001
|
except Exception as e: # noqa: BLE001
|
||||||
load_errors.append(f"vae: {e}")
|
load_errors.append(f"vae: {e}")
|
||||||
|
|
||||||
@@ -556,18 +572,35 @@ def generate(
|
|||||||
pipe_kwargs = {"torch_dtype": torch.float32}
|
pipe_kwargs = {"torch_dtype": torch.float32}
|
||||||
if vae is not None:
|
if vae is not None:
|
||||||
pipe_kwargs["vae"] = vae
|
pipe_kwargs["vae"] = vae
|
||||||
|
pipe_kwargs.update(diffusers_load_kwargs())
|
||||||
|
|
||||||
# Some SDXL repos only ship a UNet (e.g., sdxl-360); in that case load SDXL base
|
# Some SDXL repos only ship a UNet (e.g., sdxl-360); in that case load SDXL base
|
||||||
# and swap the UNet to keep the rest of the components consistent.
|
# and swap the UNet to keep the rest of the components consistent.
|
||||||
unet_only = False
|
unet_only = False
|
||||||
if model_path and model_path.endswith("sdxl-360-diffusion"):
|
if not local_single_file and model_path and model_path.endswith("sdxl-360-diffusion"):
|
||||||
try:
|
try:
|
||||||
hf_hub_download(model_path, "unet/config.json")
|
hf_hub_download(
|
||||||
|
model_path,
|
||||||
|
"unet/config.json",
|
||||||
|
local_files_only=local_files_only(),
|
||||||
|
)
|
||||||
unet_only = True
|
unet_only = True
|
||||||
except Exception:
|
except Exception:
|
||||||
unet_only = False
|
unet_only = False
|
||||||
|
|
||||||
if unet_only:
|
if local_single_file:
|
||||||
|
original_config_file = SDXL_SINGLE_FILE_CONFIG
|
||||||
|
if not os.path.isfile(original_config_file):
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Local SDXL config not found at {original_config_file}. "
|
||||||
|
"This is required to load single-file checkpoints without fetching from GitHub."
|
||||||
|
)
|
||||||
|
gen_pipe = StableDiffusionXLPipeline.from_single_file(
|
||||||
|
model_path,
|
||||||
|
original_config_file=original_config_file,
|
||||||
|
**pipe_kwargs,
|
||||||
|
).to(device)
|
||||||
|
elif unet_only:
|
||||||
base_pipe = StableDiffusionXLPipeline.from_pretrained(
|
base_pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
**pipe_kwargs
|
**pipe_kwargs
|
||||||
@@ -575,7 +608,8 @@ def generate(
|
|||||||
unet = UNet2DConditionModel.from_pretrained(
|
unet = UNet2DConditionModel.from_pretrained(
|
||||||
model_path,
|
model_path,
|
||||||
subfolder="unet",
|
subfolder="unet",
|
||||||
torch_dtype=torch.float32
|
torch_dtype=torch.float32,
|
||||||
|
**diffusers_load_kwargs(),
|
||||||
).to(device)
|
).to(device)
|
||||||
base_pipe.unet = unet
|
base_pipe.unet = unet
|
||||||
gen_pipe = base_pipe
|
gen_pipe = base_pipe
|
||||||
@@ -588,13 +622,14 @@ def generate(
|
|||||||
else:
|
else:
|
||||||
gen_pipe = StableDiffusionPipeline.from_pretrained(
|
gen_pipe = StableDiffusionPipeline.from_pretrained(
|
||||||
model_path,
|
model_path,
|
||||||
torch_dtype=torch.float32
|
torch_dtype=torch.float32,
|
||||||
|
**diffusers_load_kwargs(),
|
||||||
).to(device)
|
).to(device)
|
||||||
except Exception as e: # noqa: BLE001
|
except Exception as e: # noqa: BLE001
|
||||||
load_errors.append(f"from_pretrained: {e}")
|
load_errors.append(f"from_pretrained: {e}")
|
||||||
|
|
||||||
# Fallback: single-file SDXL checkpoint from the repo
|
# Fallback: single-file SDXL checkpoint from the repo
|
||||||
if gen_pipe is None:
|
if gen_pipe is None and not local_single_file:
|
||||||
ckpt_candidates = [
|
ckpt_candidates = [
|
||||||
"sdxl_360_diffusion.safetensors",
|
"sdxl_360_diffusion.safetensors",
|
||||||
"sdxl_360_diffusion_unet.safetensors",
|
"sdxl_360_diffusion_unet.safetensors",
|
||||||
@@ -603,10 +638,17 @@ def generate(
|
|||||||
last_err = None
|
last_err = None
|
||||||
for fname in ckpt_candidates:
|
for fname in ckpt_candidates:
|
||||||
try:
|
try:
|
||||||
ckpt = hf_hub_download(model_path, fname)
|
ckpt = hf_hub_download(
|
||||||
|
model_path,
|
||||||
|
fname,
|
||||||
|
local_files_only=local_files_only(),
|
||||||
|
)
|
||||||
pipe_kwargs = {"torch_dtype": torch.float32}
|
pipe_kwargs = {"torch_dtype": torch.float32}
|
||||||
if vae is not None:
|
if vae is not None:
|
||||||
pipe_kwargs["vae"] = vae
|
pipe_kwargs["vae"] = vae
|
||||||
|
pipe_kwargs.update(diffusers_load_kwargs())
|
||||||
|
if is_sdxl and os.path.isfile(SDXL_SINGLE_FILE_CONFIG):
|
||||||
|
pipe_kwargs["original_config_file"] = SDXL_SINGLE_FILE_CONFIG
|
||||||
gen_pipe = StableDiffusionXLPipeline.from_single_file(
|
gen_pipe = StableDiffusionXLPipeline.from_single_file(
|
||||||
ckpt,
|
ckpt,
|
||||||
**pipe_kwargs
|
**pipe_kwargs
|
||||||
@@ -627,7 +669,8 @@ def generate(
|
|||||||
vae = vae or AutoencoderKL.from_pretrained(
|
vae = vae or AutoencoderKL.from_pretrained(
|
||||||
SDXL_VAE,
|
SDXL_VAE,
|
||||||
subfolder="vae",
|
subfolder="vae",
|
||||||
torch_dtype=torch.float32
|
torch_dtype=torch.float32,
|
||||||
|
**diffusers_load_kwargs(),
|
||||||
)
|
)
|
||||||
gen_pipe.vae = vae.to(device)
|
gen_pipe.vae = vae.to(device)
|
||||||
gen_pipe.to(device)
|
gen_pipe.to(device)
|
||||||
|
|||||||
Reference in New Issue
Block a user