From 99937926560ba2716f61ea4a331724583e83afb7 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 25 Feb 2023 18:17:18 +0900 Subject: [PATCH] latents upscaling in highres fix, vae batch size --- gen_img_diffusers.py | 84 ++++++++++++++++++++++++++++++++++---------- 1 file changed, 66 insertions(+), 18 deletions(-) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index a2d5b945..f049e8a2 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -589,6 +589,8 @@ class PipelineLike(): latents: Optional[torch.FloatTensor] = None, max_embeddings_multiples: Optional[int] = 3, output_type: Optional[str] = "pil", + vae_batch_size: float = None, + return_latents: bool = False, # return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, is_cancelled_callback: Optional[Callable[[], bool]] = None, @@ -680,6 +682,9 @@ class PipelineLike(): else: raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + vae_batch_size = batch_size if vae_batch_size is None else ( + int(vae_batch_size) if vae_batch_size >= 1 else max(1, int(batch_size * vae_batch_size))) + if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") @@ -793,7 +798,6 @@ class PipelineLike(): latents_dtype = text_embeddings.dtype init_latents_orig = None mask = None - noise = None if init_image is None: # get the initial random noise unless the user supplied it @@ -825,6 +829,8 @@ class PipelineLike(): if isinstance(init_image[0], PIL.Image.Image): init_image = [preprocess_image(im) for im in init_image] init_image = torch.cat(init_image) + if isinstance(init_image, list): + init_image = torch.stack(init_image) # mask image to tensor if mask_image is not None: @@ -835,9 +841,24 @@ class PipelineLike(): # encode the init image into latents and scale the latents init_image = init_image.to(device=self.device, dtype=latents_dtype) - init_latent_dist = self.vae.encode(init_image).latent_dist - init_latents = init_latent_dist.sample(generator=generator) - init_latents = 0.18215 * init_latents + if init_image.size()[2:] == (height // 8, width // 8): + init_latents = init_image + else: + if vae_batch_size >= batch_size: + init_latent_dist = self.vae.encode(init_image).latent_dist + init_latents = init_latent_dist.sample(generator=generator) + else: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + init_latents = [] + for i in tqdm(range(0, batch_size, vae_batch_size)): + init_latent_dist = self.vae.encode(init_image[i:i + vae_batch_size] + if vae_batch_size > 1 else init_image[i].unsqueeze(0)).latent_dist + init_latents.append(init_latent_dist.sample(generator=generator)) + init_latents = torch.cat(init_latents) + + init_latents = 0.18215 * init_latents + if len(init_latents) == 1: init_latents = init_latents.repeat((batch_size, 1, 1, 1)) init_latents_orig = init_latents @@ -932,8 +953,19 @@ class PipelineLike(): if is_cancelled_callback is not None and is_cancelled_callback(): return None + if return_latents: + return (latents, False) + latents = 1 / 0.18215 * latents - image = self.vae.decode(latents).sample + if vae_batch_size >= batch_size: + image = self.vae.decode(latents).sample + else: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + images = [] + for i in tqdm(range(0, batch_size, vae_batch_size)): + images.append(self.vae.decode(latents[i:i + vae_batch_size] if vae_batch_size > 1 else latents[i].unsqueeze(0)).sample) + image = torch.cat(images) image = (image / 2 + 0.5).clamp(0, 1) @@ -1820,7 +1852,7 @@ def preprocess_mask(mask): mask = mask.convert("L") w, h = mask.size w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - mask = mask.resize((w // 8, h // 8), resample=PIL.Image.LANCZOS) + mask = mask.resize((w // 8, h // 8), resample=PIL.Image.BILINEAR) # LANCZOS) mask = np.array(mask).astype(np.float32) / 255.0 mask = np.tile(mask, (4, 1, 1)) mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? @@ -1862,6 +1894,7 @@ class BatchDataExt(NamedTuple): class BatchData(NamedTuple): + return_latents: bool base: BatchDataBase ext: BatchDataExt @@ -2296,9 +2329,9 @@ def main(args): # highres_fixの処理 if highres_fix and not highres_1st: # 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す - print("process 1st stage1") + print("process 1st stage") batch_1st = [] - for base, ext in batch: + for _, base, ext in batch: width_1st = int(ext.width * args.highres_fix_scale + .5) height_1st = int(ext.height * args.highres_fix_scale + .5) width_1st = width_1st - width_1st % 32 @@ -2306,20 +2339,29 @@ def main(args): ext_1st = BatchDataExt(width_1st, height_1st, args.highres_fix_steps, ext.scale, ext.negative_scale, ext.strength, ext.network_muls) - batch_1st.append(BatchData(base, ext_1st)) + batch_1st.append(BatchData(args.highres_fix_latents_upscaling, base, ext_1st)) images_1st = process_batch(batch_1st, True, True) # 2nd stageのバッチを作成して以下処理する - print("process 2nd stage1") + print("process 2nd stage") + if args.highres_fix_latents_upscaling: + org_dtype = images_1st.dtype + if images_1st.dtype == torch.bfloat16: + images_1st = images_1st.to(torch.float) # interpolateがbf16をサポートしていない + images_1st = torch.nn.functional.interpolate( + images_1st, (batch[0].ext.height // 8, batch[0].ext.width // 8), mode='bilinear') # , antialias=True) + images_1st = images_1st.to(org_dtype) + batch_2nd = [] for i, (bd, image) in enumerate(zip(batch, images_1st)): - image = image.resize((bd.ext.width, bd.ext.height), resample=PIL.Image.LANCZOS) # img2imgとして設定 - bd_2nd = BatchData(BatchDataBase(*bd.base[0:3], bd.base.seed+1, image, None, *bd.base[6:]), bd.ext) + if not args.highres_fix_latents_upscaling: + image = image.resize((bd.ext.width, bd.ext.height), resample=PIL.Image.LANCZOS) # img2imgとして設定 + bd_2nd = BatchData(False, BatchDataBase(*bd.base[0:3], bd.base.seed+1, image, None, *bd.base[6:]), bd.ext) batch_2nd.append(bd_2nd) batch = batch_2nd # このバッチの情報を取り出す - (step_first, _, _, _, init_image, mask_image, _, guide_image), \ + return_latents, (step_first, _, _, _, init_image, mask_image, _, guide_image), \ (width, height, steps, scale, negative_scale, strength, network_muls) = batch[0] noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR) @@ -2353,7 +2395,7 @@ def main(args): all_images_are_same = True all_masks_are_same = True all_guide_images_are_same = True - for i, ((_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch): + for i, (_, (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch): prompts.append(prompt) negative_prompts.append(negative_prompt) seeds.append(seed) @@ -2413,8 +2455,10 @@ def main(args): n.set_multiplier(m) images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, negative_scale, strength, latents=start_code, - output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises, clip_prompts=clip_prompts, clip_guide_images=guide_images)[0] - if highres_1st and not args.highres_fix_save_1st: + output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises, + vae_batch_size=args.vae_batch_size, return_latents=return_latents, + clip_prompts=clip_prompts, clip_guide_images=guide_images)[0] + if highres_1st and not args.highres_fix_save_1st: # return images or latents return images # save image @@ -2612,9 +2656,9 @@ def main(args): print("Use previous image as guide image.") guide_image = prev_image - b1 = BatchData(BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), + b1 = BatchData(False, BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), BatchDataExt(width, height, steps, scale, negative_scale, strength, tuple(network_muls) if network_muls else None)) - if len(batch_data) > 0 and batch_data[-1][1] != b1[1]: # バッチ分割必要? + if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要? process_batch(batch_data, highres_fix) batch_data.clear() @@ -2658,6 +2702,8 @@ if __name__ == '__main__': parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ") parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅") parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ") + parser.add_argument("--vae_batch_size", type=float, default=None, + help="batch size for VAE, < 1.0 for ratio / VAE処理時のバッチサイズ、1未満の値の場合は通常バッチサイズの比率") parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数") parser.add_argument('--sampler', type=str, default='ddim', choices=['ddim', 'pndm', 'lms', 'euler', 'euler_a', 'heun', 'dpm_2', 'dpm_2_a', 'dpmsolver', @@ -2713,6 +2759,8 @@ if __name__ == '__main__': help="1st stage steps for highres fix / highres fixの最初のステージのステップ数") parser.add_argument("--highres_fix_save_1st", action='store_true', help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する") + parser.add_argument("--highres_fix_latents_upscaling", action='store_true', + help="use latents upscaling for highres fix / highres fixでlatentで拡大する") parser.add_argument("--negative_scale", type=float, default=None, help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する")