From a85fcfe05f999a9dfa36edebc75ba9d62106b96a Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 25 Apr 2023 08:10:21 +0900 Subject: [PATCH] fix latent upscale not working if bs>1 --- gen_img_diffusers.py | 2 +- tools/latent_upscaler.py | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 09c68000..988eae75 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -945,7 +945,7 @@ class PipelineLike: # encode the init image into latents and scale the latents init_image = init_image.to(device=self.device, dtype=latents_dtype) - if init_image.size()[1:] == (height // 8, width // 8): + if init_image.size()[-2:] == (height // 8, width // 8): init_latents = init_image else: if vae_batch_size >= batch_size: diff --git a/tools/latent_upscaler.py b/tools/latent_upscaler.py index c69f983c..ab1fa339 100644 --- a/tools/latent_upscaler.py +++ b/tools/latent_upscaler.py @@ -243,7 +243,13 @@ def create_upscaler(**kwargs): model = Upscaler() print(f"Loading weights from {weights}...") - model.load_state_dict(torch.load(weights, map_location=torch.device("cpu"))) + if os.path.splitext(weights)[1] == ".safetensors": + from safetensors.torch import load_file + + sd = load_file(weights) + else: + sd = torch.load(weights, map_location=torch.device("cpu")) + model.load_state_dict(sd) return model