auto-git:
[change] generate_equirect.py
This commit is contained in:
@@ -358,6 +358,7 @@ def generate(
|
|||||||
).to(device)
|
).to(device)
|
||||||
base_pipe.unet = unet
|
base_pipe.unet = unet
|
||||||
gen_pipe = base_pipe
|
gen_pipe = base_pipe
|
||||||
|
del base_pipe, unet
|
||||||
else:
|
else:
|
||||||
gen_pipe = StableDiffusionXLPipeline.from_pretrained(
|
gen_pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||||
model_path,
|
model_path,
|
||||||
@@ -440,8 +441,12 @@ def generate(
|
|||||||
"Try one of: euler, euler_a, heun, ddim, dpmsolver, dpmsolver-sde."
|
"Try one of: euler, euler_a, heun, ddim, dpmsolver, dpmsolver-sde."
|
||||||
)
|
)
|
||||||
gen_pipe.enable_attention_slicing()
|
gen_pipe.enable_attention_slicing()
|
||||||
if is_sdxl and vae is not None and hasattr(gen_pipe, "enable_vae_tiling"):
|
if hasattr(gen_pipe, "enable_vae_tiling"):
|
||||||
gen_pipe.enable_vae_tiling()
|
gen_pipe.enable_vae_tiling()
|
||||||
|
if hasattr(gen_pipe, "enable_vae_slicing"):
|
||||||
|
gen_pipe.enable_vae_slicing()
|
||||||
|
if "pipe_kwargs" in locals():
|
||||||
|
del pipe_kwargs
|
||||||
|
|
||||||
def progress_cb(phase: str, current: int, total: int):
|
def progress_cb(phase: str, current: int, total: int):
|
||||||
payload = {
|
payload = {
|
||||||
@@ -463,7 +468,19 @@ def generate(
|
|||||||
height=height,
|
height=height,
|
||||||
callback_steps=1,
|
callback_steps=1,
|
||||||
callback=lambda step, timestep, kwargs: progress_cb("gen", step + 1, steps),
|
callback=lambda step, timestep, kwargs: progress_cb("gen", step + 1, steps),
|
||||||
).images[0]
|
output_type="latent",
|
||||||
|
)
|
||||||
|
latents = latent_output.images.detach().cpu()
|
||||||
|
decoder_vae = gen_pipe.vae
|
||||||
|
vae = None
|
||||||
|
del latent_output, gen_pipe
|
||||||
|
clear_torch_cache(device)
|
||||||
|
|
||||||
|
progress_cb("decode", 0, 1)
|
||||||
|
image = decode_latents_to_image(decoder_vae, latents, device)
|
||||||
|
progress_cb("decode", 1, 1)
|
||||||
|
del latents, decoder_vae
|
||||||
|
clear_torch_cache(device)
|
||||||
|
|
||||||
gen_path = os.path.join(tempdir, f"base_{width}x{height}.png")
|
gen_path = os.path.join(tempdir, f"base_{width}x{height}.png")
|
||||||
image.save(gen_path)
|
image.save(gen_path)
|
||||||
|
|||||||
Reference in New Issue
Block a user