mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
latents upscaling in highres fix, vae batch size
This commit is contained in:
@@ -589,6 +589,8 @@ class PipelineLike():
|
|||||||
latents: Optional[torch.FloatTensor] = None,
|
latents: Optional[torch.FloatTensor] = None,
|
||||||
max_embeddings_multiples: Optional[int] = 3,
|
max_embeddings_multiples: Optional[int] = 3,
|
||||||
output_type: Optional[str] = "pil",
|
output_type: Optional[str] = "pil",
|
||||||
|
vae_batch_size: float = None,
|
||||||
|
return_latents: bool = False,
|
||||||
# return_dict: bool = True,
|
# return_dict: bool = True,
|
||||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||||
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
||||||
@@ -680,6 +682,9 @@ class PipelineLike():
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||||
|
|
||||||
|
vae_batch_size = batch_size if vae_batch_size is None else (
|
||||||
|
int(vae_batch_size) if vae_batch_size >= 1 else max(1, int(batch_size * vae_batch_size)))
|
||||||
|
|
||||||
if strength < 0 or strength > 1:
|
if strength < 0 or strength > 1:
|
||||||
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
||||||
|
|
||||||
@@ -793,7 +798,6 @@ class PipelineLike():
|
|||||||
latents_dtype = text_embeddings.dtype
|
latents_dtype = text_embeddings.dtype
|
||||||
init_latents_orig = None
|
init_latents_orig = None
|
||||||
mask = None
|
mask = None
|
||||||
noise = None
|
|
||||||
|
|
||||||
if init_image is None:
|
if init_image is None:
|
||||||
# get the initial random noise unless the user supplied it
|
# get the initial random noise unless the user supplied it
|
||||||
@@ -825,6 +829,8 @@ class PipelineLike():
|
|||||||
if isinstance(init_image[0], PIL.Image.Image):
|
if isinstance(init_image[0], PIL.Image.Image):
|
||||||
init_image = [preprocess_image(im) for im in init_image]
|
init_image = [preprocess_image(im) for im in init_image]
|
||||||
init_image = torch.cat(init_image)
|
init_image = torch.cat(init_image)
|
||||||
|
if isinstance(init_image, list):
|
||||||
|
init_image = torch.stack(init_image)
|
||||||
|
|
||||||
# mask image to tensor
|
# mask image to tensor
|
||||||
if mask_image is not None:
|
if mask_image is not None:
|
||||||
@@ -835,9 +841,24 @@ class PipelineLike():
|
|||||||
|
|
||||||
# encode the init image into latents and scale the latents
|
# encode the init image into latents and scale the latents
|
||||||
init_image = init_image.to(device=self.device, dtype=latents_dtype)
|
init_image = init_image.to(device=self.device, dtype=latents_dtype)
|
||||||
init_latent_dist = self.vae.encode(init_image).latent_dist
|
if init_image.size()[2:] == (height // 8, width // 8):
|
||||||
init_latents = init_latent_dist.sample(generator=generator)
|
init_latents = init_image
|
||||||
init_latents = 0.18215 * init_latents
|
else:
|
||||||
|
if vae_batch_size >= batch_size:
|
||||||
|
init_latent_dist = self.vae.encode(init_image).latent_dist
|
||||||
|
init_latents = init_latent_dist.sample(generator=generator)
|
||||||
|
else:
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
init_latents = []
|
||||||
|
for i in tqdm(range(0, batch_size, vae_batch_size)):
|
||||||
|
init_latent_dist = self.vae.encode(init_image[i:i + vae_batch_size]
|
||||||
|
if vae_batch_size > 1 else init_image[i].unsqueeze(0)).latent_dist
|
||||||
|
init_latents.append(init_latent_dist.sample(generator=generator))
|
||||||
|
init_latents = torch.cat(init_latents)
|
||||||
|
|
||||||
|
init_latents = 0.18215 * init_latents
|
||||||
|
|
||||||
if len(init_latents) == 1:
|
if len(init_latents) == 1:
|
||||||
init_latents = init_latents.repeat((batch_size, 1, 1, 1))
|
init_latents = init_latents.repeat((batch_size, 1, 1, 1))
|
||||||
init_latents_orig = init_latents
|
init_latents_orig = init_latents
|
||||||
@@ -932,8 +953,19 @@ class PipelineLike():
|
|||||||
if is_cancelled_callback is not None and is_cancelled_callback():
|
if is_cancelled_callback is not None and is_cancelled_callback():
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
if return_latents:
|
||||||
|
return (latents, False)
|
||||||
|
|
||||||
latents = 1 / 0.18215 * latents
|
latents = 1 / 0.18215 * latents
|
||||||
image = self.vae.decode(latents).sample
|
if vae_batch_size >= batch_size:
|
||||||
|
image = self.vae.decode(latents).sample
|
||||||
|
else:
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
images = []
|
||||||
|
for i in tqdm(range(0, batch_size, vae_batch_size)):
|
||||||
|
images.append(self.vae.decode(latents[i:i + vae_batch_size] if vae_batch_size > 1 else latents[i].unsqueeze(0)).sample)
|
||||||
|
image = torch.cat(images)
|
||||||
|
|
||||||
image = (image / 2 + 0.5).clamp(0, 1)
|
image = (image / 2 + 0.5).clamp(0, 1)
|
||||||
|
|
||||||
@@ -1820,7 +1852,7 @@ def preprocess_mask(mask):
|
|||||||
mask = mask.convert("L")
|
mask = mask.convert("L")
|
||||||
w, h = mask.size
|
w, h = mask.size
|
||||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||||
mask = mask.resize((w // 8, h // 8), resample=PIL.Image.LANCZOS)
|
mask = mask.resize((w // 8, h // 8), resample=PIL.Image.BILINEAR) # LANCZOS)
|
||||||
mask = np.array(mask).astype(np.float32) / 255.0
|
mask = np.array(mask).astype(np.float32) / 255.0
|
||||||
mask = np.tile(mask, (4, 1, 1))
|
mask = np.tile(mask, (4, 1, 1))
|
||||||
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
|
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
|
||||||
@@ -1862,6 +1894,7 @@ class BatchDataExt(NamedTuple):
|
|||||||
|
|
||||||
|
|
||||||
class BatchData(NamedTuple):
|
class BatchData(NamedTuple):
|
||||||
|
return_latents: bool
|
||||||
base: BatchDataBase
|
base: BatchDataBase
|
||||||
ext: BatchDataExt
|
ext: BatchDataExt
|
||||||
|
|
||||||
@@ -2296,9 +2329,9 @@ def main(args):
|
|||||||
# highres_fixの処理
|
# highres_fixの処理
|
||||||
if highres_fix and not highres_1st:
|
if highres_fix and not highres_1st:
|
||||||
# 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す
|
# 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す
|
||||||
print("process 1st stage1")
|
print("process 1st stage")
|
||||||
batch_1st = []
|
batch_1st = []
|
||||||
for base, ext in batch:
|
for _, base, ext in batch:
|
||||||
width_1st = int(ext.width * args.highres_fix_scale + .5)
|
width_1st = int(ext.width * args.highres_fix_scale + .5)
|
||||||
height_1st = int(ext.height * args.highres_fix_scale + .5)
|
height_1st = int(ext.height * args.highres_fix_scale + .5)
|
||||||
width_1st = width_1st - width_1st % 32
|
width_1st = width_1st - width_1st % 32
|
||||||
@@ -2306,20 +2339,29 @@ def main(args):
|
|||||||
|
|
||||||
ext_1st = BatchDataExt(width_1st, height_1st, args.highres_fix_steps, ext.scale,
|
ext_1st = BatchDataExt(width_1st, height_1st, args.highres_fix_steps, ext.scale,
|
||||||
ext.negative_scale, ext.strength, ext.network_muls)
|
ext.negative_scale, ext.strength, ext.network_muls)
|
||||||
batch_1st.append(BatchData(base, ext_1st))
|
batch_1st.append(BatchData(args.highres_fix_latents_upscaling, base, ext_1st))
|
||||||
images_1st = process_batch(batch_1st, True, True)
|
images_1st = process_batch(batch_1st, True, True)
|
||||||
|
|
||||||
# 2nd stageのバッチを作成して以下処理する
|
# 2nd stageのバッチを作成して以下処理する
|
||||||
print("process 2nd stage1")
|
print("process 2nd stage")
|
||||||
|
if args.highres_fix_latents_upscaling:
|
||||||
|
org_dtype = images_1st.dtype
|
||||||
|
if images_1st.dtype == torch.bfloat16:
|
||||||
|
images_1st = images_1st.to(torch.float) # interpolateがbf16をサポートしていない
|
||||||
|
images_1st = torch.nn.functional.interpolate(
|
||||||
|
images_1st, (batch[0].ext.height // 8, batch[0].ext.width // 8), mode='bilinear') # , antialias=True)
|
||||||
|
images_1st = images_1st.to(org_dtype)
|
||||||
|
|
||||||
batch_2nd = []
|
batch_2nd = []
|
||||||
for i, (bd, image) in enumerate(zip(batch, images_1st)):
|
for i, (bd, image) in enumerate(zip(batch, images_1st)):
|
||||||
image = image.resize((bd.ext.width, bd.ext.height), resample=PIL.Image.LANCZOS) # img2imgとして設定
|
if not args.highres_fix_latents_upscaling:
|
||||||
bd_2nd = BatchData(BatchDataBase(*bd.base[0:3], bd.base.seed+1, image, None, *bd.base[6:]), bd.ext)
|
image = image.resize((bd.ext.width, bd.ext.height), resample=PIL.Image.LANCZOS) # img2imgとして設定
|
||||||
|
bd_2nd = BatchData(False, BatchDataBase(*bd.base[0:3], bd.base.seed+1, image, None, *bd.base[6:]), bd.ext)
|
||||||
batch_2nd.append(bd_2nd)
|
batch_2nd.append(bd_2nd)
|
||||||
batch = batch_2nd
|
batch = batch_2nd
|
||||||
|
|
||||||
# このバッチの情報を取り出す
|
# このバッチの情報を取り出す
|
||||||
(step_first, _, _, _, init_image, mask_image, _, guide_image), \
|
return_latents, (step_first, _, _, _, init_image, mask_image, _, guide_image), \
|
||||||
(width, height, steps, scale, negative_scale, strength, network_muls) = batch[0]
|
(width, height, steps, scale, negative_scale, strength, network_muls) = batch[0]
|
||||||
noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)
|
noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)
|
||||||
|
|
||||||
@@ -2353,7 +2395,7 @@ def main(args):
|
|||||||
all_images_are_same = True
|
all_images_are_same = True
|
||||||
all_masks_are_same = True
|
all_masks_are_same = True
|
||||||
all_guide_images_are_same = True
|
all_guide_images_are_same = True
|
||||||
for i, ((_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch):
|
for i, (_, (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch):
|
||||||
prompts.append(prompt)
|
prompts.append(prompt)
|
||||||
negative_prompts.append(negative_prompt)
|
negative_prompts.append(negative_prompt)
|
||||||
seeds.append(seed)
|
seeds.append(seed)
|
||||||
@@ -2413,8 +2455,10 @@ def main(args):
|
|||||||
n.set_multiplier(m)
|
n.set_multiplier(m)
|
||||||
|
|
||||||
images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, negative_scale, strength, latents=start_code,
|
images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, negative_scale, strength, latents=start_code,
|
||||||
output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises, clip_prompts=clip_prompts, clip_guide_images=guide_images)[0]
|
output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises,
|
||||||
if highres_1st and not args.highres_fix_save_1st:
|
vae_batch_size=args.vae_batch_size, return_latents=return_latents,
|
||||||
|
clip_prompts=clip_prompts, clip_guide_images=guide_images)[0]
|
||||||
|
if highres_1st and not args.highres_fix_save_1st: # return images or latents
|
||||||
return images
|
return images
|
||||||
|
|
||||||
# save image
|
# save image
|
||||||
@@ -2612,9 +2656,9 @@ def main(args):
|
|||||||
print("Use previous image as guide image.")
|
print("Use previous image as guide image.")
|
||||||
guide_image = prev_image
|
guide_image = prev_image
|
||||||
|
|
||||||
b1 = BatchData(BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image),
|
b1 = BatchData(False, BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image),
|
||||||
BatchDataExt(width, height, steps, scale, negative_scale, strength, tuple(network_muls) if network_muls else None))
|
BatchDataExt(width, height, steps, scale, negative_scale, strength, tuple(network_muls) if network_muls else None))
|
||||||
if len(batch_data) > 0 and batch_data[-1][1] != b1[1]: # バッチ分割必要?
|
if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要?
|
||||||
process_batch(batch_data, highres_fix)
|
process_batch(batch_data, highres_fix)
|
||||||
batch_data.clear()
|
batch_data.clear()
|
||||||
|
|
||||||
@@ -2658,6 +2702,8 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ")
|
parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ")
|
||||||
parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅")
|
parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅")
|
||||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ")
|
parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ")
|
||||||
|
parser.add_argument("--vae_batch_size", type=float, default=None,
|
||||||
|
help="batch size for VAE, < 1.0 for ratio / VAE処理時のバッチサイズ、1未満の値の場合は通常バッチサイズの比率")
|
||||||
parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数")
|
parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数")
|
||||||
parser.add_argument('--sampler', type=str, default='ddim',
|
parser.add_argument('--sampler', type=str, default='ddim',
|
||||||
choices=['ddim', 'pndm', 'lms', 'euler', 'euler_a', 'heun', 'dpm_2', 'dpm_2_a', 'dpmsolver',
|
choices=['ddim', 'pndm', 'lms', 'euler', 'euler_a', 'heun', 'dpm_2', 'dpm_2_a', 'dpmsolver',
|
||||||
@@ -2713,6 +2759,8 @@ if __name__ == '__main__':
|
|||||||
help="1st stage steps for highres fix / highres fixの最初のステージのステップ数")
|
help="1st stage steps for highres fix / highres fixの最初のステージのステップ数")
|
||||||
parser.add_argument("--highres_fix_save_1st", action='store_true',
|
parser.add_argument("--highres_fix_save_1st", action='store_true',
|
||||||
help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する")
|
help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する")
|
||||||
|
parser.add_argument("--highres_fix_latents_upscaling", action='store_true',
|
||||||
|
help="use latents upscaling for highres fix / highres fixでlatentで拡大する")
|
||||||
parser.add_argument("--negative_scale", type=float, default=None,
|
parser.add_argument("--negative_scale", type=float, default=None,
|
||||||
help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する")
|
help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user