Improve model loading robustness, add offline mode handling, and support local single-file SDXL checkpoints
This commit is contained in:
@@ -400,12 +400,21 @@ def postprocess_image(
|
||||
mask = create_mask(width, height, mask_w)
|
||||
|
||||
print("→ Loading seam inpaint model…")
|
||||
try:
|
||||
inpaint_pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
INPAINT_MODEL,
|
||||
torch_dtype=torch.float32,
|
||||
safety_checker=None,
|
||||
requires_safety_checker=False,
|
||||
**diffusers_load_kwargs(),
|
||||
).to(device)
|
||||
except Exception as e: # noqa: BLE001
|
||||
mode = "offline/local-cache" if local_files_only() else "online"
|
||||
raise RuntimeError(
|
||||
f"Failed to load seam inpaint model '{INPAINT_MODEL}' in {mode} mode. "
|
||||
"If downloads are disabled, make sure the model is already present in the "
|
||||
"Hugging Face cache or set SKYMAP_ALLOW_DOWNLOADS=1 before launching."
|
||||
) from e
|
||||
configure_pipeline_memory(inpaint_pipe, vae_tiling=False)
|
||||
|
||||
print("→ Inpainting seam for seamless tiling…")
|
||||
@@ -525,18 +534,20 @@ def generate(
|
||||
height: int = 512,
|
||||
seam_inpaint: bool = False,
|
||||
) -> str:
|
||||
# Normalize common aliases that 404 on HF
|
||||
aliases = {
|
||||
"proximasan/sdxl-360-diffusion": "ProGamerGov/sdxl-360-diffusion",
|
||||
"proximasan": "ProGamerGov/sdxl-360-diffusion",
|
||||
}
|
||||
model_path = aliases.get(model_path, model_path)
|
||||
model_path = resolve_generation_model_path(model_path, work_dir)
|
||||
|
||||
device = select_device()
|
||||
is_sdxl = "sdxl" in model_path.lower()
|
||||
local_single_file = os.path.isfile(model_path)
|
||||
is_sdxl = "sdxl" in model_path.lower() or local_single_file
|
||||
enable_upscale = bool(upscale and upscale != "none")
|
||||
|
||||
os.makedirs(work_dir, exist_ok=True)
|
||||
if local_files_only():
|
||||
print(
|
||||
"→ Offline mode enabled: using local files/cache only "
|
||||
"(set SKYMAP_ALLOW_DOWNLOADS=1 to allow downloads).",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory(dir=work_dir) as tempdir:
|
||||
print(f"→ Using tempdir: {tempdir}")
|
||||
@@ -544,9 +555,14 @@ def generate(
|
||||
gen_pipe = None
|
||||
load_errors: list[str] = []
|
||||
vae = None
|
||||
if is_sdxl:
|
||||
if is_sdxl and not local_single_file:
|
||||
try:
|
||||
vae = AutoencoderKL.from_pretrained(vae_model, subfolder="vae", torch_dtype=torch.float32)
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
vae_model,
|
||||
subfolder="vae",
|
||||
torch_dtype=torch.float32,
|
||||
**diffusers_load_kwargs(),
|
||||
)
|
||||
except Exception as e: # noqa: BLE001
|
||||
load_errors.append(f"vae: {e}")
|
||||
|
||||
@@ -556,18 +572,35 @@ def generate(
|
||||
pipe_kwargs = {"torch_dtype": torch.float32}
|
||||
if vae is not None:
|
||||
pipe_kwargs["vae"] = vae
|
||||
pipe_kwargs.update(diffusers_load_kwargs())
|
||||
|
||||
# Some SDXL repos only ship a UNet (e.g., sdxl-360); in that case load SDXL base
|
||||
# and swap the UNet to keep the rest of the components consistent.
|
||||
unet_only = False
|
||||
if model_path and model_path.endswith("sdxl-360-diffusion"):
|
||||
if not local_single_file and model_path and model_path.endswith("sdxl-360-diffusion"):
|
||||
try:
|
||||
hf_hub_download(model_path, "unet/config.json")
|
||||
hf_hub_download(
|
||||
model_path,
|
||||
"unet/config.json",
|
||||
local_files_only=local_files_only(),
|
||||
)
|
||||
unet_only = True
|
||||
except Exception:
|
||||
unet_only = False
|
||||
|
||||
if unet_only:
|
||||
if local_single_file:
|
||||
original_config_file = SDXL_SINGLE_FILE_CONFIG
|
||||
if not os.path.isfile(original_config_file):
|
||||
raise RuntimeError(
|
||||
f"Local SDXL config not found at {original_config_file}. "
|
||||
"This is required to load single-file checkpoints without fetching from GitHub."
|
||||
)
|
||||
gen_pipe = StableDiffusionXLPipeline.from_single_file(
|
||||
model_path,
|
||||
original_config_file=original_config_file,
|
||||
**pipe_kwargs,
|
||||
).to(device)
|
||||
elif unet_only:
|
||||
base_pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
base_model,
|
||||
**pipe_kwargs
|
||||
@@ -575,7 +608,8 @@ def generate(
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
model_path,
|
||||
subfolder="unet",
|
||||
torch_dtype=torch.float32
|
||||
torch_dtype=torch.float32,
|
||||
**diffusers_load_kwargs(),
|
||||
).to(device)
|
||||
base_pipe.unet = unet
|
||||
gen_pipe = base_pipe
|
||||
@@ -588,13 +622,14 @@ def generate(
|
||||
else:
|
||||
gen_pipe = StableDiffusionPipeline.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch.float32
|
||||
torch_dtype=torch.float32,
|
||||
**diffusers_load_kwargs(),
|
||||
).to(device)
|
||||
except Exception as e: # noqa: BLE001
|
||||
load_errors.append(f"from_pretrained: {e}")
|
||||
|
||||
# Fallback: single-file SDXL checkpoint from the repo
|
||||
if gen_pipe is None:
|
||||
if gen_pipe is None and not local_single_file:
|
||||
ckpt_candidates = [
|
||||
"sdxl_360_diffusion.safetensors",
|
||||
"sdxl_360_diffusion_unet.safetensors",
|
||||
@@ -603,10 +638,17 @@ def generate(
|
||||
last_err = None
|
||||
for fname in ckpt_candidates:
|
||||
try:
|
||||
ckpt = hf_hub_download(model_path, fname)
|
||||
ckpt = hf_hub_download(
|
||||
model_path,
|
||||
fname,
|
||||
local_files_only=local_files_only(),
|
||||
)
|
||||
pipe_kwargs = {"torch_dtype": torch.float32}
|
||||
if vae is not None:
|
||||
pipe_kwargs["vae"] = vae
|
||||
pipe_kwargs.update(diffusers_load_kwargs())
|
||||
if is_sdxl and os.path.isfile(SDXL_SINGLE_FILE_CONFIG):
|
||||
pipe_kwargs["original_config_file"] = SDXL_SINGLE_FILE_CONFIG
|
||||
gen_pipe = StableDiffusionXLPipeline.from_single_file(
|
||||
ckpt,
|
||||
**pipe_kwargs
|
||||
@@ -627,7 +669,8 @@ def generate(
|
||||
vae = vae or AutoencoderKL.from_pretrained(
|
||||
SDXL_VAE,
|
||||
subfolder="vae",
|
||||
torch_dtype=torch.float32
|
||||
torch_dtype=torch.float32,
|
||||
**diffusers_load_kwargs(),
|
||||
)
|
||||
gen_pipe.vae = vae.to(device)
|
||||
gen_pipe.to(device)
|
||||
|
||||
Reference in New Issue
Block a user