610 lines
23 KiB
Python
610 lines
23 KiB
Python
|
|
#!/usr/bin/env python3
|
||
|
|
"""
|
||
|
|
Generate an equirectangular HDRI image using Diffusers.
|
||
|
|
Optional upscaling:
|
||
|
|
--upscale topaz → legacy Topaz Photo AI CLI (if installed)
|
||
|
|
--upscale realesrgan → open Real-ESRGAN in-Python upscaler (pip install realesrgan==0.3.0 basicsr opencv-python)
|
||
|
|
Default flow: prompt in → equirectangular PNG out. Add --seam-inpaint to patch the horizontal wrap seam.
|
||
|
|
"""
|
||
|
|
import argparse
|
||
|
|
import gc
|
||
|
|
import json
|
||
|
|
import math
|
||
|
|
import os
|
||
|
|
import re
|
||
|
|
import shutil
|
||
|
|
import tempfile
|
||
|
|
import torch
|
||
|
|
from PIL import Image, ImageDraw
|
||
|
|
import numpy as np
|
||
|
|
|
||
|
|
from diffusers import (
|
||
|
|
StableDiffusionPipeline,
|
||
|
|
StableDiffusionXLPipeline,
|
||
|
|
DPMSolverMultistepScheduler,
|
||
|
|
EulerDiscreteScheduler,
|
||
|
|
EulerAncestralDiscreteScheduler,
|
||
|
|
HeunDiscreteScheduler,
|
||
|
|
DDIMScheduler,
|
||
|
|
StableDiffusionInpaintPipeline,
|
||
|
|
AutoencoderKL,
|
||
|
|
UNet2DConditionModel,
|
||
|
|
)
|
||
|
|
from huggingface_hub import hf_hub_download
|
||
|
|
|
||
|
|
# Default panorama model. You can override via --model-path.
|
||
|
|
# Uses the SDXL 360 diffusion checkpoint from ProGamerGov (single-file safetensors).
|
||
|
|
MODEL_PATH = "ProGamerGov/sdxl-360-diffusion"
|
||
|
|
BASE_SDXL_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
|
||
|
|
INPAINT_MODEL = "Lykon/dreamshaper-8-inpainting"
|
||
|
|
TOPAZ_CLI = "/Applications/Topaz Photo AI.app/Contents/MacOS/Topaz Photo AI"
|
||
|
|
REALESRGAN_MODEL = "RealESRGAN_x4plus.pth"
|
||
|
|
REALESRGAN_SCALE = 4 # x4 model for full upscale
|
||
|
|
# Recommended VAE for SDXL checkpoints
|
||
|
|
SDXL_VAE = "stabilityai/sdxl-vae"
|
||
|
|
|
||
|
|
|
||
|
|
def sanitize_name(prompt: str) -> str:
|
||
|
|
base = prompt.strip().lower()
|
||
|
|
base = re.sub(r"\s+", "_", base)
|
||
|
|
base = re.sub(r"[^a-z0-9_]+", "", base)
|
||
|
|
return base or "env"
|
||
|
|
|
||
|
|
|
||
|
|
def next_filename(output_dir: str, base: str, width: int, height: int) -> str:
|
||
|
|
os.makedirs(output_dir, exist_ok=True)
|
||
|
|
i = 1
|
||
|
|
while True:
|
||
|
|
fname = f"{base}-{i}-{width}x{height}.png"
|
||
|
|
candidate = os.path.join(output_dir, fname)
|
||
|
|
if not os.path.exists(candidate):
|
||
|
|
return candidate
|
||
|
|
i += 1
|
||
|
|
|
||
|
|
|
||
|
|
def shift_image(img: Image.Image, shift: int) -> Image.Image:
|
||
|
|
w, h = img.size
|
||
|
|
out = Image.new("RGB", (w, h))
|
||
|
|
out.paste(img.crop((shift, 0, w, h)), (0, 0))
|
||
|
|
out.paste(img.crop((0, 0, shift, h)), (w - shift, 0))
|
||
|
|
return out
|
||
|
|
|
||
|
|
|
||
|
|
def create_mask(width: int, height: int, mask_w: int) -> Image.Image:
|
||
|
|
mask = Image.new("L", (width, height), 0)
|
||
|
|
draw = ImageDraw.Draw(mask)
|
||
|
|
left = (width - mask_w) // 2
|
||
|
|
draw.rectangle([left, 0, left + mask_w, height], fill=255)
|
||
|
|
return mask
|
||
|
|
|
||
|
|
|
||
|
|
def unshift_image(img: Image.Image, shift: int) -> Image.Image:
|
||
|
|
w, h = img.size
|
||
|
|
out = Image.new("RGB", (w, h))
|
||
|
|
out.paste(img.crop((w - shift, 0, w, h)), (0, 0))
|
||
|
|
out.paste(img.crop((0, 0, w - shift, h)), (shift, 0))
|
||
|
|
return out
|
||
|
|
|
||
|
|
|
||
|
|
def select_device() -> str:
|
||
|
|
if torch.backends.mps.is_available():
|
||
|
|
return "mps"
|
||
|
|
if torch.cuda.is_available():
|
||
|
|
return "cuda"
|
||
|
|
return "cpu"
|
||
|
|
|
||
|
|
|
||
|
|
def clear_torch_cache(device: str | None = None) -> None:
|
||
|
|
gc.collect()
|
||
|
|
if device == "cuda" and torch.cuda.is_available():
|
||
|
|
torch.cuda.empty_cache()
|
||
|
|
elif device == "mps" and torch.backends.mps.is_available() and hasattr(torch, "mps"):
|
||
|
|
try:
|
||
|
|
torch.mps.synchronize()
|
||
|
|
except Exception:
|
||
|
|
pass
|
||
|
|
empty_cache = getattr(torch.mps, "empty_cache", None)
|
||
|
|
if empty_cache is not None:
|
||
|
|
empty_cache()
|
||
|
|
|
||
|
|
|
||
|
|
def decode_latents_to_image(vae: AutoencoderKL, latents: torch.Tensor, device: str) -> Image.Image:
|
||
|
|
print("→ Decoding latent image with standalone VAE…")
|
||
|
|
if hasattr(vae, "enable_tiling"):
|
||
|
|
vae.enable_tiling()
|
||
|
|
if hasattr(vae, "enable_slicing"):
|
||
|
|
vae.enable_slicing()
|
||
|
|
|
||
|
|
vae.to(device)
|
||
|
|
vae.eval()
|
||
|
|
vae_dtype = next(vae.parameters()).dtype
|
||
|
|
|
||
|
|
with torch.inference_mode():
|
||
|
|
latents = latents.to(device=device, dtype=vae_dtype)
|
||
|
|
has_latents_mean = hasattr(vae.config, "latents_mean") and vae.config.latents_mean is not None
|
||
|
|
has_latents_std = hasattr(vae.config, "latents_std") and vae.config.latents_std is not None
|
||
|
|
if has_latents_mean and has_latents_std:
|
||
|
|
latent_channels = len(vae.config.latents_mean)
|
||
|
|
latents_mean = (
|
||
|
|
torch.tensor(vae.config.latents_mean)
|
||
|
|
.view(1, latent_channels, 1, 1)
|
||
|
|
.to(latents.device, latents.dtype)
|
||
|
|
)
|
||
|
|
latents_std = (
|
||
|
|
torch.tensor(vae.config.latents_std)
|
||
|
|
.view(1, latent_channels, 1, 1)
|
||
|
|
.to(latents.device, latents.dtype)
|
||
|
|
)
|
||
|
|
latents = latents * latents_std / vae.config.scaling_factor + latents_mean
|
||
|
|
else:
|
||
|
|
latents = latents / vae.config.scaling_factor
|
||
|
|
|
||
|
|
decoded = vae.decode(latents, return_dict=False)[0]
|
||
|
|
decoded = (decoded / 2 + 0.5).clamp(0, 1)
|
||
|
|
image = decoded[0].detach().cpu().permute(1, 2, 0).float().numpy()
|
||
|
|
|
||
|
|
del decoded, latents
|
||
|
|
clear_torch_cache(device)
|
||
|
|
return Image.fromarray((image * 255).round().astype("uint8"))
|
||
|
|
|
||
|
|
|
||
|
|
def run_topaz(input_path: str, tempdir: str) -> str:
|
||
|
|
print("→ Upscaling with Topaz Photo AI CLI…")
|
||
|
|
result = None
|
||
|
|
try:
|
||
|
|
result = os.system(f'"{TOPAZ_CLI}" --cli "{input_path}" -o "{tempdir}"')
|
||
|
|
except Exception as e: # noqa: BLE001
|
||
|
|
raise RuntimeError(f"Topaz invocation failed: {e}") from e
|
||
|
|
if result != 0:
|
||
|
|
raise RuntimeError("Topaz CLI returned non-zero exit code")
|
||
|
|
upscaled_files = sorted(
|
||
|
|
[os.path.join(tempdir, f) for f in os.listdir(tempdir) if f.lower().endswith('.png')],
|
||
|
|
key=os.path.getmtime,
|
||
|
|
reverse=True
|
||
|
|
)
|
||
|
|
if not upscaled_files:
|
||
|
|
raise RuntimeError("Topaz produced no PNG output")
|
||
|
|
print(f"→ Upscaled file: {upscaled_files[0]}")
|
||
|
|
return upscaled_files[0]
|
||
|
|
|
||
|
|
|
||
|
|
def run_realesrgan(
|
||
|
|
input_image: Image.Image,
|
||
|
|
tempdir: str,
|
||
|
|
scale: int = 4,
|
||
|
|
model_path: str = REALESRGAN_MODEL,
|
||
|
|
progress_cb=None
|
||
|
|
) -> str:
|
||
|
|
try:
|
||
|
|
# Compatibility shim for newer torchvision where functional_tensor moved/renamed
|
||
|
|
import sys
|
||
|
|
try:
|
||
|
|
import torchvision.transforms._functional_tensor as _ft # type: ignore
|
||
|
|
sys.modules.setdefault("torchvision.transforms.functional_tensor", _ft)
|
||
|
|
except Exception:
|
||
|
|
pass
|
||
|
|
from realesrgan import RealESRGANer
|
||
|
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||
|
|
import cv2
|
||
|
|
import types
|
||
|
|
except Exception as e: # noqa: BLE001
|
||
|
|
raise RuntimeError(
|
||
|
|
"Real-ESRGAN dependencies missing. Install with: pip install realesrgan==0.3.0 basicsr opencv-python torchvision"
|
||
|
|
) from e
|
||
|
|
|
||
|
|
device = select_device()
|
||
|
|
is_sdxl = "sdxl" in model_path.lower()
|
||
|
|
if not model_path or not os.path.exists(model_path):
|
||
|
|
raise RuntimeError(
|
||
|
|
f"Real-ESRGAN model not found at {model_path!r}. "
|
||
|
|
"Place RealESRGAN_x4plus.pth next to this script or update REALESRGAN_MODEL."
|
||
|
|
)
|
||
|
|
img_bgr = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR)
|
||
|
|
print(f"→ Upscaling with Real-ESRGAN (x{scale}) on {device}…")
|
||
|
|
model = RRDBNet(
|
||
|
|
num_in_ch=3,
|
||
|
|
num_out_ch=3,
|
||
|
|
num_feat=64,
|
||
|
|
num_block=23,
|
||
|
|
num_grow_ch=32,
|
||
|
|
scale=scale,
|
||
|
|
)
|
||
|
|
upsampler = RealESRGANer(
|
||
|
|
model_path=model_path,
|
||
|
|
scale=scale,
|
||
|
|
model=model,
|
||
|
|
tile=64, # aggressive tiling to keep memory & runtime manageable
|
||
|
|
tile_pad=10,
|
||
|
|
pre_pad=0,
|
||
|
|
half=False, # keep full precision for CPU/MPS
|
||
|
|
)
|
||
|
|
|
||
|
|
# Wrap tile processing to surface progress per tile.
|
||
|
|
def tile_process_with_progress(self):
|
||
|
|
batch, channel, height, width = self.img.shape
|
||
|
|
output_height = height * self.scale
|
||
|
|
output_width = width * self.scale
|
||
|
|
output_shape = (batch, channel, output_height, output_width)
|
||
|
|
|
||
|
|
self.output = self.img.new_zeros(output_shape)
|
||
|
|
tiles_x = math.ceil(width / self.tile_size)
|
||
|
|
tiles_y = math.ceil(height / self.tile_size)
|
||
|
|
total_tiles = max(1, tiles_x * tiles_y)
|
||
|
|
|
||
|
|
for y in range(tiles_y):
|
||
|
|
for x in range(tiles_x):
|
||
|
|
ofs_x = x * self.tile_size
|
||
|
|
ofs_y = y * self.tile_size
|
||
|
|
input_start_x = ofs_x
|
||
|
|
input_end_x = min(ofs_x + self.tile_size, width)
|
||
|
|
input_start_y = ofs_y
|
||
|
|
input_end_y = min(ofs_y + self.tile_size, height)
|
||
|
|
|
||
|
|
input_start_x_pad = max(input_start_x - self.tile_pad, 0)
|
||
|
|
input_end_x_pad = min(input_end_x + self.tile_pad, width)
|
||
|
|
input_start_y_pad = max(input_start_y - self.tile_pad, 0)
|
||
|
|
input_end_y_pad = min(input_end_y + self.tile_pad, height)
|
||
|
|
|
||
|
|
input_tile_width = input_end_x - input_start_x
|
||
|
|
input_tile_height = input_end_y - input_start_y
|
||
|
|
tile_idx = y * tiles_x + x + 1
|
||
|
|
input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
|
||
|
|
|
||
|
|
try:
|
||
|
|
with torch.no_grad():
|
||
|
|
output_tile = self.model(input_tile)
|
||
|
|
except RuntimeError as error:
|
||
|
|
print('Error', error)
|
||
|
|
|
||
|
|
if progress_cb:
|
||
|
|
progress_cb("upscale", tile_idx, total_tiles)
|
||
|
|
|
||
|
|
output_start_x = input_start_x * self.scale
|
||
|
|
output_end_x = input_end_x * self.scale
|
||
|
|
output_start_y = input_start_y * self.scale
|
||
|
|
output_end_y = input_end_y * self.scale
|
||
|
|
|
||
|
|
output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
|
||
|
|
output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
|
||
|
|
output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
|
||
|
|
output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
|
||
|
|
|
||
|
|
self.output[:, :, output_start_y:output_end_y,
|
||
|
|
output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
|
||
|
|
output_start_x_tile:output_end_x_tile]
|
||
|
|
|
||
|
|
if progress_cb:
|
||
|
|
upsampler.tile_process = types.MethodType(tile_process_with_progress, upsampler)
|
||
|
|
|
||
|
|
if progress_cb:
|
||
|
|
progress_cb("upscale", 0, 1)
|
||
|
|
sr_img, _ = upsampler.enhance(img_bgr, outscale=scale)
|
||
|
|
sr_img = Image.fromarray(cv2.cvtColor(sr_img, cv2.COLOR_BGR2RGB))
|
||
|
|
if progress_cb:
|
||
|
|
progress_cb("upscale", 1, 1)
|
||
|
|
out_path = os.path.join(tempdir, f"realesrgan_x{scale}.png")
|
||
|
|
sr_img.save(out_path)
|
||
|
|
print(f"→ Real-ESRGAN output: {out_path}")
|
||
|
|
return out_path
|
||
|
|
|
||
|
|
|
||
|
|
def generate(
|
||
|
|
prompt: str,
|
||
|
|
output_path: str,
|
||
|
|
work_dir: str,
|
||
|
|
upscale: str = "none",
|
||
|
|
model_path: str = MODEL_PATH,
|
||
|
|
base_model: str = BASE_SDXL_MODEL,
|
||
|
|
vae_model: str = SDXL_VAE,
|
||
|
|
steps: int = 25,
|
||
|
|
guidance: float = 4.5,
|
||
|
|
scheduler: str | None = None,
|
||
|
|
width: int = 1024,
|
||
|
|
height: int = 512,
|
||
|
|
seam_inpaint: bool = False,
|
||
|
|
) -> str:
|
||
|
|
# Normalize common aliases that 404 on HF
|
||
|
|
aliases = {
|
||
|
|
"proximasan/sdxl-360-diffusion": "ProGamerGov/sdxl-360-diffusion",
|
||
|
|
"proximasan": "ProGamerGov/sdxl-360-diffusion",
|
||
|
|
}
|
||
|
|
model_path = aliases.get(model_path, model_path)
|
||
|
|
|
||
|
|
device = select_device()
|
||
|
|
is_sdxl = "sdxl" in model_path.lower()
|
||
|
|
scale = guidance # keep inpaint guidance in sync with cfg guidance
|
||
|
|
enable_upscale = bool(upscale and upscale != "none")
|
||
|
|
|
||
|
|
os.makedirs(work_dir, exist_ok=True)
|
||
|
|
|
||
|
|
with tempfile.TemporaryDirectory(dir=work_dir) as tempdir:
|
||
|
|
print(f"→ Using tempdir: {tempdir}")
|
||
|
|
|
||
|
|
gen_pipe = None
|
||
|
|
load_errors: list[str] = []
|
||
|
|
vae = None
|
||
|
|
if is_sdxl:
|
||
|
|
try:
|
||
|
|
vae = AutoencoderKL.from_pretrained(vae_model, subfolder="vae", torch_dtype=torch.float32)
|
||
|
|
except Exception as e: # noqa: BLE001
|
||
|
|
load_errors.append(f"vae: {e}")
|
||
|
|
|
||
|
|
# Try native Diffusers repo
|
||
|
|
try:
|
||
|
|
if is_sdxl:
|
||
|
|
pipe_kwargs = {"torch_dtype": torch.float32}
|
||
|
|
if vae is not None:
|
||
|
|
pipe_kwargs["vae"] = vae
|
||
|
|
|
||
|
|
# Some SDXL repos only ship a UNet (e.g., sdxl-360); in that case load SDXL base
|
||
|
|
# and swap the UNet to keep the rest of the components consistent.
|
||
|
|
unet_only = False
|
||
|
|
if model_path and model_path.endswith("sdxl-360-diffusion"):
|
||
|
|
try:
|
||
|
|
hf_hub_download(model_path, "unet/config.json")
|
||
|
|
unet_only = True
|
||
|
|
except Exception:
|
||
|
|
unet_only = False
|
||
|
|
|
||
|
|
if unet_only:
|
||
|
|
base_pipe = StableDiffusionXLPipeline.from_pretrained(
|
||
|
|
base_model,
|
||
|
|
**pipe_kwargs
|
||
|
|
).to(device)
|
||
|
|
unet = UNet2DConditionModel.from_pretrained(
|
||
|
|
model_path,
|
||
|
|
subfolder="unet",
|
||
|
|
torch_dtype=torch.float32
|
||
|
|
).to(device)
|
||
|
|
base_pipe.unet = unet
|
||
|
|
gen_pipe = base_pipe
|
||
|
|
else:
|
||
|
|
gen_pipe = StableDiffusionXLPipeline.from_pretrained(
|
||
|
|
model_path,
|
||
|
|
**pipe_kwargs
|
||
|
|
).to(device)
|
||
|
|
else:
|
||
|
|
gen_pipe = StableDiffusionPipeline.from_pretrained(
|
||
|
|
model_path,
|
||
|
|
torch_dtype=torch.float32
|
||
|
|
).to(device)
|
||
|
|
except Exception as e: # noqa: BLE001
|
||
|
|
load_errors.append(f"from_pretrained: {e}")
|
||
|
|
|
||
|
|
# Fallback: single-file SDXL checkpoint from the repo
|
||
|
|
if gen_pipe is None:
|
||
|
|
ckpt_candidates = [
|
||
|
|
"sdxl_360_diffusion.safetensors",
|
||
|
|
"sdxl_360_diffusion_unet.safetensors",
|
||
|
|
"model.safetensors",
|
||
|
|
]
|
||
|
|
last_err = None
|
||
|
|
for fname in ckpt_candidates:
|
||
|
|
try:
|
||
|
|
ckpt = hf_hub_download(model_path, fname)
|
||
|
|
pipe_kwargs = {"torch_dtype": torch.float32}
|
||
|
|
if vae is not None:
|
||
|
|
pipe_kwargs["vae"] = vae
|
||
|
|
gen_pipe = StableDiffusionXLPipeline.from_single_file(
|
||
|
|
ckpt,
|
||
|
|
**pipe_kwargs
|
||
|
|
).to(device)
|
||
|
|
break
|
||
|
|
except Exception as e2: # noqa: BLE001
|
||
|
|
last_err = e2
|
||
|
|
load_errors.append(f"{fname}: {e2}")
|
||
|
|
if gen_pipe is None:
|
||
|
|
raise RuntimeError(
|
||
|
|
f"Failed to load model '{model_path}'. "
|
||
|
|
"Ensure the repo/path exists (e.g., ProGamerGov/sdxl-360-diffusion) and "
|
||
|
|
"install accelerate for low_cpu_mem_usage: pip install accelerate. "
|
||
|
|
f"Errors: {load_errors}"
|
||
|
|
) from (last_err or Exception("No pipeline loaded"))
|
||
|
|
if gen_pipe.vae is None and is_sdxl:
|
||
|
|
try:
|
||
|
|
vae = vae or AutoencoderKL.from_pretrained(
|
||
|
|
SDXL_VAE,
|
||
|
|
subfolder="vae",
|
||
|
|
torch_dtype=torch.float32
|
||
|
|
)
|
||
|
|
gen_pipe.vae = vae.to(device)
|
||
|
|
gen_pipe.to(device)
|
||
|
|
except Exception as e: # noqa: BLE001
|
||
|
|
load_errors.append(f"vae-fallback: {e}")
|
||
|
|
raise RuntimeError(
|
||
|
|
"Loaded SDXL pipeline without a VAE; failed to attach the SDXL VAE. "
|
||
|
|
f"Errors: {load_errors}"
|
||
|
|
) from e
|
||
|
|
# Optionally override scheduler; otherwise keep the pipeline default (Euler for SDXL base).
|
||
|
|
if scheduler:
|
||
|
|
sched_kind = scheduler.lower()
|
||
|
|
sched_cfg = gen_pipe.scheduler.config
|
||
|
|
if sched_kind in {"dpmsolver", "dpmsolver++"}:
|
||
|
|
gen_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sched_cfg)
|
||
|
|
elif sched_kind in {"dpmsolver-sde", "dpmsolver_sde"}:
|
||
|
|
gen_pipe.scheduler = DPMSolverMultistepScheduler.from_config(
|
||
|
|
sched_cfg,
|
||
|
|
algorithm_type="sde-dpmsolver++"
|
||
|
|
)
|
||
|
|
elif sched_kind in {"euler"}:
|
||
|
|
gen_pipe.scheduler = EulerDiscreteScheduler.from_config(sched_cfg)
|
||
|
|
elif sched_kind in {"euler_a", "euler-ancestral", "euler-ancestral-discrete"}:
|
||
|
|
gen_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sched_cfg)
|
||
|
|
elif sched_kind in {"heun"}:
|
||
|
|
gen_pipe.scheduler = HeunDiscreteScheduler.from_config(sched_cfg)
|
||
|
|
elif sched_kind in {"ddim"}:
|
||
|
|
gen_pipe.scheduler = DDIMScheduler.from_config(sched_cfg)
|
||
|
|
else:
|
||
|
|
raise ValueError(
|
||
|
|
f"Unsupported scheduler '{scheduler}'. "
|
||
|
|
"Try one of: euler, euler_a, heun, ddim, dpmsolver, dpmsolver-sde."
|
||
|
|
)
|
||
|
|
gen_pipe.enable_attention_slicing()
|
||
|
|
if is_sdxl and vae is not None and hasattr(gen_pipe, "enable_vae_tiling"):
|
||
|
|
gen_pipe.enable_vae_tiling()
|
||
|
|
|
||
|
|
def progress_cb(phase: str, current: int, total: int):
|
||
|
|
payload = {
|
||
|
|
"phase": phase,
|
||
|
|
"current": current,
|
||
|
|
"total": total,
|
||
|
|
"upscale": enable_upscale,
|
||
|
|
"seamInpaint": seam_inpaint,
|
||
|
|
}
|
||
|
|
print(f"PROGRESS {json.dumps(payload)}", flush=True)
|
||
|
|
|
||
|
|
print("→ Generating equirectangular HDRI…")
|
||
|
|
progress_cb("gen", 0, steps)
|
||
|
|
image = gen_pipe(
|
||
|
|
prompt=prompt,
|
||
|
|
num_inference_steps=steps,
|
||
|
|
guidance_scale=guidance,
|
||
|
|
width=width,
|
||
|
|
height=height,
|
||
|
|
callback_steps=1,
|
||
|
|
callback=lambda step, timestep, kwargs: progress_cb("gen", step + 1, steps),
|
||
|
|
).images[0]
|
||
|
|
|
||
|
|
gen_path = os.path.join(tempdir, f"base_{width}x{height}.png")
|
||
|
|
image.save(gen_path)
|
||
|
|
print(f"→ Saved initial image to {gen_path}")
|
||
|
|
|
||
|
|
seamless_path = os.path.join(tempdir, os.path.basename(output_path))
|
||
|
|
if seam_inpaint:
|
||
|
|
shift_amt = width // 2
|
||
|
|
mask_w = width // 8
|
||
|
|
|
||
|
|
shifted = shift_image(image, shift_amt)
|
||
|
|
mask = create_mask(width, height, mask_w)
|
||
|
|
|
||
|
|
inpaint_pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||
|
|
INPAINT_MODEL,
|
||
|
|
torch_dtype=torch.float32
|
||
|
|
).to(device)
|
||
|
|
inpaint_pipe.enable_attention_slicing()
|
||
|
|
|
||
|
|
print("→ Inpainting seam for seamless tiling…")
|
||
|
|
progress_cb("inpaint", 0, steps)
|
||
|
|
inpainted = inpaint_pipe(
|
||
|
|
prompt=prompt,
|
||
|
|
image=shifted,
|
||
|
|
mask_image=mask,
|
||
|
|
num_inference_steps=steps,
|
||
|
|
guidance_scale=scale,
|
||
|
|
width=width,
|
||
|
|
height=height,
|
||
|
|
callback_steps=1,
|
||
|
|
callback=lambda step, timestep, kwargs: progress_cb("inpaint", step + 1, steps),
|
||
|
|
).images[0]
|
||
|
|
|
||
|
|
inpainted = unshift_image(inpainted, shift_amt)
|
||
|
|
inpainted.save(seamless_path)
|
||
|
|
print(f"→ Crafted seamless image: {seamless_path}")
|
||
|
|
final_source = inpainted
|
||
|
|
else:
|
||
|
|
image.save(seamless_path)
|
||
|
|
print(f"→ Using raw output (seam inpaint disabled): {seamless_path}")
|
||
|
|
final_source = image
|
||
|
|
|
||
|
|
final_path = seamless_path
|
||
|
|
|
||
|
|
if upscale and upscale != "none":
|
||
|
|
try:
|
||
|
|
if upscale is True or upscale == "topaz":
|
||
|
|
final_path = run_topaz(seamless_path, tempdir)
|
||
|
|
elif upscale == "realesrgan":
|
||
|
|
final_path = run_realesrgan(
|
||
|
|
final_source,
|
||
|
|
tempdir,
|
||
|
|
scale=REALESRGAN_SCALE,
|
||
|
|
model_path=REALESRGAN_MODEL,
|
||
|
|
progress_cb=progress_cb
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
raise ValueError(f"Unknown upscale option '{upscale}'")
|
||
|
|
except Exception as e: # noqa: BLE001
|
||
|
|
print(f"Upscaling failed ({upscale}); keeping seamless image: {e}")
|
||
|
|
|
||
|
|
shutil.move(final_path, output_path)
|
||
|
|
try:
|
||
|
|
with Image.open(output_path) as _im:
|
||
|
|
print(f"→ Final image written to {output_path} [{_im.size[0]}x{_im.size[1]}]")
|
||
|
|
except Exception:
|
||
|
|
print(f"→ Final image written to {output_path}")
|
||
|
|
return output_path
|
||
|
|
|
||
|
|
|
||
|
|
def main():
|
||
|
|
parser = argparse.ArgumentParser(
|
||
|
|
description="Generate an equirectangular HDRI image"
|
||
|
|
)
|
||
|
|
parser.add_argument('--prompt', required=True, help='Text prompt for generation')
|
||
|
|
parser.add_argument('--output', help='Output filename (PNG)')
|
||
|
|
parser.add_argument('--output-dir', default='output', help='Directory for outputs (default: output)')
|
||
|
|
parser.add_argument('--work-dir', default=os.path.dirname(os.path.abspath(__file__)), help='Working directory for temp files')
|
||
|
|
parser.add_argument(
|
||
|
|
'--upscale',
|
||
|
|
choices=['none', 'topaz', 'realesrgan'],
|
||
|
|
default='realesrgan',
|
||
|
|
help='Optional upscaler: none, topaz (legacy), realesrgan (open source; default)'
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
'--model-path',
|
||
|
|
default=MODEL_PATH,
|
||
|
|
help='Diffusers model id or local path (default: ProGamerGov/sdxl-360-diffusion)'
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
'--base-model',
|
||
|
|
default=BASE_SDXL_MODEL,
|
||
|
|
help='SDXL base pipeline used when the model only provides a UNet (default: stabilityai/stable-diffusion-xl-base-1.0)'
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
'--vae-model',
|
||
|
|
default=SDXL_VAE,
|
||
|
|
help='VAE repo/path to load for SDXL models (default: stabilityai/sdxl-vae; try madebyollin/sdxl-vae-fp16-fix on Mac)'
|
||
|
|
)
|
||
|
|
parser.add_argument('--steps', type=int, default=25, help='Number of inference steps (default: 25)')
|
||
|
|
parser.add_argument('--guidance', type=float, default=4.5, help='CFG guidance scale (default: 4.5)')
|
||
|
|
parser.add_argument('--width', type=int, default=1024, help='Output width (default: 1024)')
|
||
|
|
parser.add_argument('--height', type=int, default=512, help='Output height (default: 512)')
|
||
|
|
parser.add_argument(
|
||
|
|
'--seam-inpaint',
|
||
|
|
action='store_true',
|
||
|
|
help='Patch the horizontal wrap seam by shifting, inpainting the center seam, then shifting back'
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
'--scheduler',
|
||
|
|
choices=['euler', 'euler_a', 'heun', 'ddim', 'dpmsolver', 'dpmsolver-sde'],
|
||
|
|
help='Sampler/scheduler override; default uses the pipeline scheduler (Euler for SDXL base)'
|
||
|
|
)
|
||
|
|
args = parser.parse_args()
|
||
|
|
|
||
|
|
base = sanitize_name(args.prompt)
|
||
|
|
target = args.output or next_filename(args.output_dir, base, args.width, args.height)
|
||
|
|
output_abs = os.path.abspath(target)
|
||
|
|
|
||
|
|
try:
|
||
|
|
result_path = generate(
|
||
|
|
args.prompt,
|
||
|
|
output_abs,
|
||
|
|
args.work_dir,
|
||
|
|
upscale=args.upscale,
|
||
|
|
model_path=args.model_path,
|
||
|
|
base_model=args.base_model,
|
||
|
|
vae_model=args.vae_model,
|
||
|
|
steps=args.steps,
|
||
|
|
guidance=args.guidance,
|
||
|
|
scheduler=args.scheduler,
|
||
|
|
width=args.width,
|
||
|
|
height=args.height,
|
||
|
|
seam_inpaint=args.seam_inpaint,
|
||
|
|
)
|
||
|
|
print(result_path)
|
||
|
|
except Exception as e: # noqa: BLE001
|
||
|
|
print(f"Generation failed: {e}")
|
||
|
|
raise
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == '__main__':
|
||
|
|
main()
|