latents upscaling in highres fix, vae batch size

This commit is contained in:
Kohya S
2023-02-25 18:17:18 +09:00
parent f0ae7eea95
commit 9993792656

View File

@@ -589,6 +589,8 @@ class PipelineLike():
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.FloatTensor] = None,
max_embeddings_multiples: Optional[int] = 3, max_embeddings_multiples: Optional[int] = 3,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
vae_batch_size: float = None,
return_latents: bool = False,
# return_dict: bool = True, # return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
is_cancelled_callback: Optional[Callable[[], bool]] = None, is_cancelled_callback: Optional[Callable[[], bool]] = None,
@@ -680,6 +682,9 @@ class PipelineLike():
else: else:
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
vae_batch_size = batch_size if vae_batch_size is None else (
int(vae_batch_size) if vae_batch_size >= 1 else max(1, int(batch_size * vae_batch_size)))
if strength < 0 or strength > 1: if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
@@ -793,7 +798,6 @@ class PipelineLike():
latents_dtype = text_embeddings.dtype latents_dtype = text_embeddings.dtype
init_latents_orig = None init_latents_orig = None
mask = None mask = None
noise = None
if init_image is None: if init_image is None:
# get the initial random noise unless the user supplied it # get the initial random noise unless the user supplied it
@@ -825,6 +829,8 @@ class PipelineLike():
if isinstance(init_image[0], PIL.Image.Image): if isinstance(init_image[0], PIL.Image.Image):
init_image = [preprocess_image(im) for im in init_image] init_image = [preprocess_image(im) for im in init_image]
init_image = torch.cat(init_image) init_image = torch.cat(init_image)
if isinstance(init_image, list):
init_image = torch.stack(init_image)
# mask image to tensor # mask image to tensor
if mask_image is not None: if mask_image is not None:
@@ -835,9 +841,24 @@ class PipelineLike():
# encode the init image into latents and scale the latents # encode the init image into latents and scale the latents
init_image = init_image.to(device=self.device, dtype=latents_dtype) init_image = init_image.to(device=self.device, dtype=latents_dtype)
init_latent_dist = self.vae.encode(init_image).latent_dist if init_image.size()[2:] == (height // 8, width // 8):
init_latents = init_latent_dist.sample(generator=generator) init_latents = init_image
init_latents = 0.18215 * init_latents else:
if vae_batch_size >= batch_size:
init_latent_dist = self.vae.encode(init_image).latent_dist
init_latents = init_latent_dist.sample(generator=generator)
else:
if torch.cuda.is_available():
torch.cuda.empty_cache()
init_latents = []
for i in tqdm(range(0, batch_size, vae_batch_size)):
init_latent_dist = self.vae.encode(init_image[i:i + vae_batch_size]
if vae_batch_size > 1 else init_image[i].unsqueeze(0)).latent_dist
init_latents.append(init_latent_dist.sample(generator=generator))
init_latents = torch.cat(init_latents)
init_latents = 0.18215 * init_latents
if len(init_latents) == 1: if len(init_latents) == 1:
init_latents = init_latents.repeat((batch_size, 1, 1, 1)) init_latents = init_latents.repeat((batch_size, 1, 1, 1))
init_latents_orig = init_latents init_latents_orig = init_latents
@@ -932,8 +953,19 @@ class PipelineLike():
if is_cancelled_callback is not None and is_cancelled_callback(): if is_cancelled_callback is not None and is_cancelled_callback():
return None return None
if return_latents:
return (latents, False)
latents = 1 / 0.18215 * latents latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample if vae_batch_size >= batch_size:
image = self.vae.decode(latents).sample
else:
if torch.cuda.is_available():
torch.cuda.empty_cache()
images = []
for i in tqdm(range(0, batch_size, vae_batch_size)):
images.append(self.vae.decode(latents[i:i + vae_batch_size] if vae_batch_size > 1 else latents[i].unsqueeze(0)).sample)
image = torch.cat(images)
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
@@ -1820,7 +1852,7 @@ def preprocess_mask(mask):
mask = mask.convert("L") mask = mask.convert("L")
w, h = mask.size w, h = mask.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
mask = mask.resize((w // 8, h // 8), resample=PIL.Image.LANCZOS) mask = mask.resize((w // 8, h // 8), resample=PIL.Image.BILINEAR) # LANCZOS)
mask = np.array(mask).astype(np.float32) / 255.0 mask = np.array(mask).astype(np.float32) / 255.0
mask = np.tile(mask, (4, 1, 1)) mask = np.tile(mask, (4, 1, 1))
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
@@ -1862,6 +1894,7 @@ class BatchDataExt(NamedTuple):
class BatchData(NamedTuple): class BatchData(NamedTuple):
return_latents: bool
base: BatchDataBase base: BatchDataBase
ext: BatchDataExt ext: BatchDataExt
@@ -2296,9 +2329,9 @@ def main(args):
# highres_fixの処理 # highres_fixの処理
if highres_fix and not highres_1st: if highres_fix and not highres_1st:
# 1st stageのバッチを作成して呼び出すサイズを小さくして呼び出す # 1st stageのバッチを作成して呼び出すサイズを小さくして呼び出す
print("process 1st stage1") print("process 1st stage")
batch_1st = [] batch_1st = []
for base, ext in batch: for _, base, ext in batch:
width_1st = int(ext.width * args.highres_fix_scale + .5) width_1st = int(ext.width * args.highres_fix_scale + .5)
height_1st = int(ext.height * args.highres_fix_scale + .5) height_1st = int(ext.height * args.highres_fix_scale + .5)
width_1st = width_1st - width_1st % 32 width_1st = width_1st - width_1st % 32
@@ -2306,20 +2339,29 @@ def main(args):
ext_1st = BatchDataExt(width_1st, height_1st, args.highres_fix_steps, ext.scale, ext_1st = BatchDataExt(width_1st, height_1st, args.highres_fix_steps, ext.scale,
ext.negative_scale, ext.strength, ext.network_muls) ext.negative_scale, ext.strength, ext.network_muls)
batch_1st.append(BatchData(base, ext_1st)) batch_1st.append(BatchData(args.highres_fix_latents_upscaling, base, ext_1st))
images_1st = process_batch(batch_1st, True, True) images_1st = process_batch(batch_1st, True, True)
# 2nd stageのバッチを作成して以下処理する # 2nd stageのバッチを作成して以下処理する
print("process 2nd stage1") print("process 2nd stage")
if args.highres_fix_latents_upscaling:
org_dtype = images_1st.dtype
if images_1st.dtype == torch.bfloat16:
images_1st = images_1st.to(torch.float) # interpolateがbf16をサポートしていない
images_1st = torch.nn.functional.interpolate(
images_1st, (batch[0].ext.height // 8, batch[0].ext.width // 8), mode='bilinear') # , antialias=True)
images_1st = images_1st.to(org_dtype)
batch_2nd = [] batch_2nd = []
for i, (bd, image) in enumerate(zip(batch, images_1st)): for i, (bd, image) in enumerate(zip(batch, images_1st)):
image = image.resize((bd.ext.width, bd.ext.height), resample=PIL.Image.LANCZOS) # img2imgとして設定 if not args.highres_fix_latents_upscaling:
bd_2nd = BatchData(BatchDataBase(*bd.base[0:3], bd.base.seed+1, image, None, *bd.base[6:]), bd.ext) image = image.resize((bd.ext.width, bd.ext.height), resample=PIL.Image.LANCZOS) # img2imgとして設定
bd_2nd = BatchData(False, BatchDataBase(*bd.base[0:3], bd.base.seed+1, image, None, *bd.base[6:]), bd.ext)
batch_2nd.append(bd_2nd) batch_2nd.append(bd_2nd)
batch = batch_2nd batch = batch_2nd
# このバッチの情報を取り出す # このバッチの情報を取り出す
(step_first, _, _, _, init_image, mask_image, _, guide_image), \ return_latents, (step_first, _, _, _, init_image, mask_image, _, guide_image), \
(width, height, steps, scale, negative_scale, strength, network_muls) = batch[0] (width, height, steps, scale, negative_scale, strength, network_muls) = batch[0]
noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR) noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)
@@ -2353,7 +2395,7 @@ def main(args):
all_images_are_same = True all_images_are_same = True
all_masks_are_same = True all_masks_are_same = True
all_guide_images_are_same = True all_guide_images_are_same = True
for i, ((_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch): for i, (_, (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch):
prompts.append(prompt) prompts.append(prompt)
negative_prompts.append(negative_prompt) negative_prompts.append(negative_prompt)
seeds.append(seed) seeds.append(seed)
@@ -2413,8 +2455,10 @@ def main(args):
n.set_multiplier(m) n.set_multiplier(m)
images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, negative_scale, strength, latents=start_code, images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, negative_scale, strength, latents=start_code,
output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises, clip_prompts=clip_prompts, clip_guide_images=guide_images)[0] output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises,
if highres_1st and not args.highres_fix_save_1st: vae_batch_size=args.vae_batch_size, return_latents=return_latents,
clip_prompts=clip_prompts, clip_guide_images=guide_images)[0]
if highres_1st and not args.highres_fix_save_1st: # return images or latents
return images return images
# save image # save image
@@ -2612,9 +2656,9 @@ def main(args):
print("Use previous image as guide image.") print("Use previous image as guide image.")
guide_image = prev_image guide_image = prev_image
b1 = BatchData(BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), b1 = BatchData(False, BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image),
BatchDataExt(width, height, steps, scale, negative_scale, strength, tuple(network_muls) if network_muls else None)) BatchDataExt(width, height, steps, scale, negative_scale, strength, tuple(network_muls) if network_muls else None))
if len(batch_data) > 0 and batch_data[-1][1] != b1[1]: # バッチ分割必要? if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要?
process_batch(batch_data, highres_fix) process_batch(batch_data, highres_fix)
batch_data.clear() batch_data.clear()
@@ -2658,6 +2702,8 @@ if __name__ == '__main__':
parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ") parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ")
parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅") parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅")
parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ") parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ")
parser.add_argument("--vae_batch_size", type=float, default=None,
help="batch size for VAE, < 1.0 for ratio / VAE処理時のバッチサイズ、1未満の値の場合は通常バッチサイズの比率")
parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数") parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数")
parser.add_argument('--sampler', type=str, default='ddim', parser.add_argument('--sampler', type=str, default='ddim',
choices=['ddim', 'pndm', 'lms', 'euler', 'euler_a', 'heun', 'dpm_2', 'dpm_2_a', 'dpmsolver', choices=['ddim', 'pndm', 'lms', 'euler', 'euler_a', 'heun', 'dpm_2', 'dpm_2_a', 'dpmsolver',
@@ -2713,6 +2759,8 @@ if __name__ == '__main__':
help="1st stage steps for highres fix / highres fixの最初のステージのステップ数") help="1st stage steps for highres fix / highres fixの最初のステージのステップ数")
parser.add_argument("--highres_fix_save_1st", action='store_true', parser.add_argument("--highres_fix_save_1st", action='store_true',
help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する") help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する")
parser.add_argument("--highres_fix_latents_upscaling", action='store_true',
help="use latents upscaling for highres fix / highres fixでlatentで拡大する")
parser.add_argument("--negative_scale", type=float, default=None, parser.add_argument("--negative_scale", type=float, default=None,
help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する") help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する")