Files
skymap-gen/generate_equirect.py

888 lines
32 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
"""
Generate an equirectangular HDRI image using Diffusers.
Optional upscaling:
--upscale topaz legacy Topaz Photo AI CLI (if installed)
--upscale realesrgan open Real-ESRGAN in-Python upscaler (pip install realesrgan==0.3.0 basicsr opencv-python)
Default flow: prompt in equirectangular PNG out. Add --seam-inpaint to patch the horizontal wrap seam.
"""
import argparse
import gc
import json
import math
import os
import re
import shutil
import sys
import tempfile
import torch
from PIL import Image, ImageDraw
from PIL.PngImagePlugin import PngInfo
import numpy as np
from diffusers import (
StableDiffusionPipeline,
StableDiffusionXLPipeline,
DPMSolverMultistepScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
HeunDiscreteScheduler,
DDIMScheduler,
StableDiffusionInpaintPipeline,
AutoencoderKL,
UNet2DConditionModel,
)
from huggingface_hub import hf_hub_download
# Default panorama model. You can override via --model-path.
# Uses the SDXL 360 diffusion checkpoint from ProGamerGov (single-file safetensors).
MODEL_PATH = "ProGamerGov/sdxl-360-diffusion"
BASE_SDXL_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
INPAINT_MODEL = "Lykon/dreamshaper-8-inpainting"
LOCAL_SDXL_360_MODEL = os.path.join("models", "sdxl360Diffusion_v10.safetensors")
SDXL_SINGLE_FILE_CONFIG = os.path.join(os.path.dirname(__file__), "configs", "sd_xl_base.yaml")
TOPAZ_CLI = "/Applications/Topaz Photo AI.app/Contents/MacOS/Topaz Photo AI"
REALESRGAN_MODEL = "RealESRGAN_x4plus.pth"
REALESRGAN_SCALE = 4 # x4 model for full upscale
# Recommended VAE for SDXL checkpoints
SDXL_VAE = "stabilityai/sdxl-vae"
def downloads_allowed() -> bool:
return os.environ.get("SKYMAP_ALLOW_DOWNLOADS", "").strip().lower() in {"1", "true", "yes", "on"}
def local_files_only() -> bool:
return not downloads_allowed()
def resolve_existing_path(path: str, work_dir: str) -> str:
if os.path.isabs(path) and os.path.exists(path):
return path
candidates = [
os.path.abspath(path),
os.path.abspath(os.path.join(work_dir, path)),
]
for candidate in candidates:
if os.path.exists(candidate):
return candidate
return path
def resolve_generation_model_path(model_path: str, work_dir: str) -> str:
aliases = {
"proximasan/sdxl-360-diffusion": "ProGamerGov/sdxl-360-diffusion",
"proximasan": "ProGamerGov/sdxl-360-diffusion",
}
model_path = aliases.get(model_path, model_path)
local_model = resolve_existing_path(LOCAL_SDXL_360_MODEL, work_dir)
if (
model_path in {"", MODEL_PATH, "ProGamerGov/sdxl-360-diffusion"}
and os.path.isfile(local_model)
):
print(f"→ Using local SDXL 360 checkpoint: {local_model}", flush=True)
return local_model
return resolve_existing_path(model_path, work_dir)
def diffusers_load_kwargs() -> dict[str, bool]:
return {"local_files_only": local_files_only()}
def sanitize_name(prompt: str) -> str:
base = prompt.strip().lower()
base = re.sub(r"\s+", "_", base)
base = re.sub(r"[^a-z0-9_]+", "", base)
return base or "env"
def next_filename(output_dir: str, base: str, width: int, height: int) -> str:
os.makedirs(output_dir, exist_ok=True)
i = 1
while True:
fname = f"{base}-{i}-{width}x{height}.png"
candidate = os.path.join(output_dir, fname)
if not os.path.exists(candidate):
return candidate
i += 1
def save_png_with_prompt(img: Image.Image, out_path: str, prompt: str) -> None:
pnginfo = PngInfo()
pnginfo.add_text("prompt", prompt)
save_kwargs = {"pnginfo": pnginfo}
for key in ("icc_profile", "exif", "dpi"):
if key in img.info:
save_kwargs[key] = img.info[key]
img.save(out_path, **save_kwargs)
def shift_image(img: Image.Image, shift: int) -> Image.Image:
w, h = img.size
out = Image.new("RGB", (w, h))
out.paste(img.crop((shift, 0, w, h)), (0, 0))
out.paste(img.crop((0, 0, shift, h)), (w - shift, 0))
return out
def create_mask(width: int, height: int, mask_w: int) -> Image.Image:
mask = Image.new("L", (width, height), 0)
draw = ImageDraw.Draw(mask)
left = (width - mask_w) // 2
draw.rectangle([left, 0, left + mask_w, height], fill=255)
return mask
def unshift_image(img: Image.Image, shift: int) -> Image.Image:
w, h = img.size
out = Image.new("RGB", (w, h))
out.paste(img.crop((w - shift, 0, w, h)), (0, 0))
out.paste(img.crop((0, 0, w - shift, h)), (shift, 0))
return out
def select_device() -> str:
if torch.backends.mps.is_available():
return "mps"
if torch.cuda.is_available():
return "cuda"
return "cpu"
def clear_torch_cache(device: str | None = None) -> None:
gc.collect()
if device == "cuda" and torch.cuda.is_available():
torch.cuda.empty_cache()
elif device == "mps" and torch.backends.mps.is_available() and hasattr(torch, "mps"):
try:
torch.mps.synchronize()
except Exception:
pass
empty_cache = getattr(torch.mps, "empty_cache", None)
if empty_cache is not None:
empty_cache()
def configure_pipeline_memory(pipe, *, vae_tiling: bool = True) -> None:
pipe.enable_attention_slicing()
if vae_tiling and hasattr(pipe, "enable_vae_tiling"):
pipe.enable_vae_tiling()
elif not vae_tiling and hasattr(pipe, "disable_vae_tiling"):
pipe.disable_vae_tiling()
if hasattr(pipe, "enable_vae_slicing"):
pipe.enable_vae_slicing()
def make_progress_cb(enable_upscale: bool, seam_inpaint: bool):
def progress_cb(phase: str, current: int, total: int):
payload = {
"phase": phase,
"current": current,
"total": total,
"upscale": enable_upscale,
"seamInpaint": seam_inpaint,
}
print(f"PROGRESS {json.dumps(payload)}", flush=True)
return progress_cb
def decode_latents_to_image(vae: AutoencoderKL, latents: torch.Tensor, device: str) -> Image.Image:
print("→ Decoding latent image with standalone VAE…")
if hasattr(vae, "enable_tiling"):
vae.enable_tiling()
if hasattr(vae, "enable_slicing"):
vae.enable_slicing()
vae.to(device)
vae.eval()
vae_dtype = next(vae.parameters()).dtype
with torch.inference_mode():
latents = latents.to(device=device, dtype=vae_dtype)
has_latents_mean = hasattr(vae.config, "latents_mean") and vae.config.latents_mean is not None
has_latents_std = hasattr(vae.config, "latents_std") and vae.config.latents_std is not None
if has_latents_mean and has_latents_std:
latent_channels = len(vae.config.latents_mean)
latents_mean = (
torch.tensor(vae.config.latents_mean)
.view(1, latent_channels, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = (
torch.tensor(vae.config.latents_std)
.view(1, latent_channels, 1, 1)
.to(latents.device, latents.dtype)
)
latents = latents * latents_std / vae.config.scaling_factor + latents_mean
else:
latents = latents / vae.config.scaling_factor
decoded = vae.decode(latents, return_dict=False)[0]
decoded = (decoded / 2 + 0.5).clamp(0, 1)
image = decoded[0].detach().cpu().permute(1, 2, 0).float().numpy()
del decoded, latents
clear_torch_cache(device)
return Image.fromarray((image * 255).round().astype("uint8"))
def run_topaz(input_path: str, tempdir: str) -> str:
print("→ Upscaling with Topaz Photo AI CLI…")
result = None
try:
result = os.system(f'"{TOPAZ_CLI}" --cli "{input_path}" -o "{tempdir}"')
except Exception as e: # noqa: BLE001
raise RuntimeError(f"Topaz invocation failed: {e}") from e
if result != 0:
raise RuntimeError("Topaz CLI returned non-zero exit code")
upscaled_files = sorted(
[os.path.join(tempdir, f) for f in os.listdir(tempdir) if f.lower().endswith('.png')],
key=os.path.getmtime,
reverse=True
)
if not upscaled_files:
raise RuntimeError("Topaz produced no PNG output")
print(f"→ Upscaled file: {upscaled_files[0]}")
return upscaled_files[0]
def run_realesrgan(
input_image: Image.Image,
tempdir: str,
scale: int = 4,
model_path: str = REALESRGAN_MODEL,
progress_cb=None
) -> str:
try:
# Compatibility shim for newer torchvision where functional_tensor moved/renamed
import sys
try:
import torchvision.transforms._functional_tensor as _ft # type: ignore
sys.modules.setdefault("torchvision.transforms.functional_tensor", _ft)
except Exception:
pass
from realesrgan import RealESRGANer
from basicsr.archs.rrdbnet_arch import RRDBNet
import cv2
import types
except Exception as e: # noqa: BLE001
raise RuntimeError(
"Real-ESRGAN dependencies missing. Install with: pip install realesrgan==0.3.0 basicsr opencv-python torchvision"
) from e
device = select_device()
is_sdxl = "sdxl" in model_path.lower()
if not model_path or not os.path.exists(model_path):
raise RuntimeError(
f"Real-ESRGAN model not found at {model_path!r}. "
"Place RealESRGAN_x4plus.pth next to this script or update REALESRGAN_MODEL."
)
img_bgr = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR)
print(f"→ Upscaling with Real-ESRGAN (x{scale}) on {device}")
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=scale,
)
upsampler = RealESRGANer(
model_path=model_path,
scale=scale,
model=model,
tile=64, # aggressive tiling to keep memory & runtime manageable
tile_pad=10,
pre_pad=0,
half=False, # keep full precision for CPU/MPS
)
# Wrap tile processing to surface progress per tile.
def tile_process_with_progress(self):
batch, channel, height, width = self.img.shape
output_height = height * self.scale
output_width = width * self.scale
output_shape = (batch, channel, output_height, output_width)
self.output = self.img.new_zeros(output_shape)
tiles_x = math.ceil(width / self.tile_size)
tiles_y = math.ceil(height / self.tile_size)
total_tiles = max(1, tiles_x * tiles_y)
for y in range(tiles_y):
for x in range(tiles_x):
ofs_x = x * self.tile_size
ofs_y = y * self.tile_size
input_start_x = ofs_x
input_end_x = min(ofs_x + self.tile_size, width)
input_start_y = ofs_y
input_end_y = min(ofs_y + self.tile_size, height)
input_start_x_pad = max(input_start_x - self.tile_pad, 0)
input_end_x_pad = min(input_end_x + self.tile_pad, width)
input_start_y_pad = max(input_start_y - self.tile_pad, 0)
input_end_y_pad = min(input_end_y + self.tile_pad, height)
input_tile_width = input_end_x - input_start_x
input_tile_height = input_end_y - input_start_y
tile_idx = y * tiles_x + x + 1
input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
try:
with torch.no_grad():
output_tile = self.model(input_tile)
except RuntimeError as error:
print('Error', error)
if progress_cb:
progress_cb("upscale", tile_idx, total_tiles)
output_start_x = input_start_x * self.scale
output_end_x = input_end_x * self.scale
output_start_y = input_start_y * self.scale
output_end_y = input_end_y * self.scale
output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
self.output[:, :, output_start_y:output_end_y,
output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
output_start_x_tile:output_end_x_tile]
if progress_cb:
upsampler.tile_process = types.MethodType(tile_process_with_progress, upsampler)
if progress_cb:
progress_cb("upscale", 0, 1)
sr_img, _ = upsampler.enhance(img_bgr, outscale=scale)
sr_img = Image.fromarray(cv2.cvtColor(sr_img, cv2.COLOR_BGR2RGB))
if progress_cb:
progress_cb("upscale", 1, 1)
out_path = os.path.join(tempdir, f"realesrgan_x{scale}.png")
sr_img.save(out_path)
print(f"→ Real-ESRGAN output: {out_path}")
return out_path
def postprocess_image(
prompt: str,
input_path: str,
output_path: str,
tempdir: str,
upscale: str = "none",
steps: int = 25,
guidance: float = 4.5,
width: int = 1024,
height: int = 512,
seam_inpaint: bool = False,
) -> str:
device = select_device()
enable_upscale = bool(upscale and upscale != "none")
progress_cb = make_progress_cb(enable_upscale, seam_inpaint)
with Image.open(input_path) as input_img:
image = input_img.convert("RGB")
seamless_path = os.path.join(tempdir, os.path.basename(output_path))
if seam_inpaint:
shift_amt = width // 2
mask_w = width // 8
shifted = shift_image(image, shift_amt)
mask = create_mask(width, height, mask_w)
print("→ Loading seam inpaint model…")
try:
inpaint_pipe = StableDiffusionInpaintPipeline.from_pretrained(
INPAINT_MODEL,
torch_dtype=torch.float32,
safety_checker=None,
requires_safety_checker=False,
**diffusers_load_kwargs(),
).to(device)
except Exception as e: # noqa: BLE001
mode = "offline/local-cache" if local_files_only() else "online"
raise RuntimeError(
f"Failed to load seam inpaint model '{INPAINT_MODEL}' in {mode} mode. "
"If downloads are disabled, make sure the model is already present in the "
"Hugging Face cache or set SKYMAP_ALLOW_DOWNLOADS=1 before launching."
) from e
configure_pipeline_memory(inpaint_pipe, vae_tiling=False)
print("→ Inpainting seam for seamless tiling…")
progress_cb("inpaint", 0, steps)
inpainted = inpaint_pipe(
prompt=prompt,
image=shifted,
mask_image=mask,
num_inference_steps=steps,
guidance_scale=guidance,
width=width,
height=height,
callback_steps=1,
callback=lambda step, timestep, kwargs: progress_cb("inpaint", step + 1, steps),
).images[0]
del inpaint_pipe, shifted, mask
clear_torch_cache(device)
inpainted = unshift_image(inpainted, shift_amt)
inpainted.save(seamless_path)
print(f"→ Crafted seamless image: {seamless_path}")
final_source = inpainted
else:
image.save(seamless_path)
print(f"→ Using raw output (seam inpaint disabled): {seamless_path}")
final_source = image
final_path = seamless_path
if upscale and upscale != "none":
try:
if upscale is True or upscale == "topaz":
final_path = run_topaz(seamless_path, tempdir)
elif upscale == "realesrgan":
final_path = run_realesrgan(
final_source,
tempdir,
scale=REALESRGAN_SCALE,
model_path=REALESRGAN_MODEL,
progress_cb=progress_cb
)
else:
raise ValueError(f"Unknown upscale option '{upscale}'")
except Exception as e: # noqa: BLE001
print(f"Upscaling failed ({upscale}); keeping seamless image: {e}")
with Image.open(final_path) as final_img:
final_img.load()
save_png_with_prompt(final_img, output_path, prompt)
try:
with Image.open(output_path) as _im:
print(f"→ Final image written to {output_path} [{_im.size[0]}x{_im.size[1]}]")
except Exception:
print(f"→ Final image written to {output_path}")
return output_path
def restart_for_postprocess(
prompt: str,
input_path: str,
output_path: str,
tempdir: str,
work_dir: str,
upscale: str,
steps: int,
guidance: float,
width: int,
height: int,
seam_inpaint: bool,
) -> None:
clear_torch_cache(select_device())
script = os.path.abspath(__file__)
args = [
sys.executable,
script,
"--prompt",
prompt,
"--postprocess-input",
input_path,
"--postprocess-output",
output_path,
"--postprocess-tempdir",
tempdir,
"--work-dir",
work_dir,
"--upscale",
upscale or "none",
"--steps",
str(steps),
"--guidance",
str(guidance),
"--width",
str(width),
"--height",
str(height),
]
if seam_inpaint:
args.append("--seam-inpaint")
print("→ Restarting Python for post-processing to release generation model memory…", flush=True)
os.execv(sys.executable, args)
def generate(
prompt: str,
output_path: str,
work_dir: str,
upscale: str = "none",
model_path: str = MODEL_PATH,
base_model: str = BASE_SDXL_MODEL,
vae_model: str = SDXL_VAE,
steps: int = 25,
guidance: float = 4.5,
scheduler: str | None = None,
width: int = 1024,
height: int = 512,
seam_inpaint: bool = False,
) -> str:
model_path = resolve_generation_model_path(model_path, work_dir)
device = select_device()
local_single_file = os.path.isfile(model_path)
is_sdxl = "sdxl" in model_path.lower() or local_single_file
enable_upscale = bool(upscale and upscale != "none")
os.makedirs(work_dir, exist_ok=True)
if local_files_only():
print(
"→ Offline mode enabled: using local files/cache only "
"(set SKYMAP_ALLOW_DOWNLOADS=1 to allow downloads).",
flush=True,
)
with tempfile.TemporaryDirectory(dir=work_dir) as tempdir:
print(f"→ Using tempdir: {tempdir}")
gen_pipe = None
load_errors: list[str] = []
vae = None
if is_sdxl and not local_single_file:
try:
vae = AutoencoderKL.from_pretrained(
vae_model,
subfolder="vae",
torch_dtype=torch.float32,
**diffusers_load_kwargs(),
)
except Exception as e: # noqa: BLE001
load_errors.append(f"vae: {e}")
# Try native Diffusers repo
try:
if is_sdxl:
pipe_kwargs = {"torch_dtype": torch.float32}
if vae is not None:
pipe_kwargs["vae"] = vae
pipe_kwargs.update(diffusers_load_kwargs())
# Some SDXL repos only ship a UNet (e.g., sdxl-360); in that case load SDXL base
# and swap the UNet to keep the rest of the components consistent.
unet_only = False
if not local_single_file and model_path and model_path.endswith("sdxl-360-diffusion"):
try:
hf_hub_download(
model_path,
"unet/config.json",
local_files_only=local_files_only(),
)
unet_only = True
except Exception:
unet_only = False
if local_single_file:
original_config_file = SDXL_SINGLE_FILE_CONFIG
if not os.path.isfile(original_config_file):
raise RuntimeError(
f"Local SDXL config not found at {original_config_file}. "
"This is required to load single-file checkpoints without fetching from GitHub."
)
gen_pipe = StableDiffusionXLPipeline.from_single_file(
model_path,
original_config_file=original_config_file,
**pipe_kwargs,
).to(device)
elif unet_only:
base_pipe = StableDiffusionXLPipeline.from_pretrained(
base_model,
**pipe_kwargs
).to(device)
unet = UNet2DConditionModel.from_pretrained(
model_path,
subfolder="unet",
torch_dtype=torch.float32,
**diffusers_load_kwargs(),
).to(device)
base_pipe.unet = unet
gen_pipe = base_pipe
del base_pipe, unet
else:
gen_pipe = StableDiffusionXLPipeline.from_pretrained(
model_path,
**pipe_kwargs
).to(device)
else:
gen_pipe = StableDiffusionPipeline.from_pretrained(
model_path,
torch_dtype=torch.float32,
**diffusers_load_kwargs(),
).to(device)
except Exception as e: # noqa: BLE001
load_errors.append(f"from_pretrained: {e}")
# Fallback: single-file SDXL checkpoint from the repo
if gen_pipe is None and not local_single_file:
ckpt_candidates = [
"sdxl_360_diffusion.safetensors",
"sdxl_360_diffusion_unet.safetensors",
"model.safetensors",
]
last_err = None
for fname in ckpt_candidates:
try:
ckpt = hf_hub_download(
model_path,
fname,
local_files_only=local_files_only(),
)
pipe_kwargs = {"torch_dtype": torch.float32}
if vae is not None:
pipe_kwargs["vae"] = vae
pipe_kwargs.update(diffusers_load_kwargs())
if is_sdxl and os.path.isfile(SDXL_SINGLE_FILE_CONFIG):
pipe_kwargs["original_config_file"] = SDXL_SINGLE_FILE_CONFIG
gen_pipe = StableDiffusionXLPipeline.from_single_file(
ckpt,
**pipe_kwargs
).to(device)
break
except Exception as e2: # noqa: BLE001
last_err = e2
load_errors.append(f"{fname}: {e2}")
if gen_pipe is None:
raise RuntimeError(
f"Failed to load model '{model_path}'. "
"Ensure the repo/path exists (e.g., ProGamerGov/sdxl-360-diffusion) and "
"install accelerate for low_cpu_mem_usage: pip install accelerate. "
f"Errors: {load_errors}"
) from (last_err or Exception("No pipeline loaded"))
if gen_pipe is None:
mode = "offline/local-cache" if local_files_only() else "online"
raise RuntimeError(
f"Failed to load model '{model_path}' in {mode} mode. "
"The local single-file checkpoint needs the bundled SDXL config and locally cached "
"tokenizer/text-encoder metadata. "
f"Errors: {load_errors}"
)
if gen_pipe.vae is None and is_sdxl:
try:
vae = vae or AutoencoderKL.from_pretrained(
SDXL_VAE,
subfolder="vae",
torch_dtype=torch.float32,
**diffusers_load_kwargs(),
)
gen_pipe.vae = vae.to(device)
gen_pipe.to(device)
except Exception as e: # noqa: BLE001
load_errors.append(f"vae-fallback: {e}")
raise RuntimeError(
"Loaded SDXL pipeline without a VAE; failed to attach the SDXL VAE. "
f"Errors: {load_errors}"
) from e
# Optionally override scheduler; otherwise keep the pipeline default (Euler for SDXL base).
if scheduler:
sched_kind = scheduler.lower()
sched_cfg = gen_pipe.scheduler.config
if sched_kind in {"dpmsolver", "dpmsolver++"}:
gen_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sched_cfg)
elif sched_kind in {"dpmsolver-sde", "dpmsolver_sde"}:
gen_pipe.scheduler = DPMSolverMultistepScheduler.from_config(
sched_cfg,
algorithm_type="sde-dpmsolver++"
)
elif sched_kind in {"euler"}:
gen_pipe.scheduler = EulerDiscreteScheduler.from_config(sched_cfg)
elif sched_kind in {"euler_a", "euler-ancestral", "euler-ancestral-discrete"}:
gen_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sched_cfg)
elif sched_kind in {"heun"}:
gen_pipe.scheduler = HeunDiscreteScheduler.from_config(sched_cfg)
elif sched_kind in {"ddim"}:
gen_pipe.scheduler = DDIMScheduler.from_config(sched_cfg)
else:
raise ValueError(
f"Unsupported scheduler '{scheduler}'. "
"Try one of: euler, euler_a, heun, ddim, dpmsolver, dpmsolver-sde."
)
configure_pipeline_memory(gen_pipe)
if "pipe_kwargs" in locals():
del pipe_kwargs
progress_cb = make_progress_cb(enable_upscale, seam_inpaint)
print("→ Generating equirectangular HDRI…")
progress_cb("gen", 0, steps)
latent_output = gen_pipe(
prompt=prompt,
num_inference_steps=steps,
guidance_scale=guidance,
width=width,
height=height,
callback_steps=1,
callback=lambda step, timestep, kwargs: progress_cb("gen", step + 1, steps),
output_type="latent",
)
latents = latent_output.images.detach().cpu()
decoder_vae = gen_pipe.vae
vae = None
del latent_output, gen_pipe
clear_torch_cache(device)
progress_cb("decode", 0, 1)
image = decode_latents_to_image(decoder_vae, latents, device)
progress_cb("decode", 1, 1)
del latents, decoder_vae
clear_torch_cache(device)
gen_path = os.path.join(tempdir, f"base_{width}x{height}.png")
image.save(gen_path)
print(f"→ Saved initial image to {gen_path}")
if seam_inpaint:
del image
clear_torch_cache(device)
restart_for_postprocess(
prompt,
gen_path,
output_path,
tempdir,
work_dir,
upscale,
steps,
guidance,
width,
height,
seam_inpaint,
)
raise RuntimeError("Failed to restart Python for seam inpaint post-processing")
return postprocess_image(
prompt,
gen_path,
output_path,
tempdir,
upscale=upscale,
steps=steps,
guidance=guidance,
width=width,
height=height,
seam_inpaint=False,
)
def main():
parser = argparse.ArgumentParser(
description="Generate an equirectangular HDRI image"
)
parser.add_argument('--prompt', required=True, help='Text prompt for generation')
parser.add_argument('--output', help='Output filename (PNG)')
parser.add_argument('--output-dir', default='output', help='Directory for outputs (default: output)')
parser.add_argument('--work-dir', default=os.path.dirname(os.path.abspath(__file__)), help='Working directory for temp files')
parser.add_argument(
'--upscale',
choices=['none', 'topaz', 'realesrgan'],
default='realesrgan',
help='Optional upscaler: none, topaz (legacy), realesrgan (open source; default)'
)
parser.add_argument(
'--model-path',
default=MODEL_PATH,
help='Diffusers model id or local path (default: local models/sdxl360Diffusion_v10.safetensors when present)'
)
parser.add_argument(
'--base-model',
default=BASE_SDXL_MODEL,
help='SDXL base pipeline used when the model only provides a UNet (default: stabilityai/stable-diffusion-xl-base-1.0)'
)
parser.add_argument(
'--vae-model',
default=SDXL_VAE,
help='VAE repo/path to load for SDXL models (default: stabilityai/sdxl-vae; try madebyollin/sdxl-vae-fp16-fix on Mac)'
)
parser.add_argument('--steps', type=int, default=25, help='Number of inference steps (default: 25)')
parser.add_argument('--guidance', type=float, default=4.5, help='CFG guidance scale (default: 4.5)')
parser.add_argument('--width', type=int, default=1024, help='Output width (default: 1024)')
parser.add_argument('--height', type=int, default=512, help='Output height (default: 512)')
parser.add_argument(
'--seam-inpaint',
action='store_true',
help='Patch the horizontal wrap seam by shifting, inpainting the center seam, then shifting back'
)
parser.add_argument(
'--scheduler',
choices=['euler', 'euler_a', 'heun', 'ddim', 'dpmsolver', 'dpmsolver-sde'],
help='Sampler/scheduler override; default uses the pipeline scheduler (Euler for SDXL base)'
)
parser.add_argument(
'--allow-downloads',
action='store_true',
help='Allow missing model files to be downloaded from remote model hubs'
)
parser.add_argument('--postprocess-input', help=argparse.SUPPRESS)
parser.add_argument('--postprocess-output', help=argparse.SUPPRESS)
parser.add_argument('--postprocess-tempdir', help=argparse.SUPPRESS)
args = parser.parse_args()
if args.allow_downloads:
os.environ["SKYMAP_ALLOW_DOWNLOADS"] = "1"
if args.postprocess_input:
if not args.postprocess_output or not args.postprocess_tempdir:
parser.error("--postprocess-input requires --postprocess-output and --postprocess-tempdir")
try:
result_path = postprocess_image(
args.prompt,
os.path.abspath(args.postprocess_input),
os.path.abspath(args.postprocess_output),
os.path.abspath(args.postprocess_tempdir),
upscale=args.upscale,
steps=args.steps,
guidance=args.guidance,
width=args.width,
height=args.height,
seam_inpaint=args.seam_inpaint,
)
print(result_path)
shutil.rmtree(args.postprocess_tempdir, ignore_errors=True)
return
except Exception as e: # noqa: BLE001
print(f"Generation failed: {e}")
raise
base = sanitize_name(args.prompt)
target = args.output or next_filename(args.output_dir, base, args.width, args.height)
output_abs = os.path.abspath(target)
try:
result_path = generate(
args.prompt,
output_abs,
args.work_dir,
upscale=args.upscale,
model_path=args.model_path,
base_model=args.base_model,
vae_model=args.vae_model,
steps=args.steps,
guidance=args.guidance,
scheduler=args.scheduler,
width=args.width,
height=args.height,
seam_inpaint=args.seam_inpaint,
)
print(result_path)
except Exception as e: # noqa: BLE001
print(f"Generation failed: {e}")
raise
if __name__ == '__main__':
main()