diff --git a/generate_equirect.py b/generate_equirect.py index 30b8b4f..dd98a7e 100644 --- a/generate_equirect.py +++ b/generate_equirect.py @@ -39,6 +39,8 @@ from huggingface_hub import hf_hub_download MODEL_PATH = "ProGamerGov/sdxl-360-diffusion" BASE_SDXL_MODEL = "stabilityai/stable-diffusion-xl-base-1.0" INPAINT_MODEL = "Lykon/dreamshaper-8-inpainting" +LOCAL_SDXL_360_MODEL = os.path.join("models", "sdxl360Diffusion_v10.safetensors") +SDXL_SINGLE_FILE_CONFIG = os.path.join(os.path.dirname(__file__), "configs", "sd_xl_base.yaml") TOPAZ_CLI = "/Applications/Topaz Photo AI.app/Contents/MacOS/Topaz Photo AI" REALESRGAN_MODEL = "RealESRGAN_x4plus.pth" REALESRGAN_SCALE = 4 # x4 model for full upscale @@ -46,6 +48,51 @@ REALESRGAN_SCALE = 4 # x4 model for full upscale SDXL_VAE = "stabilityai/sdxl-vae" +def downloads_allowed() -> bool: + return os.environ.get("SKYMAP_ALLOW_DOWNLOADS", "").strip().lower() in {"1", "true", "yes", "on"} + + +def local_files_only() -> bool: + return not downloads_allowed() + + +def resolve_existing_path(path: str, work_dir: str) -> str: + if os.path.isabs(path) and os.path.exists(path): + return path + + candidates = [ + os.path.abspath(path), + os.path.abspath(os.path.join(work_dir, path)), + ] + for candidate in candidates: + if os.path.exists(candidate): + return candidate + + return path + + +def resolve_generation_model_path(model_path: str, work_dir: str) -> str: + aliases = { + "proximasan/sdxl-360-diffusion": "ProGamerGov/sdxl-360-diffusion", + "proximasan": "ProGamerGov/sdxl-360-diffusion", + } + model_path = aliases.get(model_path, model_path) + + local_model = resolve_existing_path(LOCAL_SDXL_360_MODEL, work_dir) + if ( + model_path in {"", MODEL_PATH, "ProGamerGov/sdxl-360-diffusion"} + and os.path.isfile(local_model) + ): + print(f"→ Using local SDXL 360 checkpoint: {local_model}", flush=True) + return local_model + + return resolve_existing_path(model_path, work_dir) + + +def diffusers_load_kwargs() -> dict[str, bool]: + return {"local_files_only": local_files_only()} + + def sanitize_name(prompt: str) -> str: base = prompt.strip().lower() base = re.sub(r"\s+", "_", base)