Improve model loading robustness, add offline mode handling, and support local single-file SDXL checkpoints

This commit is contained in:
2026-05-07 15:12:59 +02:00
parent 52fe35e4f2
commit 25dc3cf952

View File

@@ -400,12 +400,21 @@ def postprocess_image(
mask = create_mask(width, height, mask_w)
print("→ Loading seam inpaint model…")
try:
inpaint_pipe = StableDiffusionInpaintPipeline.from_pretrained(
INPAINT_MODEL,
torch_dtype=torch.float32,
safety_checker=None,
requires_safety_checker=False,
**diffusers_load_kwargs(),
).to(device)
except Exception as e: # noqa: BLE001
mode = "offline/local-cache" if local_files_only() else "online"
raise RuntimeError(
f"Failed to load seam inpaint model '{INPAINT_MODEL}' in {mode} mode. "
"If downloads are disabled, make sure the model is already present in the "
"Hugging Face cache or set SKYMAP_ALLOW_DOWNLOADS=1 before launching."
) from e
configure_pipeline_memory(inpaint_pipe, vae_tiling=False)
print("→ Inpainting seam for seamless tiling…")
@@ -525,18 +534,20 @@ def generate(
height: int = 512,
seam_inpaint: bool = False,
) -> str:
# Normalize common aliases that 404 on HF
aliases = {
"proximasan/sdxl-360-diffusion": "ProGamerGov/sdxl-360-diffusion",
"proximasan": "ProGamerGov/sdxl-360-diffusion",
}
model_path = aliases.get(model_path, model_path)
model_path = resolve_generation_model_path(model_path, work_dir)
device = select_device()
is_sdxl = "sdxl" in model_path.lower()
local_single_file = os.path.isfile(model_path)
is_sdxl = "sdxl" in model_path.lower() or local_single_file
enable_upscale = bool(upscale and upscale != "none")
os.makedirs(work_dir, exist_ok=True)
if local_files_only():
print(
"→ Offline mode enabled: using local files/cache only "
"(set SKYMAP_ALLOW_DOWNLOADS=1 to allow downloads).",
flush=True,
)
with tempfile.TemporaryDirectory(dir=work_dir) as tempdir:
print(f"→ Using tempdir: {tempdir}")
@@ -544,9 +555,14 @@ def generate(
gen_pipe = None
load_errors: list[str] = []
vae = None
if is_sdxl:
if is_sdxl and not local_single_file:
try:
vae = AutoencoderKL.from_pretrained(vae_model, subfolder="vae", torch_dtype=torch.float32)
vae = AutoencoderKL.from_pretrained(
vae_model,
subfolder="vae",
torch_dtype=torch.float32,
**diffusers_load_kwargs(),
)
except Exception as e: # noqa: BLE001
load_errors.append(f"vae: {e}")
@@ -556,18 +572,35 @@ def generate(
pipe_kwargs = {"torch_dtype": torch.float32}
if vae is not None:
pipe_kwargs["vae"] = vae
pipe_kwargs.update(diffusers_load_kwargs())
# Some SDXL repos only ship a UNet (e.g., sdxl-360); in that case load SDXL base
# and swap the UNet to keep the rest of the components consistent.
unet_only = False
if model_path and model_path.endswith("sdxl-360-diffusion"):
if not local_single_file and model_path and model_path.endswith("sdxl-360-diffusion"):
try:
hf_hub_download(model_path, "unet/config.json")
hf_hub_download(
model_path,
"unet/config.json",
local_files_only=local_files_only(),
)
unet_only = True
except Exception:
unet_only = False
if unet_only:
if local_single_file:
original_config_file = SDXL_SINGLE_FILE_CONFIG
if not os.path.isfile(original_config_file):
raise RuntimeError(
f"Local SDXL config not found at {original_config_file}. "
"This is required to load single-file checkpoints without fetching from GitHub."
)
gen_pipe = StableDiffusionXLPipeline.from_single_file(
model_path,
original_config_file=original_config_file,
**pipe_kwargs,
).to(device)
elif unet_only:
base_pipe = StableDiffusionXLPipeline.from_pretrained(
base_model,
**pipe_kwargs
@@ -575,7 +608,8 @@ def generate(
unet = UNet2DConditionModel.from_pretrained(
model_path,
subfolder="unet",
torch_dtype=torch.float32
torch_dtype=torch.float32,
**diffusers_load_kwargs(),
).to(device)
base_pipe.unet = unet
gen_pipe = base_pipe
@@ -588,13 +622,14 @@ def generate(
else:
gen_pipe = StableDiffusionPipeline.from_pretrained(
model_path,
torch_dtype=torch.float32
torch_dtype=torch.float32,
**diffusers_load_kwargs(),
).to(device)
except Exception as e: # noqa: BLE001
load_errors.append(f"from_pretrained: {e}")
# Fallback: single-file SDXL checkpoint from the repo
if gen_pipe is None:
if gen_pipe is None and not local_single_file:
ckpt_candidates = [
"sdxl_360_diffusion.safetensors",
"sdxl_360_diffusion_unet.safetensors",
@@ -603,10 +638,17 @@ def generate(
last_err = None
for fname in ckpt_candidates:
try:
ckpt = hf_hub_download(model_path, fname)
ckpt = hf_hub_download(
model_path,
fname,
local_files_only=local_files_only(),
)
pipe_kwargs = {"torch_dtype": torch.float32}
if vae is not None:
pipe_kwargs["vae"] = vae
pipe_kwargs.update(diffusers_load_kwargs())
if is_sdxl and os.path.isfile(SDXL_SINGLE_FILE_CONFIG):
pipe_kwargs["original_config_file"] = SDXL_SINGLE_FILE_CONFIG
gen_pipe = StableDiffusionXLPipeline.from_single_file(
ckpt,
**pipe_kwargs
@@ -627,7 +669,8 @@ def generate(
vae = vae or AutoencoderKL.from_pretrained(
SDXL_VAE,
subfolder="vae",
torch_dtype=torch.float32
torch_dtype=torch.float32,
**diffusers_load_kwargs(),
)
gen_pipe.vae = vae.to(device)
gen_pipe.to(device)