Files
skymap-gen/generate_equirect.py

627 lines
24 KiB
Python

#!/usr/bin/env python3
"""
Generate an equirectangular HDRI image using Diffusers.
Optional upscaling:
--upscale topaz → legacy Topaz Photo AI CLI (if installed)
--upscale realesrgan → open Real-ESRGAN in-Python upscaler (pip install realesrgan==0.3.0 basicsr opencv-python)
Default flow: prompt in → equirectangular PNG out. Add --seam-inpaint to patch the horizontal wrap seam.
"""
import argparse
import gc
import json
import math
import os
import re
import shutil
import tempfile
import torch
from PIL import Image, ImageDraw
import numpy as np
from diffusers import (
StableDiffusionPipeline,
StableDiffusionXLPipeline,
DPMSolverMultistepScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
HeunDiscreteScheduler,
DDIMScheduler,
StableDiffusionInpaintPipeline,
AutoencoderKL,
UNet2DConditionModel,
)
from huggingface_hub import hf_hub_download
# Default panorama model. You can override via --model-path.
# Uses the SDXL 360 diffusion checkpoint from ProGamerGov (single-file safetensors).
MODEL_PATH = "ProGamerGov/sdxl-360-diffusion"
BASE_SDXL_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
INPAINT_MODEL = "Lykon/dreamshaper-8-inpainting"
TOPAZ_CLI = "/Applications/Topaz Photo AI.app/Contents/MacOS/Topaz Photo AI"
REALESRGAN_MODEL = "RealESRGAN_x4plus.pth"
REALESRGAN_SCALE = 4 # x4 model for full upscale
# Recommended VAE for SDXL checkpoints
SDXL_VAE = "stabilityai/sdxl-vae"
def sanitize_name(prompt: str) -> str:
base = prompt.strip().lower()
base = re.sub(r"\s+", "_", base)
base = re.sub(r"[^a-z0-9_]+", "", base)
return base or "env"
def next_filename(output_dir: str, base: str, width: int, height: int) -> str:
os.makedirs(output_dir, exist_ok=True)
i = 1
while True:
fname = f"{base}-{i}-{width}x{height}.png"
candidate = os.path.join(output_dir, fname)
if not os.path.exists(candidate):
return candidate
i += 1
def shift_image(img: Image.Image, shift: int) -> Image.Image:
w, h = img.size
out = Image.new("RGB", (w, h))
out.paste(img.crop((shift, 0, w, h)), (0, 0))
out.paste(img.crop((0, 0, shift, h)), (w - shift, 0))
return out
def create_mask(width: int, height: int, mask_w: int) -> Image.Image:
mask = Image.new("L", (width, height), 0)
draw = ImageDraw.Draw(mask)
left = (width - mask_w) // 2
draw.rectangle([left, 0, left + mask_w, height], fill=255)
return mask
def unshift_image(img: Image.Image, shift: int) -> Image.Image:
w, h = img.size
out = Image.new("RGB", (w, h))
out.paste(img.crop((w - shift, 0, w, h)), (0, 0))
out.paste(img.crop((0, 0, w - shift, h)), (shift, 0))
return out
def select_device() -> str:
if torch.backends.mps.is_available():
return "mps"
if torch.cuda.is_available():
return "cuda"
return "cpu"
def clear_torch_cache(device: str | None = None) -> None:
gc.collect()
if device == "cuda" and torch.cuda.is_available():
torch.cuda.empty_cache()
elif device == "mps" and torch.backends.mps.is_available() and hasattr(torch, "mps"):
try:
torch.mps.synchronize()
except Exception:
pass
empty_cache = getattr(torch.mps, "empty_cache", None)
if empty_cache is not None:
empty_cache()
def decode_latents_to_image(vae: AutoencoderKL, latents: torch.Tensor, device: str) -> Image.Image:
print("→ Decoding latent image with standalone VAE…")
if hasattr(vae, "enable_tiling"):
vae.enable_tiling()
if hasattr(vae, "enable_slicing"):
vae.enable_slicing()
vae.to(device)
vae.eval()
vae_dtype = next(vae.parameters()).dtype
with torch.inference_mode():
latents = latents.to(device=device, dtype=vae_dtype)
has_latents_mean = hasattr(vae.config, "latents_mean") and vae.config.latents_mean is not None
has_latents_std = hasattr(vae.config, "latents_std") and vae.config.latents_std is not None
if has_latents_mean and has_latents_std:
latent_channels = len(vae.config.latents_mean)
latents_mean = (
torch.tensor(vae.config.latents_mean)
.view(1, latent_channels, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = (
torch.tensor(vae.config.latents_std)
.view(1, latent_channels, 1, 1)
.to(latents.device, latents.dtype)
)
latents = latents * latents_std / vae.config.scaling_factor + latents_mean
else:
latents = latents / vae.config.scaling_factor
decoded = vae.decode(latents, return_dict=False)[0]
decoded = (decoded / 2 + 0.5).clamp(0, 1)
image = decoded[0].detach().cpu().permute(1, 2, 0).float().numpy()
del decoded, latents
clear_torch_cache(device)
return Image.fromarray((image * 255).round().astype("uint8"))
def run_topaz(input_path: str, tempdir: str) -> str:
print("→ Upscaling with Topaz Photo AI CLI…")
result = None
try:
result = os.system(f'"{TOPAZ_CLI}" --cli "{input_path}" -o "{tempdir}"')
except Exception as e: # noqa: BLE001
raise RuntimeError(f"Topaz invocation failed: {e}") from e
if result != 0:
raise RuntimeError("Topaz CLI returned non-zero exit code")
upscaled_files = sorted(
[os.path.join(tempdir, f) for f in os.listdir(tempdir) if f.lower().endswith('.png')],
key=os.path.getmtime,
reverse=True
)
if not upscaled_files:
raise RuntimeError("Topaz produced no PNG output")
print(f"→ Upscaled file: {upscaled_files[0]}")
return upscaled_files[0]
def run_realesrgan(
input_image: Image.Image,
tempdir: str,
scale: int = 4,
model_path: str = REALESRGAN_MODEL,
progress_cb=None
) -> str:
try:
# Compatibility shim for newer torchvision where functional_tensor moved/renamed
import sys
try:
import torchvision.transforms._functional_tensor as _ft # type: ignore
sys.modules.setdefault("torchvision.transforms.functional_tensor", _ft)
except Exception:
pass
from realesrgan import RealESRGANer
from basicsr.archs.rrdbnet_arch import RRDBNet
import cv2
import types
except Exception as e: # noqa: BLE001
raise RuntimeError(
"Real-ESRGAN dependencies missing. Install with: pip install realesrgan==0.3.0 basicsr opencv-python torchvision"
) from e
device = select_device()
is_sdxl = "sdxl" in model_path.lower()
if not model_path or not os.path.exists(model_path):
raise RuntimeError(
f"Real-ESRGAN model not found at {model_path!r}. "
"Place RealESRGAN_x4plus.pth next to this script or update REALESRGAN_MODEL."
)
img_bgr = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR)
print(f"→ Upscaling with Real-ESRGAN (x{scale}) on {device}")
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=scale,
)
upsampler = RealESRGANer(
model_path=model_path,
scale=scale,
model=model,
tile=64, # aggressive tiling to keep memory & runtime manageable
tile_pad=10,
pre_pad=0,
half=False, # keep full precision for CPU/MPS
)
# Wrap tile processing to surface progress per tile.
def tile_process_with_progress(self):
batch, channel, height, width = self.img.shape
output_height = height * self.scale
output_width = width * self.scale
output_shape = (batch, channel, output_height, output_width)
self.output = self.img.new_zeros(output_shape)
tiles_x = math.ceil(width / self.tile_size)
tiles_y = math.ceil(height / self.tile_size)
total_tiles = max(1, tiles_x * tiles_y)
for y in range(tiles_y):
for x in range(tiles_x):
ofs_x = x * self.tile_size
ofs_y = y * self.tile_size
input_start_x = ofs_x
input_end_x = min(ofs_x + self.tile_size, width)
input_start_y = ofs_y
input_end_y = min(ofs_y + self.tile_size, height)
input_start_x_pad = max(input_start_x - self.tile_pad, 0)
input_end_x_pad = min(input_end_x + self.tile_pad, width)
input_start_y_pad = max(input_start_y - self.tile_pad, 0)
input_end_y_pad = min(input_end_y + self.tile_pad, height)
input_tile_width = input_end_x - input_start_x
input_tile_height = input_end_y - input_start_y
tile_idx = y * tiles_x + x + 1
input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
try:
with torch.no_grad():
output_tile = self.model(input_tile)
except RuntimeError as error:
print('Error', error)
if progress_cb:
progress_cb("upscale", tile_idx, total_tiles)
output_start_x = input_start_x * self.scale
output_end_x = input_end_x * self.scale
output_start_y = input_start_y * self.scale
output_end_y = input_end_y * self.scale
output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
self.output[:, :, output_start_y:output_end_y,
output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
output_start_x_tile:output_end_x_tile]
if progress_cb:
upsampler.tile_process = types.MethodType(tile_process_with_progress, upsampler)
if progress_cb:
progress_cb("upscale", 0, 1)
sr_img, _ = upsampler.enhance(img_bgr, outscale=scale)
sr_img = Image.fromarray(cv2.cvtColor(sr_img, cv2.COLOR_BGR2RGB))
if progress_cb:
progress_cb("upscale", 1, 1)
out_path = os.path.join(tempdir, f"realesrgan_x{scale}.png")
sr_img.save(out_path)
print(f"→ Real-ESRGAN output: {out_path}")
return out_path
def generate(
prompt: str,
output_path: str,
work_dir: str,
upscale: str = "none",
model_path: str = MODEL_PATH,
base_model: str = BASE_SDXL_MODEL,
vae_model: str = SDXL_VAE,
steps: int = 25,
guidance: float = 4.5,
scheduler: str | None = None,
width: int = 1024,
height: int = 512,
seam_inpaint: bool = False,
) -> str:
# Normalize common aliases that 404 on HF
aliases = {
"proximasan/sdxl-360-diffusion": "ProGamerGov/sdxl-360-diffusion",
"proximasan": "ProGamerGov/sdxl-360-diffusion",
}
model_path = aliases.get(model_path, model_path)
device = select_device()
is_sdxl = "sdxl" in model_path.lower()
scale = guidance # keep inpaint guidance in sync with cfg guidance
enable_upscale = bool(upscale and upscale != "none")
os.makedirs(work_dir, exist_ok=True)
with tempfile.TemporaryDirectory(dir=work_dir) as tempdir:
print(f"→ Using tempdir: {tempdir}")
gen_pipe = None
load_errors: list[str] = []
vae = None
if is_sdxl:
try:
vae = AutoencoderKL.from_pretrained(vae_model, subfolder="vae", torch_dtype=torch.float32)
except Exception as e: # noqa: BLE001
load_errors.append(f"vae: {e}")
# Try native Diffusers repo
try:
if is_sdxl:
pipe_kwargs = {"torch_dtype": torch.float32}
if vae is not None:
pipe_kwargs["vae"] = vae
# Some SDXL repos only ship a UNet (e.g., sdxl-360); in that case load SDXL base
# and swap the UNet to keep the rest of the components consistent.
unet_only = False
if model_path and model_path.endswith("sdxl-360-diffusion"):
try:
hf_hub_download(model_path, "unet/config.json")
unet_only = True
except Exception:
unet_only = False
if unet_only:
base_pipe = StableDiffusionXLPipeline.from_pretrained(
base_model,
**pipe_kwargs
).to(device)
unet = UNet2DConditionModel.from_pretrained(
model_path,
subfolder="unet",
torch_dtype=torch.float32
).to(device)
base_pipe.unet = unet
gen_pipe = base_pipe
del base_pipe, unet
else:
gen_pipe = StableDiffusionXLPipeline.from_pretrained(
model_path,
**pipe_kwargs
).to(device)
else:
gen_pipe = StableDiffusionPipeline.from_pretrained(
model_path,
torch_dtype=torch.float32
).to(device)
except Exception as e: # noqa: BLE001
load_errors.append(f"from_pretrained: {e}")
# Fallback: single-file SDXL checkpoint from the repo
if gen_pipe is None:
ckpt_candidates = [
"sdxl_360_diffusion.safetensors",
"sdxl_360_diffusion_unet.safetensors",
"model.safetensors",
]
last_err = None
for fname in ckpt_candidates:
try:
ckpt = hf_hub_download(model_path, fname)
pipe_kwargs = {"torch_dtype": torch.float32}
if vae is not None:
pipe_kwargs["vae"] = vae
gen_pipe = StableDiffusionXLPipeline.from_single_file(
ckpt,
**pipe_kwargs
).to(device)
break
except Exception as e2: # noqa: BLE001
last_err = e2
load_errors.append(f"{fname}: {e2}")
if gen_pipe is None:
raise RuntimeError(
f"Failed to load model '{model_path}'. "
"Ensure the repo/path exists (e.g., ProGamerGov/sdxl-360-diffusion) and "
"install accelerate for low_cpu_mem_usage: pip install accelerate. "
f"Errors: {load_errors}"
) from (last_err or Exception("No pipeline loaded"))
if gen_pipe.vae is None and is_sdxl:
try:
vae = vae or AutoencoderKL.from_pretrained(
SDXL_VAE,
subfolder="vae",
torch_dtype=torch.float32
)
gen_pipe.vae = vae.to(device)
gen_pipe.to(device)
except Exception as e: # noqa: BLE001
load_errors.append(f"vae-fallback: {e}")
raise RuntimeError(
"Loaded SDXL pipeline without a VAE; failed to attach the SDXL VAE. "
f"Errors: {load_errors}"
) from e
# Optionally override scheduler; otherwise keep the pipeline default (Euler for SDXL base).
if scheduler:
sched_kind = scheduler.lower()
sched_cfg = gen_pipe.scheduler.config
if sched_kind in {"dpmsolver", "dpmsolver++"}:
gen_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sched_cfg)
elif sched_kind in {"dpmsolver-sde", "dpmsolver_sde"}:
gen_pipe.scheduler = DPMSolverMultistepScheduler.from_config(
sched_cfg,
algorithm_type="sde-dpmsolver++"
)
elif sched_kind in {"euler"}:
gen_pipe.scheduler = EulerDiscreteScheduler.from_config(sched_cfg)
elif sched_kind in {"euler_a", "euler-ancestral", "euler-ancestral-discrete"}:
gen_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sched_cfg)
elif sched_kind in {"heun"}:
gen_pipe.scheduler = HeunDiscreteScheduler.from_config(sched_cfg)
elif sched_kind in {"ddim"}:
gen_pipe.scheduler = DDIMScheduler.from_config(sched_cfg)
else:
raise ValueError(
f"Unsupported scheduler '{scheduler}'. "
"Try one of: euler, euler_a, heun, ddim, dpmsolver, dpmsolver-sde."
)
gen_pipe.enable_attention_slicing()
if hasattr(gen_pipe, "enable_vae_tiling"):
gen_pipe.enable_vae_tiling()
if hasattr(gen_pipe, "enable_vae_slicing"):
gen_pipe.enable_vae_slicing()
if "pipe_kwargs" in locals():
del pipe_kwargs
def progress_cb(phase: str, current: int, total: int):
payload = {
"phase": phase,
"current": current,
"total": total,
"upscale": enable_upscale,
"seamInpaint": seam_inpaint,
}
print(f"PROGRESS {json.dumps(payload)}", flush=True)
print("→ Generating equirectangular HDRI…")
progress_cb("gen", 0, steps)
latent_output = gen_pipe(
prompt=prompt,
num_inference_steps=steps,
guidance_scale=guidance,
width=width,
height=height,
callback_steps=1,
callback=lambda step, timestep, kwargs: progress_cb("gen", step + 1, steps),
output_type="latent",
)
latents = latent_output.images.detach().cpu()
decoder_vae = gen_pipe.vae
vae = None
del latent_output, gen_pipe
clear_torch_cache(device)
progress_cb("decode", 0, 1)
image = decode_latents_to_image(decoder_vae, latents, device)
progress_cb("decode", 1, 1)
del latents, decoder_vae
clear_torch_cache(device)
gen_path = os.path.join(tempdir, f"base_{width}x{height}.png")
image.save(gen_path)
print(f"→ Saved initial image to {gen_path}")
seamless_path = os.path.join(tempdir, os.path.basename(output_path))
if seam_inpaint:
shift_amt = width // 2
mask_w = width // 8
shifted = shift_image(image, shift_amt)
mask = create_mask(width, height, mask_w)
inpaint_pipe = StableDiffusionInpaintPipeline.from_pretrained(
INPAINT_MODEL,
torch_dtype=torch.float32
).to(device)
inpaint_pipe.enable_attention_slicing()
print("→ Inpainting seam for seamless tiling…")
progress_cb("inpaint", 0, steps)
inpainted = inpaint_pipe(
prompt=prompt,
image=shifted,
mask_image=mask,
num_inference_steps=steps,
guidance_scale=scale,
width=width,
height=height,
callback_steps=1,
callback=lambda step, timestep, kwargs: progress_cb("inpaint", step + 1, steps),
).images[0]
inpainted = unshift_image(inpainted, shift_amt)
inpainted.save(seamless_path)
print(f"→ Crafted seamless image: {seamless_path}")
final_source = inpainted
else:
image.save(seamless_path)
print(f"→ Using raw output (seam inpaint disabled): {seamless_path}")
final_source = image
final_path = seamless_path
if upscale and upscale != "none":
try:
if upscale is True or upscale == "topaz":
final_path = run_topaz(seamless_path, tempdir)
elif upscale == "realesrgan":
final_path = run_realesrgan(
final_source,
tempdir,
scale=REALESRGAN_SCALE,
model_path=REALESRGAN_MODEL,
progress_cb=progress_cb
)
else:
raise ValueError(f"Unknown upscale option '{upscale}'")
except Exception as e: # noqa: BLE001
print(f"Upscaling failed ({upscale}); keeping seamless image: {e}")
shutil.move(final_path, output_path)
try:
with Image.open(output_path) as _im:
print(f"→ Final image written to {output_path} [{_im.size[0]}x{_im.size[1]}]")
except Exception:
print(f"→ Final image written to {output_path}")
return output_path
def main():
parser = argparse.ArgumentParser(
description="Generate an equirectangular HDRI image"
)
parser.add_argument('--prompt', required=True, help='Text prompt for generation')
parser.add_argument('--output', help='Output filename (PNG)')
parser.add_argument('--output-dir', default='output', help='Directory for outputs (default: output)')
parser.add_argument('--work-dir', default=os.path.dirname(os.path.abspath(__file__)), help='Working directory for temp files')
parser.add_argument(
'--upscale',
choices=['none', 'topaz', 'realesrgan'],
default='realesrgan',
help='Optional upscaler: none, topaz (legacy), realesrgan (open source; default)'
)
parser.add_argument(
'--model-path',
default=MODEL_PATH,
help='Diffusers model id or local path (default: ProGamerGov/sdxl-360-diffusion)'
)
parser.add_argument(
'--base-model',
default=BASE_SDXL_MODEL,
help='SDXL base pipeline used when the model only provides a UNet (default: stabilityai/stable-diffusion-xl-base-1.0)'
)
parser.add_argument(
'--vae-model',
default=SDXL_VAE,
help='VAE repo/path to load for SDXL models (default: stabilityai/sdxl-vae; try madebyollin/sdxl-vae-fp16-fix on Mac)'
)
parser.add_argument('--steps', type=int, default=25, help='Number of inference steps (default: 25)')
parser.add_argument('--guidance', type=float, default=4.5, help='CFG guidance scale (default: 4.5)')
parser.add_argument('--width', type=int, default=1024, help='Output width (default: 1024)')
parser.add_argument('--height', type=int, default=512, help='Output height (default: 512)')
parser.add_argument(
'--seam-inpaint',
action='store_true',
help='Patch the horizontal wrap seam by shifting, inpainting the center seam, then shifting back'
)
parser.add_argument(
'--scheduler',
choices=['euler', 'euler_a', 'heun', 'ddim', 'dpmsolver', 'dpmsolver-sde'],
help='Sampler/scheduler override; default uses the pipeline scheduler (Euler for SDXL base)'
)
args = parser.parse_args()
base = sanitize_name(args.prompt)
target = args.output or next_filename(args.output_dir, base, args.width, args.height)
output_abs = os.path.abspath(target)
try:
result_path = generate(
args.prompt,
output_abs,
args.work_dir,
upscale=args.upscale,
model_path=args.model_path,
base_model=args.base_model,
vae_model=args.vae_model,
steps=args.steps,
guidance=args.guidance,
scheduler=args.scheduler,
width=args.width,
height=args.height,
seam_inpaint=args.seam_inpaint,
)
print(result_path)
except Exception as e: # noqa: BLE001
print(f"Generation failed: {e}")
raise
if __name__ == '__main__':
main()