#!/usr/bin/env python3 """ Generate an equirectangular HDRI image using Diffusers. Optional upscaling: --upscale topaz → legacy Topaz Photo AI CLI (if installed) --upscale realesrgan → open Real-ESRGAN in-Python upscaler (pip install realesrgan==0.3.0 basicsr opencv-python) Default flow: prompt in → equirectangular PNG out. Add --seam-inpaint to patch the horizontal wrap seam. """ import argparse import gc import json import math import os import re import shutil import tempfile import torch from PIL import Image, ImageDraw import numpy as np from diffusers import ( StableDiffusionPipeline, StableDiffusionXLPipeline, DPMSolverMultistepScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, HeunDiscreteScheduler, DDIMScheduler, StableDiffusionInpaintPipeline, AutoencoderKL, UNet2DConditionModel, ) from huggingface_hub import hf_hub_download # Default panorama model. You can override via --model-path. # Uses the SDXL 360 diffusion checkpoint from ProGamerGov (single-file safetensors). MODEL_PATH = "ProGamerGov/sdxl-360-diffusion" BASE_SDXL_MODEL = "stabilityai/stable-diffusion-xl-base-1.0" INPAINT_MODEL = "Lykon/dreamshaper-8-inpainting" TOPAZ_CLI = "/Applications/Topaz Photo AI.app/Contents/MacOS/Topaz Photo AI" REALESRGAN_MODEL = "RealESRGAN_x4plus.pth" REALESRGAN_SCALE = 4 # x4 model for full upscale # Recommended VAE for SDXL checkpoints SDXL_VAE = "stabilityai/sdxl-vae" def sanitize_name(prompt: str) -> str: base = prompt.strip().lower() base = re.sub(r"\s+", "_", base) base = re.sub(r"[^a-z0-9_]+", "", base) return base or "env" def next_filename(output_dir: str, base: str, width: int, height: int) -> str: os.makedirs(output_dir, exist_ok=True) i = 1 while True: fname = f"{base}-{i}-{width}x{height}.png" candidate = os.path.join(output_dir, fname) if not os.path.exists(candidate): return candidate i += 1 def shift_image(img: Image.Image, shift: int) -> Image.Image: w, h = img.size out = Image.new("RGB", (w, h)) out.paste(img.crop((shift, 0, w, h)), (0, 0)) out.paste(img.crop((0, 0, shift, h)), (w - shift, 0)) return out def create_mask(width: int, height: int, mask_w: int) -> Image.Image: mask = Image.new("L", (width, height), 0) draw = ImageDraw.Draw(mask) left = (width - mask_w) // 2 draw.rectangle([left, 0, left + mask_w, height], fill=255) return mask def unshift_image(img: Image.Image, shift: int) -> Image.Image: w, h = img.size out = Image.new("RGB", (w, h)) out.paste(img.crop((w - shift, 0, w, h)), (0, 0)) out.paste(img.crop((0, 0, w - shift, h)), (shift, 0)) return out def select_device() -> str: if torch.backends.mps.is_available(): return "mps" if torch.cuda.is_available(): return "cuda" return "cpu" def clear_torch_cache(device: str | None = None) -> None: gc.collect() if device == "cuda" and torch.cuda.is_available(): torch.cuda.empty_cache() elif device == "mps" and torch.backends.mps.is_available() and hasattr(torch, "mps"): try: torch.mps.synchronize() except Exception: pass empty_cache = getattr(torch.mps, "empty_cache", None) if empty_cache is not None: empty_cache() def decode_latents_to_image(vae: AutoencoderKL, latents: torch.Tensor, device: str) -> Image.Image: print("→ Decoding latent image with standalone VAE…") if hasattr(vae, "enable_tiling"): vae.enable_tiling() if hasattr(vae, "enable_slicing"): vae.enable_slicing() vae.to(device) vae.eval() vae_dtype = next(vae.parameters()).dtype with torch.inference_mode(): latents = latents.to(device=device, dtype=vae_dtype) has_latents_mean = hasattr(vae.config, "latents_mean") and vae.config.latents_mean is not None has_latents_std = hasattr(vae.config, "latents_std") and vae.config.latents_std is not None if has_latents_mean and has_latents_std: latent_channels = len(vae.config.latents_mean) latents_mean = ( torch.tensor(vae.config.latents_mean) .view(1, latent_channels, 1, 1) .to(latents.device, latents.dtype) ) latents_std = ( torch.tensor(vae.config.latents_std) .view(1, latent_channels, 1, 1) .to(latents.device, latents.dtype) ) latents = latents * latents_std / vae.config.scaling_factor + latents_mean else: latents = latents / vae.config.scaling_factor decoded = vae.decode(latents, return_dict=False)[0] decoded = (decoded / 2 + 0.5).clamp(0, 1) image = decoded[0].detach().cpu().permute(1, 2, 0).float().numpy() del decoded, latents clear_torch_cache(device) return Image.fromarray((image * 255).round().astype("uint8")) def run_topaz(input_path: str, tempdir: str) -> str: print("→ Upscaling with Topaz Photo AI CLI…") result = None try: result = os.system(f'"{TOPAZ_CLI}" --cli "{input_path}" -o "{tempdir}"') except Exception as e: # noqa: BLE001 raise RuntimeError(f"Topaz invocation failed: {e}") from e if result != 0: raise RuntimeError("Topaz CLI returned non-zero exit code") upscaled_files = sorted( [os.path.join(tempdir, f) for f in os.listdir(tempdir) if f.lower().endswith('.png')], key=os.path.getmtime, reverse=True ) if not upscaled_files: raise RuntimeError("Topaz produced no PNG output") print(f"→ Upscaled file: {upscaled_files[0]}") return upscaled_files[0] def run_realesrgan( input_image: Image.Image, tempdir: str, scale: int = 4, model_path: str = REALESRGAN_MODEL, progress_cb=None ) -> str: try: # Compatibility shim for newer torchvision where functional_tensor moved/renamed import sys try: import torchvision.transforms._functional_tensor as _ft # type: ignore sys.modules.setdefault("torchvision.transforms.functional_tensor", _ft) except Exception: pass from realesrgan import RealESRGANer from basicsr.archs.rrdbnet_arch import RRDBNet import cv2 import types except Exception as e: # noqa: BLE001 raise RuntimeError( "Real-ESRGAN dependencies missing. Install with: pip install realesrgan==0.3.0 basicsr opencv-python torchvision" ) from e device = select_device() is_sdxl = "sdxl" in model_path.lower() if not model_path or not os.path.exists(model_path): raise RuntimeError( f"Real-ESRGAN model not found at {model_path!r}. " "Place RealESRGAN_x4plus.pth next to this script or update REALESRGAN_MODEL." ) img_bgr = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR) print(f"→ Upscaling with Real-ESRGAN (x{scale}) on {device}…") model = RRDBNet( num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=scale, ) upsampler = RealESRGANer( model_path=model_path, scale=scale, model=model, tile=64, # aggressive tiling to keep memory & runtime manageable tile_pad=10, pre_pad=0, half=False, # keep full precision for CPU/MPS ) # Wrap tile processing to surface progress per tile. def tile_process_with_progress(self): batch, channel, height, width = self.img.shape output_height = height * self.scale output_width = width * self.scale output_shape = (batch, channel, output_height, output_width) self.output = self.img.new_zeros(output_shape) tiles_x = math.ceil(width / self.tile_size) tiles_y = math.ceil(height / self.tile_size) total_tiles = max(1, tiles_x * tiles_y) for y in range(tiles_y): for x in range(tiles_x): ofs_x = x * self.tile_size ofs_y = y * self.tile_size input_start_x = ofs_x input_end_x = min(ofs_x + self.tile_size, width) input_start_y = ofs_y input_end_y = min(ofs_y + self.tile_size, height) input_start_x_pad = max(input_start_x - self.tile_pad, 0) input_end_x_pad = min(input_end_x + self.tile_pad, width) input_start_y_pad = max(input_start_y - self.tile_pad, 0) input_end_y_pad = min(input_end_y + self.tile_pad, height) input_tile_width = input_end_x - input_start_x input_tile_height = input_end_y - input_start_y tile_idx = y * tiles_x + x + 1 input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad] try: with torch.no_grad(): output_tile = self.model(input_tile) except RuntimeError as error: print('Error', error) if progress_cb: progress_cb("upscale", tile_idx, total_tiles) output_start_x = input_start_x * self.scale output_end_x = input_end_x * self.scale output_start_y = input_start_y * self.scale output_end_y = input_end_y * self.scale output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale output_end_x_tile = output_start_x_tile + input_tile_width * self.scale output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale output_end_y_tile = output_start_y_tile + input_tile_height * self.scale self.output[:, :, output_start_y:output_end_y, output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile, output_start_x_tile:output_end_x_tile] if progress_cb: upsampler.tile_process = types.MethodType(tile_process_with_progress, upsampler) if progress_cb: progress_cb("upscale", 0, 1) sr_img, _ = upsampler.enhance(img_bgr, outscale=scale) sr_img = Image.fromarray(cv2.cvtColor(sr_img, cv2.COLOR_BGR2RGB)) if progress_cb: progress_cb("upscale", 1, 1) out_path = os.path.join(tempdir, f"realesrgan_x{scale}.png") sr_img.save(out_path) print(f"→ Real-ESRGAN output: {out_path}") return out_path def generate( prompt: str, output_path: str, work_dir: str, upscale: str = "none", model_path: str = MODEL_PATH, base_model: str = BASE_SDXL_MODEL, vae_model: str = SDXL_VAE, steps: int = 25, guidance: float = 4.5, scheduler: str | None = None, width: int = 1024, height: int = 512, seam_inpaint: bool = False, ) -> str: # Normalize common aliases that 404 on HF aliases = { "proximasan/sdxl-360-diffusion": "ProGamerGov/sdxl-360-diffusion", "proximasan": "ProGamerGov/sdxl-360-diffusion", } model_path = aliases.get(model_path, model_path) device = select_device() is_sdxl = "sdxl" in model_path.lower() scale = guidance # keep inpaint guidance in sync with cfg guidance enable_upscale = bool(upscale and upscale != "none") os.makedirs(work_dir, exist_ok=True) with tempfile.TemporaryDirectory(dir=work_dir) as tempdir: print(f"→ Using tempdir: {tempdir}") gen_pipe = None load_errors: list[str] = [] vae = None if is_sdxl: try: vae = AutoencoderKL.from_pretrained(vae_model, subfolder="vae", torch_dtype=torch.float32) except Exception as e: # noqa: BLE001 load_errors.append(f"vae: {e}") # Try native Diffusers repo try: if is_sdxl: pipe_kwargs = {"torch_dtype": torch.float32} if vae is not None: pipe_kwargs["vae"] = vae # Some SDXL repos only ship a UNet (e.g., sdxl-360); in that case load SDXL base # and swap the UNet to keep the rest of the components consistent. unet_only = False if model_path and model_path.endswith("sdxl-360-diffusion"): try: hf_hub_download(model_path, "unet/config.json") unet_only = True except Exception: unet_only = False if unet_only: base_pipe = StableDiffusionXLPipeline.from_pretrained( base_model, **pipe_kwargs ).to(device) unet = UNet2DConditionModel.from_pretrained( model_path, subfolder="unet", torch_dtype=torch.float32 ).to(device) base_pipe.unet = unet gen_pipe = base_pipe del base_pipe, unet else: gen_pipe = StableDiffusionXLPipeline.from_pretrained( model_path, **pipe_kwargs ).to(device) else: gen_pipe = StableDiffusionPipeline.from_pretrained( model_path, torch_dtype=torch.float32 ).to(device) except Exception as e: # noqa: BLE001 load_errors.append(f"from_pretrained: {e}") # Fallback: single-file SDXL checkpoint from the repo if gen_pipe is None: ckpt_candidates = [ "sdxl_360_diffusion.safetensors", "sdxl_360_diffusion_unet.safetensors", "model.safetensors", ] last_err = None for fname in ckpt_candidates: try: ckpt = hf_hub_download(model_path, fname) pipe_kwargs = {"torch_dtype": torch.float32} if vae is not None: pipe_kwargs["vae"] = vae gen_pipe = StableDiffusionXLPipeline.from_single_file( ckpt, **pipe_kwargs ).to(device) break except Exception as e2: # noqa: BLE001 last_err = e2 load_errors.append(f"{fname}: {e2}") if gen_pipe is None: raise RuntimeError( f"Failed to load model '{model_path}'. " "Ensure the repo/path exists (e.g., ProGamerGov/sdxl-360-diffusion) and " "install accelerate for low_cpu_mem_usage: pip install accelerate. " f"Errors: {load_errors}" ) from (last_err or Exception("No pipeline loaded")) if gen_pipe.vae is None and is_sdxl: try: vae = vae or AutoencoderKL.from_pretrained( SDXL_VAE, subfolder="vae", torch_dtype=torch.float32 ) gen_pipe.vae = vae.to(device) gen_pipe.to(device) except Exception as e: # noqa: BLE001 load_errors.append(f"vae-fallback: {e}") raise RuntimeError( "Loaded SDXL pipeline without a VAE; failed to attach the SDXL VAE. " f"Errors: {load_errors}" ) from e # Optionally override scheduler; otherwise keep the pipeline default (Euler for SDXL base). if scheduler: sched_kind = scheduler.lower() sched_cfg = gen_pipe.scheduler.config if sched_kind in {"dpmsolver", "dpmsolver++"}: gen_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sched_cfg) elif sched_kind in {"dpmsolver-sde", "dpmsolver_sde"}: gen_pipe.scheduler = DPMSolverMultistepScheduler.from_config( sched_cfg, algorithm_type="sde-dpmsolver++" ) elif sched_kind in {"euler"}: gen_pipe.scheduler = EulerDiscreteScheduler.from_config(sched_cfg) elif sched_kind in {"euler_a", "euler-ancestral", "euler-ancestral-discrete"}: gen_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sched_cfg) elif sched_kind in {"heun"}: gen_pipe.scheduler = HeunDiscreteScheduler.from_config(sched_cfg) elif sched_kind in {"ddim"}: gen_pipe.scheduler = DDIMScheduler.from_config(sched_cfg) else: raise ValueError( f"Unsupported scheduler '{scheduler}'. " "Try one of: euler, euler_a, heun, ddim, dpmsolver, dpmsolver-sde." ) gen_pipe.enable_attention_slicing() if hasattr(gen_pipe, "enable_vae_tiling"): gen_pipe.enable_vae_tiling() if hasattr(gen_pipe, "enable_vae_slicing"): gen_pipe.enable_vae_slicing() if "pipe_kwargs" in locals(): del pipe_kwargs def progress_cb(phase: str, current: int, total: int): payload = { "phase": phase, "current": current, "total": total, "upscale": enable_upscale, "seamInpaint": seam_inpaint, } print(f"PROGRESS {json.dumps(payload)}", flush=True) print("→ Generating equirectangular HDRI…") progress_cb("gen", 0, steps) image = gen_pipe( prompt=prompt, num_inference_steps=steps, guidance_scale=guidance, width=width, height=height, callback_steps=1, callback=lambda step, timestep, kwargs: progress_cb("gen", step + 1, steps), output_type="latent", ) latents = latent_output.images.detach().cpu() decoder_vae = gen_pipe.vae vae = None del latent_output, gen_pipe clear_torch_cache(device) progress_cb("decode", 0, 1) image = decode_latents_to_image(decoder_vae, latents, device) progress_cb("decode", 1, 1) del latents, decoder_vae clear_torch_cache(device) gen_path = os.path.join(tempdir, f"base_{width}x{height}.png") image.save(gen_path) print(f"→ Saved initial image to {gen_path}") seamless_path = os.path.join(tempdir, os.path.basename(output_path)) if seam_inpaint: shift_amt = width // 2 mask_w = width // 8 shifted = shift_image(image, shift_amt) mask = create_mask(width, height, mask_w) inpaint_pipe = StableDiffusionInpaintPipeline.from_pretrained( INPAINT_MODEL, torch_dtype=torch.float32 ).to(device) inpaint_pipe.enable_attention_slicing() print("→ Inpainting seam for seamless tiling…") progress_cb("inpaint", 0, steps) inpainted = inpaint_pipe( prompt=prompt, image=shifted, mask_image=mask, num_inference_steps=steps, guidance_scale=scale, width=width, height=height, callback_steps=1, callback=lambda step, timestep, kwargs: progress_cb("inpaint", step + 1, steps), ).images[0] inpainted = unshift_image(inpainted, shift_amt) inpainted.save(seamless_path) print(f"→ Crafted seamless image: {seamless_path}") final_source = inpainted else: image.save(seamless_path) print(f"→ Using raw output (seam inpaint disabled): {seamless_path}") final_source = image final_path = seamless_path if upscale and upscale != "none": try: if upscale is True or upscale == "topaz": final_path = run_topaz(seamless_path, tempdir) elif upscale == "realesrgan": final_path = run_realesrgan( final_source, tempdir, scale=REALESRGAN_SCALE, model_path=REALESRGAN_MODEL, progress_cb=progress_cb ) else: raise ValueError(f"Unknown upscale option '{upscale}'") except Exception as e: # noqa: BLE001 print(f"Upscaling failed ({upscale}); keeping seamless image: {e}") shutil.move(final_path, output_path) try: with Image.open(output_path) as _im: print(f"→ Final image written to {output_path} [{_im.size[0]}x{_im.size[1]}]") except Exception: print(f"→ Final image written to {output_path}") return output_path def main(): parser = argparse.ArgumentParser( description="Generate an equirectangular HDRI image" ) parser.add_argument('--prompt', required=True, help='Text prompt for generation') parser.add_argument('--output', help='Output filename (PNG)') parser.add_argument('--output-dir', default='output', help='Directory for outputs (default: output)') parser.add_argument('--work-dir', default=os.path.dirname(os.path.abspath(__file__)), help='Working directory for temp files') parser.add_argument( '--upscale', choices=['none', 'topaz', 'realesrgan'], default='realesrgan', help='Optional upscaler: none, topaz (legacy), realesrgan (open source; default)' ) parser.add_argument( '--model-path', default=MODEL_PATH, help='Diffusers model id or local path (default: ProGamerGov/sdxl-360-diffusion)' ) parser.add_argument( '--base-model', default=BASE_SDXL_MODEL, help='SDXL base pipeline used when the model only provides a UNet (default: stabilityai/stable-diffusion-xl-base-1.0)' ) parser.add_argument( '--vae-model', default=SDXL_VAE, help='VAE repo/path to load for SDXL models (default: stabilityai/sdxl-vae; try madebyollin/sdxl-vae-fp16-fix on Mac)' ) parser.add_argument('--steps', type=int, default=25, help='Number of inference steps (default: 25)') parser.add_argument('--guidance', type=float, default=4.5, help='CFG guidance scale (default: 4.5)') parser.add_argument('--width', type=int, default=1024, help='Output width (default: 1024)') parser.add_argument('--height', type=int, default=512, help='Output height (default: 512)') parser.add_argument( '--seam-inpaint', action='store_true', help='Patch the horizontal wrap seam by shifting, inpainting the center seam, then shifting back' ) parser.add_argument( '--scheduler', choices=['euler', 'euler_a', 'heun', 'ddim', 'dpmsolver', 'dpmsolver-sde'], help='Sampler/scheduler override; default uses the pipeline scheduler (Euler for SDXL base)' ) args = parser.parse_args() base = sanitize_name(args.prompt) target = args.output or next_filename(args.output_dir, base, args.width, args.height) output_abs = os.path.abspath(target) try: result_path = generate( args.prompt, output_abs, args.work_dir, upscale=args.upscale, model_path=args.model_path, base_model=args.base_model, vae_model=args.vae_model, steps=args.steps, guidance=args.guidance, scheduler=args.scheduler, width=args.width, height=args.height, seam_inpaint=args.seam_inpaint, ) print(result_path) except Exception as e: # noqa: BLE001 print(f"Generation failed: {e}") raise if __name__ == '__main__': main()