auto-git:
[add] README.md [add] default.png [add] equirect_hdr_icon_512-Wiederhergestellt.png [add] generate_equirect.py [add] icon.png [add] index.html [add] package-lock.json [add] package.json [add] public/ [add] requirements.txt [add] run.sh [add] src-tauri/ [add] src/ [add] vite.config.js
This commit is contained in:
609
generate_equirect.py
Normal file
609
generate_equirect.py
Normal file
@@ -0,0 +1,609 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Generate an equirectangular HDRI image using Diffusers.
|
||||
Optional upscaling:
|
||||
--upscale topaz → legacy Topaz Photo AI CLI (if installed)
|
||||
--upscale realesrgan → open Real-ESRGAN in-Python upscaler (pip install realesrgan==0.3.0 basicsr opencv-python)
|
||||
Default flow: prompt in → equirectangular PNG out. Add --seam-inpaint to patch the horizontal wrap seam.
|
||||
"""
|
||||
import argparse
|
||||
import gc
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import tempfile
|
||||
import torch
|
||||
from PIL import Image, ImageDraw
|
||||
import numpy as np
|
||||
|
||||
from diffusers import (
|
||||
StableDiffusionPipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
DDIMScheduler,
|
||||
StableDiffusionInpaintPipeline,
|
||||
AutoencoderKL,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
# Default panorama model. You can override via --model-path.
|
||||
# Uses the SDXL 360 diffusion checkpoint from ProGamerGov (single-file safetensors).
|
||||
MODEL_PATH = "ProGamerGov/sdxl-360-diffusion"
|
||||
BASE_SDXL_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
INPAINT_MODEL = "Lykon/dreamshaper-8-inpainting"
|
||||
TOPAZ_CLI = "/Applications/Topaz Photo AI.app/Contents/MacOS/Topaz Photo AI"
|
||||
REALESRGAN_MODEL = "RealESRGAN_x4plus.pth"
|
||||
REALESRGAN_SCALE = 4 # x4 model for full upscale
|
||||
# Recommended VAE for SDXL checkpoints
|
||||
SDXL_VAE = "stabilityai/sdxl-vae"
|
||||
|
||||
|
||||
def sanitize_name(prompt: str) -> str:
|
||||
base = prompt.strip().lower()
|
||||
base = re.sub(r"\s+", "_", base)
|
||||
base = re.sub(r"[^a-z0-9_]+", "", base)
|
||||
return base or "env"
|
||||
|
||||
|
||||
def next_filename(output_dir: str, base: str, width: int, height: int) -> str:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
i = 1
|
||||
while True:
|
||||
fname = f"{base}-{i}-{width}x{height}.png"
|
||||
candidate = os.path.join(output_dir, fname)
|
||||
if not os.path.exists(candidate):
|
||||
return candidate
|
||||
i += 1
|
||||
|
||||
|
||||
def shift_image(img: Image.Image, shift: int) -> Image.Image:
|
||||
w, h = img.size
|
||||
out = Image.new("RGB", (w, h))
|
||||
out.paste(img.crop((shift, 0, w, h)), (0, 0))
|
||||
out.paste(img.crop((0, 0, shift, h)), (w - shift, 0))
|
||||
return out
|
||||
|
||||
|
||||
def create_mask(width: int, height: int, mask_w: int) -> Image.Image:
|
||||
mask = Image.new("L", (width, height), 0)
|
||||
draw = ImageDraw.Draw(mask)
|
||||
left = (width - mask_w) // 2
|
||||
draw.rectangle([left, 0, left + mask_w, height], fill=255)
|
||||
return mask
|
||||
|
||||
|
||||
def unshift_image(img: Image.Image, shift: int) -> Image.Image:
|
||||
w, h = img.size
|
||||
out = Image.new("RGB", (w, h))
|
||||
out.paste(img.crop((w - shift, 0, w, h)), (0, 0))
|
||||
out.paste(img.crop((0, 0, w - shift, h)), (shift, 0))
|
||||
return out
|
||||
|
||||
|
||||
def select_device() -> str:
|
||||
if torch.backends.mps.is_available():
|
||||
return "mps"
|
||||
if torch.cuda.is_available():
|
||||
return "cuda"
|
||||
return "cpu"
|
||||
|
||||
|
||||
def clear_torch_cache(device: str | None = None) -> None:
|
||||
gc.collect()
|
||||
if device == "cuda" and torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
elif device == "mps" and torch.backends.mps.is_available() and hasattr(torch, "mps"):
|
||||
try:
|
||||
torch.mps.synchronize()
|
||||
except Exception:
|
||||
pass
|
||||
empty_cache = getattr(torch.mps, "empty_cache", None)
|
||||
if empty_cache is not None:
|
||||
empty_cache()
|
||||
|
||||
|
||||
def decode_latents_to_image(vae: AutoencoderKL, latents: torch.Tensor, device: str) -> Image.Image:
|
||||
print("→ Decoding latent image with standalone VAE…")
|
||||
if hasattr(vae, "enable_tiling"):
|
||||
vae.enable_tiling()
|
||||
if hasattr(vae, "enable_slicing"):
|
||||
vae.enable_slicing()
|
||||
|
||||
vae.to(device)
|
||||
vae.eval()
|
||||
vae_dtype = next(vae.parameters()).dtype
|
||||
|
||||
with torch.inference_mode():
|
||||
latents = latents.to(device=device, dtype=vae_dtype)
|
||||
has_latents_mean = hasattr(vae.config, "latents_mean") and vae.config.latents_mean is not None
|
||||
has_latents_std = hasattr(vae.config, "latents_std") and vae.config.latents_std is not None
|
||||
if has_latents_mean and has_latents_std:
|
||||
latent_channels = len(vae.config.latents_mean)
|
||||
latents_mean = (
|
||||
torch.tensor(vae.config.latents_mean)
|
||||
.view(1, latent_channels, 1, 1)
|
||||
.to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = (
|
||||
torch.tensor(vae.config.latents_std)
|
||||
.view(1, latent_channels, 1, 1)
|
||||
.to(latents.device, latents.dtype)
|
||||
)
|
||||
latents = latents * latents_std / vae.config.scaling_factor + latents_mean
|
||||
else:
|
||||
latents = latents / vae.config.scaling_factor
|
||||
|
||||
decoded = vae.decode(latents, return_dict=False)[0]
|
||||
decoded = (decoded / 2 + 0.5).clamp(0, 1)
|
||||
image = decoded[0].detach().cpu().permute(1, 2, 0).float().numpy()
|
||||
|
||||
del decoded, latents
|
||||
clear_torch_cache(device)
|
||||
return Image.fromarray((image * 255).round().astype("uint8"))
|
||||
|
||||
|
||||
def run_topaz(input_path: str, tempdir: str) -> str:
|
||||
print("→ Upscaling with Topaz Photo AI CLI…")
|
||||
result = None
|
||||
try:
|
||||
result = os.system(f'"{TOPAZ_CLI}" --cli "{input_path}" -o "{tempdir}"')
|
||||
except Exception as e: # noqa: BLE001
|
||||
raise RuntimeError(f"Topaz invocation failed: {e}") from e
|
||||
if result != 0:
|
||||
raise RuntimeError("Topaz CLI returned non-zero exit code")
|
||||
upscaled_files = sorted(
|
||||
[os.path.join(tempdir, f) for f in os.listdir(tempdir) if f.lower().endswith('.png')],
|
||||
key=os.path.getmtime,
|
||||
reverse=True
|
||||
)
|
||||
if not upscaled_files:
|
||||
raise RuntimeError("Topaz produced no PNG output")
|
||||
print(f"→ Upscaled file: {upscaled_files[0]}")
|
||||
return upscaled_files[0]
|
||||
|
||||
|
||||
def run_realesrgan(
|
||||
input_image: Image.Image,
|
||||
tempdir: str,
|
||||
scale: int = 4,
|
||||
model_path: str = REALESRGAN_MODEL,
|
||||
progress_cb=None
|
||||
) -> str:
|
||||
try:
|
||||
# Compatibility shim for newer torchvision where functional_tensor moved/renamed
|
||||
import sys
|
||||
try:
|
||||
import torchvision.transforms._functional_tensor as _ft # type: ignore
|
||||
sys.modules.setdefault("torchvision.transforms.functional_tensor", _ft)
|
||||
except Exception:
|
||||
pass
|
||||
from realesrgan import RealESRGANer
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
import cv2
|
||||
import types
|
||||
except Exception as e: # noqa: BLE001
|
||||
raise RuntimeError(
|
||||
"Real-ESRGAN dependencies missing. Install with: pip install realesrgan==0.3.0 basicsr opencv-python torchvision"
|
||||
) from e
|
||||
|
||||
device = select_device()
|
||||
is_sdxl = "sdxl" in model_path.lower()
|
||||
if not model_path or not os.path.exists(model_path):
|
||||
raise RuntimeError(
|
||||
f"Real-ESRGAN model not found at {model_path!r}. "
|
||||
"Place RealESRGAN_x4plus.pth next to this script or update REALESRGAN_MODEL."
|
||||
)
|
||||
img_bgr = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR)
|
||||
print(f"→ Upscaling with Real-ESRGAN (x{scale}) on {device}…")
|
||||
model = RRDBNet(
|
||||
num_in_ch=3,
|
||||
num_out_ch=3,
|
||||
num_feat=64,
|
||||
num_block=23,
|
||||
num_grow_ch=32,
|
||||
scale=scale,
|
||||
)
|
||||
upsampler = RealESRGANer(
|
||||
model_path=model_path,
|
||||
scale=scale,
|
||||
model=model,
|
||||
tile=64, # aggressive tiling to keep memory & runtime manageable
|
||||
tile_pad=10,
|
||||
pre_pad=0,
|
||||
half=False, # keep full precision for CPU/MPS
|
||||
)
|
||||
|
||||
# Wrap tile processing to surface progress per tile.
|
||||
def tile_process_with_progress(self):
|
||||
batch, channel, height, width = self.img.shape
|
||||
output_height = height * self.scale
|
||||
output_width = width * self.scale
|
||||
output_shape = (batch, channel, output_height, output_width)
|
||||
|
||||
self.output = self.img.new_zeros(output_shape)
|
||||
tiles_x = math.ceil(width / self.tile_size)
|
||||
tiles_y = math.ceil(height / self.tile_size)
|
||||
total_tiles = max(1, tiles_x * tiles_y)
|
||||
|
||||
for y in range(tiles_y):
|
||||
for x in range(tiles_x):
|
||||
ofs_x = x * self.tile_size
|
||||
ofs_y = y * self.tile_size
|
||||
input_start_x = ofs_x
|
||||
input_end_x = min(ofs_x + self.tile_size, width)
|
||||
input_start_y = ofs_y
|
||||
input_end_y = min(ofs_y + self.tile_size, height)
|
||||
|
||||
input_start_x_pad = max(input_start_x - self.tile_pad, 0)
|
||||
input_end_x_pad = min(input_end_x + self.tile_pad, width)
|
||||
input_start_y_pad = max(input_start_y - self.tile_pad, 0)
|
||||
input_end_y_pad = min(input_end_y + self.tile_pad, height)
|
||||
|
||||
input_tile_width = input_end_x - input_start_x
|
||||
input_tile_height = input_end_y - input_start_y
|
||||
tile_idx = y * tiles_x + x + 1
|
||||
input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
|
||||
|
||||
try:
|
||||
with torch.no_grad():
|
||||
output_tile = self.model(input_tile)
|
||||
except RuntimeError as error:
|
||||
print('Error', error)
|
||||
|
||||
if progress_cb:
|
||||
progress_cb("upscale", tile_idx, total_tiles)
|
||||
|
||||
output_start_x = input_start_x * self.scale
|
||||
output_end_x = input_end_x * self.scale
|
||||
output_start_y = input_start_y * self.scale
|
||||
output_end_y = input_end_y * self.scale
|
||||
|
||||
output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
|
||||
output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
|
||||
output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
|
||||
output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
|
||||
|
||||
self.output[:, :, output_start_y:output_end_y,
|
||||
output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
|
||||
output_start_x_tile:output_end_x_tile]
|
||||
|
||||
if progress_cb:
|
||||
upsampler.tile_process = types.MethodType(tile_process_with_progress, upsampler)
|
||||
|
||||
if progress_cb:
|
||||
progress_cb("upscale", 0, 1)
|
||||
sr_img, _ = upsampler.enhance(img_bgr, outscale=scale)
|
||||
sr_img = Image.fromarray(cv2.cvtColor(sr_img, cv2.COLOR_BGR2RGB))
|
||||
if progress_cb:
|
||||
progress_cb("upscale", 1, 1)
|
||||
out_path = os.path.join(tempdir, f"realesrgan_x{scale}.png")
|
||||
sr_img.save(out_path)
|
||||
print(f"→ Real-ESRGAN output: {out_path}")
|
||||
return out_path
|
||||
|
||||
|
||||
def generate(
|
||||
prompt: str,
|
||||
output_path: str,
|
||||
work_dir: str,
|
||||
upscale: str = "none",
|
||||
model_path: str = MODEL_PATH,
|
||||
base_model: str = BASE_SDXL_MODEL,
|
||||
vae_model: str = SDXL_VAE,
|
||||
steps: int = 25,
|
||||
guidance: float = 4.5,
|
||||
scheduler: str | None = None,
|
||||
width: int = 1024,
|
||||
height: int = 512,
|
||||
seam_inpaint: bool = False,
|
||||
) -> str:
|
||||
# Normalize common aliases that 404 on HF
|
||||
aliases = {
|
||||
"proximasan/sdxl-360-diffusion": "ProGamerGov/sdxl-360-diffusion",
|
||||
"proximasan": "ProGamerGov/sdxl-360-diffusion",
|
||||
}
|
||||
model_path = aliases.get(model_path, model_path)
|
||||
|
||||
device = select_device()
|
||||
is_sdxl = "sdxl" in model_path.lower()
|
||||
scale = guidance # keep inpaint guidance in sync with cfg guidance
|
||||
enable_upscale = bool(upscale and upscale != "none")
|
||||
|
||||
os.makedirs(work_dir, exist_ok=True)
|
||||
|
||||
with tempfile.TemporaryDirectory(dir=work_dir) as tempdir:
|
||||
print(f"→ Using tempdir: {tempdir}")
|
||||
|
||||
gen_pipe = None
|
||||
load_errors: list[str] = []
|
||||
vae = None
|
||||
if is_sdxl:
|
||||
try:
|
||||
vae = AutoencoderKL.from_pretrained(vae_model, subfolder="vae", torch_dtype=torch.float32)
|
||||
except Exception as e: # noqa: BLE001
|
||||
load_errors.append(f"vae: {e}")
|
||||
|
||||
# Try native Diffusers repo
|
||||
try:
|
||||
if is_sdxl:
|
||||
pipe_kwargs = {"torch_dtype": torch.float32}
|
||||
if vae is not None:
|
||||
pipe_kwargs["vae"] = vae
|
||||
|
||||
# Some SDXL repos only ship a UNet (e.g., sdxl-360); in that case load SDXL base
|
||||
# and swap the UNet to keep the rest of the components consistent.
|
||||
unet_only = False
|
||||
if model_path and model_path.endswith("sdxl-360-diffusion"):
|
||||
try:
|
||||
hf_hub_download(model_path, "unet/config.json")
|
||||
unet_only = True
|
||||
except Exception:
|
||||
unet_only = False
|
||||
|
||||
if unet_only:
|
||||
base_pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
base_model,
|
||||
**pipe_kwargs
|
||||
).to(device)
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
model_path,
|
||||
subfolder="unet",
|
||||
torch_dtype=torch.float32
|
||||
).to(device)
|
||||
base_pipe.unet = unet
|
||||
gen_pipe = base_pipe
|
||||
else:
|
||||
gen_pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
model_path,
|
||||
**pipe_kwargs
|
||||
).to(device)
|
||||
else:
|
||||
gen_pipe = StableDiffusionPipeline.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch.float32
|
||||
).to(device)
|
||||
except Exception as e: # noqa: BLE001
|
||||
load_errors.append(f"from_pretrained: {e}")
|
||||
|
||||
# Fallback: single-file SDXL checkpoint from the repo
|
||||
if gen_pipe is None:
|
||||
ckpt_candidates = [
|
||||
"sdxl_360_diffusion.safetensors",
|
||||
"sdxl_360_diffusion_unet.safetensors",
|
||||
"model.safetensors",
|
||||
]
|
||||
last_err = None
|
||||
for fname in ckpt_candidates:
|
||||
try:
|
||||
ckpt = hf_hub_download(model_path, fname)
|
||||
pipe_kwargs = {"torch_dtype": torch.float32}
|
||||
if vae is not None:
|
||||
pipe_kwargs["vae"] = vae
|
||||
gen_pipe = StableDiffusionXLPipeline.from_single_file(
|
||||
ckpt,
|
||||
**pipe_kwargs
|
||||
).to(device)
|
||||
break
|
||||
except Exception as e2: # noqa: BLE001
|
||||
last_err = e2
|
||||
load_errors.append(f"{fname}: {e2}")
|
||||
if gen_pipe is None:
|
||||
raise RuntimeError(
|
||||
f"Failed to load model '{model_path}'. "
|
||||
"Ensure the repo/path exists (e.g., ProGamerGov/sdxl-360-diffusion) and "
|
||||
"install accelerate for low_cpu_mem_usage: pip install accelerate. "
|
||||
f"Errors: {load_errors}"
|
||||
) from (last_err or Exception("No pipeline loaded"))
|
||||
if gen_pipe.vae is None and is_sdxl:
|
||||
try:
|
||||
vae = vae or AutoencoderKL.from_pretrained(
|
||||
SDXL_VAE,
|
||||
subfolder="vae",
|
||||
torch_dtype=torch.float32
|
||||
)
|
||||
gen_pipe.vae = vae.to(device)
|
||||
gen_pipe.to(device)
|
||||
except Exception as e: # noqa: BLE001
|
||||
load_errors.append(f"vae-fallback: {e}")
|
||||
raise RuntimeError(
|
||||
"Loaded SDXL pipeline without a VAE; failed to attach the SDXL VAE. "
|
||||
f"Errors: {load_errors}"
|
||||
) from e
|
||||
# Optionally override scheduler; otherwise keep the pipeline default (Euler for SDXL base).
|
||||
if scheduler:
|
||||
sched_kind = scheduler.lower()
|
||||
sched_cfg = gen_pipe.scheduler.config
|
||||
if sched_kind in {"dpmsolver", "dpmsolver++"}:
|
||||
gen_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sched_cfg)
|
||||
elif sched_kind in {"dpmsolver-sde", "dpmsolver_sde"}:
|
||||
gen_pipe.scheduler = DPMSolverMultistepScheduler.from_config(
|
||||
sched_cfg,
|
||||
algorithm_type="sde-dpmsolver++"
|
||||
)
|
||||
elif sched_kind in {"euler"}:
|
||||
gen_pipe.scheduler = EulerDiscreteScheduler.from_config(sched_cfg)
|
||||
elif sched_kind in {"euler_a", "euler-ancestral", "euler-ancestral-discrete"}:
|
||||
gen_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sched_cfg)
|
||||
elif sched_kind in {"heun"}:
|
||||
gen_pipe.scheduler = HeunDiscreteScheduler.from_config(sched_cfg)
|
||||
elif sched_kind in {"ddim"}:
|
||||
gen_pipe.scheduler = DDIMScheduler.from_config(sched_cfg)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported scheduler '{scheduler}'. "
|
||||
"Try one of: euler, euler_a, heun, ddim, dpmsolver, dpmsolver-sde."
|
||||
)
|
||||
gen_pipe.enable_attention_slicing()
|
||||
if is_sdxl and vae is not None and hasattr(gen_pipe, "enable_vae_tiling"):
|
||||
gen_pipe.enable_vae_tiling()
|
||||
|
||||
def progress_cb(phase: str, current: int, total: int):
|
||||
payload = {
|
||||
"phase": phase,
|
||||
"current": current,
|
||||
"total": total,
|
||||
"upscale": enable_upscale,
|
||||
"seamInpaint": seam_inpaint,
|
||||
}
|
||||
print(f"PROGRESS {json.dumps(payload)}", flush=True)
|
||||
|
||||
print("→ Generating equirectangular HDRI…")
|
||||
progress_cb("gen", 0, steps)
|
||||
image = gen_pipe(
|
||||
prompt=prompt,
|
||||
num_inference_steps=steps,
|
||||
guidance_scale=guidance,
|
||||
width=width,
|
||||
height=height,
|
||||
callback_steps=1,
|
||||
callback=lambda step, timestep, kwargs: progress_cb("gen", step + 1, steps),
|
||||
).images[0]
|
||||
|
||||
gen_path = os.path.join(tempdir, f"base_{width}x{height}.png")
|
||||
image.save(gen_path)
|
||||
print(f"→ Saved initial image to {gen_path}")
|
||||
|
||||
seamless_path = os.path.join(tempdir, os.path.basename(output_path))
|
||||
if seam_inpaint:
|
||||
shift_amt = width // 2
|
||||
mask_w = width // 8
|
||||
|
||||
shifted = shift_image(image, shift_amt)
|
||||
mask = create_mask(width, height, mask_w)
|
||||
|
||||
inpaint_pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
INPAINT_MODEL,
|
||||
torch_dtype=torch.float32
|
||||
).to(device)
|
||||
inpaint_pipe.enable_attention_slicing()
|
||||
|
||||
print("→ Inpainting seam for seamless tiling…")
|
||||
progress_cb("inpaint", 0, steps)
|
||||
inpainted = inpaint_pipe(
|
||||
prompt=prompt,
|
||||
image=shifted,
|
||||
mask_image=mask,
|
||||
num_inference_steps=steps,
|
||||
guidance_scale=scale,
|
||||
width=width,
|
||||
height=height,
|
||||
callback_steps=1,
|
||||
callback=lambda step, timestep, kwargs: progress_cb("inpaint", step + 1, steps),
|
||||
).images[0]
|
||||
|
||||
inpainted = unshift_image(inpainted, shift_amt)
|
||||
inpainted.save(seamless_path)
|
||||
print(f"→ Crafted seamless image: {seamless_path}")
|
||||
final_source = inpainted
|
||||
else:
|
||||
image.save(seamless_path)
|
||||
print(f"→ Using raw output (seam inpaint disabled): {seamless_path}")
|
||||
final_source = image
|
||||
|
||||
final_path = seamless_path
|
||||
|
||||
if upscale and upscale != "none":
|
||||
try:
|
||||
if upscale is True or upscale == "topaz":
|
||||
final_path = run_topaz(seamless_path, tempdir)
|
||||
elif upscale == "realesrgan":
|
||||
final_path = run_realesrgan(
|
||||
final_source,
|
||||
tempdir,
|
||||
scale=REALESRGAN_SCALE,
|
||||
model_path=REALESRGAN_MODEL,
|
||||
progress_cb=progress_cb
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown upscale option '{upscale}'")
|
||||
except Exception as e: # noqa: BLE001
|
||||
print(f"Upscaling failed ({upscale}); keeping seamless image: {e}")
|
||||
|
||||
shutil.move(final_path, output_path)
|
||||
try:
|
||||
with Image.open(output_path) as _im:
|
||||
print(f"→ Final image written to {output_path} [{_im.size[0]}x{_im.size[1]}]")
|
||||
except Exception:
|
||||
print(f"→ Final image written to {output_path}")
|
||||
return output_path
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate an equirectangular HDRI image"
|
||||
)
|
||||
parser.add_argument('--prompt', required=True, help='Text prompt for generation')
|
||||
parser.add_argument('--output', help='Output filename (PNG)')
|
||||
parser.add_argument('--output-dir', default='output', help='Directory for outputs (default: output)')
|
||||
parser.add_argument('--work-dir', default=os.path.dirname(os.path.abspath(__file__)), help='Working directory for temp files')
|
||||
parser.add_argument(
|
||||
'--upscale',
|
||||
choices=['none', 'topaz', 'realesrgan'],
|
||||
default='realesrgan',
|
||||
help='Optional upscaler: none, topaz (legacy), realesrgan (open source; default)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--model-path',
|
||||
default=MODEL_PATH,
|
||||
help='Diffusers model id or local path (default: ProGamerGov/sdxl-360-diffusion)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--base-model',
|
||||
default=BASE_SDXL_MODEL,
|
||||
help='SDXL base pipeline used when the model only provides a UNet (default: stabilityai/stable-diffusion-xl-base-1.0)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--vae-model',
|
||||
default=SDXL_VAE,
|
||||
help='VAE repo/path to load for SDXL models (default: stabilityai/sdxl-vae; try madebyollin/sdxl-vae-fp16-fix on Mac)'
|
||||
)
|
||||
parser.add_argument('--steps', type=int, default=25, help='Number of inference steps (default: 25)')
|
||||
parser.add_argument('--guidance', type=float, default=4.5, help='CFG guidance scale (default: 4.5)')
|
||||
parser.add_argument('--width', type=int, default=1024, help='Output width (default: 1024)')
|
||||
parser.add_argument('--height', type=int, default=512, help='Output height (default: 512)')
|
||||
parser.add_argument(
|
||||
'--seam-inpaint',
|
||||
action='store_true',
|
||||
help='Patch the horizontal wrap seam by shifting, inpainting the center seam, then shifting back'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--scheduler',
|
||||
choices=['euler', 'euler_a', 'heun', 'ddim', 'dpmsolver', 'dpmsolver-sde'],
|
||||
help='Sampler/scheduler override; default uses the pipeline scheduler (Euler for SDXL base)'
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
base = sanitize_name(args.prompt)
|
||||
target = args.output or next_filename(args.output_dir, base, args.width, args.height)
|
||||
output_abs = os.path.abspath(target)
|
||||
|
||||
try:
|
||||
result_path = generate(
|
||||
args.prompt,
|
||||
output_abs,
|
||||
args.work_dir,
|
||||
upscale=args.upscale,
|
||||
model_path=args.model_path,
|
||||
base_model=args.base_model,
|
||||
vae_model=args.vae_model,
|
||||
steps=args.steps,
|
||||
guidance=args.guidance,
|
||||
scheduler=args.scheduler,
|
||||
width=args.width,
|
||||
height=args.height,
|
||||
seam_inpaint=args.seam_inpaint,
|
||||
)
|
||||
print(result_path)
|
||||
except Exception as e: # noqa: BLE001
|
||||
print(f"Generation failed: {e}")
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user