mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
latents upscaling in highres fix, vae batch size
This commit is contained in:
@@ -589,6 +589,8 @@ class PipelineLike():
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
max_embeddings_multiples: Optional[int] = 3,
|
||||
output_type: Optional[str] = "pil",
|
||||
vae_batch_size: float = None,
|
||||
return_latents: bool = False,
|
||||
# return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
||||
@@ -680,6 +682,9 @@ class PipelineLike():
|
||||
else:
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
vae_batch_size = batch_size if vae_batch_size is None else (
|
||||
int(vae_batch_size) if vae_batch_size >= 1 else max(1, int(batch_size * vae_batch_size)))
|
||||
|
||||
if strength < 0 or strength > 1:
|
||||
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
||||
|
||||
@@ -793,7 +798,6 @@ class PipelineLike():
|
||||
latents_dtype = text_embeddings.dtype
|
||||
init_latents_orig = None
|
||||
mask = None
|
||||
noise = None
|
||||
|
||||
if init_image is None:
|
||||
# get the initial random noise unless the user supplied it
|
||||
@@ -825,6 +829,8 @@ class PipelineLike():
|
||||
if isinstance(init_image[0], PIL.Image.Image):
|
||||
init_image = [preprocess_image(im) for im in init_image]
|
||||
init_image = torch.cat(init_image)
|
||||
if isinstance(init_image, list):
|
||||
init_image = torch.stack(init_image)
|
||||
|
||||
# mask image to tensor
|
||||
if mask_image is not None:
|
||||
@@ -835,9 +841,24 @@ class PipelineLike():
|
||||
|
||||
# encode the init image into latents and scale the latents
|
||||
init_image = init_image.to(device=self.device, dtype=latents_dtype)
|
||||
if init_image.size()[2:] == (height // 8, width // 8):
|
||||
init_latents = init_image
|
||||
else:
|
||||
if vae_batch_size >= batch_size:
|
||||
init_latent_dist = self.vae.encode(init_image).latent_dist
|
||||
init_latents = init_latent_dist.sample(generator=generator)
|
||||
else:
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
init_latents = []
|
||||
for i in tqdm(range(0, batch_size, vae_batch_size)):
|
||||
init_latent_dist = self.vae.encode(init_image[i:i + vae_batch_size]
|
||||
if vae_batch_size > 1 else init_image[i].unsqueeze(0)).latent_dist
|
||||
init_latents.append(init_latent_dist.sample(generator=generator))
|
||||
init_latents = torch.cat(init_latents)
|
||||
|
||||
init_latents = 0.18215 * init_latents
|
||||
|
||||
if len(init_latents) == 1:
|
||||
init_latents = init_latents.repeat((batch_size, 1, 1, 1))
|
||||
init_latents_orig = init_latents
|
||||
@@ -932,8 +953,19 @@ class PipelineLike():
|
||||
if is_cancelled_callback is not None and is_cancelled_callback():
|
||||
return None
|
||||
|
||||
if return_latents:
|
||||
return (latents, False)
|
||||
|
||||
latents = 1 / 0.18215 * latents
|
||||
if vae_batch_size >= batch_size:
|
||||
image = self.vae.decode(latents).sample
|
||||
else:
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
images = []
|
||||
for i in tqdm(range(0, batch_size, vae_batch_size)):
|
||||
images.append(self.vae.decode(latents[i:i + vae_batch_size] if vae_batch_size > 1 else latents[i].unsqueeze(0)).sample)
|
||||
image = torch.cat(images)
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
@@ -1820,7 +1852,7 @@ def preprocess_mask(mask):
|
||||
mask = mask.convert("L")
|
||||
w, h = mask.size
|
||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||
mask = mask.resize((w // 8, h // 8), resample=PIL.Image.LANCZOS)
|
||||
mask = mask.resize((w // 8, h // 8), resample=PIL.Image.BILINEAR) # LANCZOS)
|
||||
mask = np.array(mask).astype(np.float32) / 255.0
|
||||
mask = np.tile(mask, (4, 1, 1))
|
||||
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
|
||||
@@ -1862,6 +1894,7 @@ class BatchDataExt(NamedTuple):
|
||||
|
||||
|
||||
class BatchData(NamedTuple):
|
||||
return_latents: bool
|
||||
base: BatchDataBase
|
||||
ext: BatchDataExt
|
||||
|
||||
@@ -2296,9 +2329,9 @@ def main(args):
|
||||
# highres_fixの処理
|
||||
if highres_fix and not highres_1st:
|
||||
# 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す
|
||||
print("process 1st stage1")
|
||||
print("process 1st stage")
|
||||
batch_1st = []
|
||||
for base, ext in batch:
|
||||
for _, base, ext in batch:
|
||||
width_1st = int(ext.width * args.highres_fix_scale + .5)
|
||||
height_1st = int(ext.height * args.highres_fix_scale + .5)
|
||||
width_1st = width_1st - width_1st % 32
|
||||
@@ -2306,20 +2339,29 @@ def main(args):
|
||||
|
||||
ext_1st = BatchDataExt(width_1st, height_1st, args.highres_fix_steps, ext.scale,
|
||||
ext.negative_scale, ext.strength, ext.network_muls)
|
||||
batch_1st.append(BatchData(base, ext_1st))
|
||||
batch_1st.append(BatchData(args.highres_fix_latents_upscaling, base, ext_1st))
|
||||
images_1st = process_batch(batch_1st, True, True)
|
||||
|
||||
# 2nd stageのバッチを作成して以下処理する
|
||||
print("process 2nd stage1")
|
||||
print("process 2nd stage")
|
||||
if args.highres_fix_latents_upscaling:
|
||||
org_dtype = images_1st.dtype
|
||||
if images_1st.dtype == torch.bfloat16:
|
||||
images_1st = images_1st.to(torch.float) # interpolateがbf16をサポートしていない
|
||||
images_1st = torch.nn.functional.interpolate(
|
||||
images_1st, (batch[0].ext.height // 8, batch[0].ext.width // 8), mode='bilinear') # , antialias=True)
|
||||
images_1st = images_1st.to(org_dtype)
|
||||
|
||||
batch_2nd = []
|
||||
for i, (bd, image) in enumerate(zip(batch, images_1st)):
|
||||
if not args.highres_fix_latents_upscaling:
|
||||
image = image.resize((bd.ext.width, bd.ext.height), resample=PIL.Image.LANCZOS) # img2imgとして設定
|
||||
bd_2nd = BatchData(BatchDataBase(*bd.base[0:3], bd.base.seed+1, image, None, *bd.base[6:]), bd.ext)
|
||||
bd_2nd = BatchData(False, BatchDataBase(*bd.base[0:3], bd.base.seed+1, image, None, *bd.base[6:]), bd.ext)
|
||||
batch_2nd.append(bd_2nd)
|
||||
batch = batch_2nd
|
||||
|
||||
# このバッチの情報を取り出す
|
||||
(step_first, _, _, _, init_image, mask_image, _, guide_image), \
|
||||
return_latents, (step_first, _, _, _, init_image, mask_image, _, guide_image), \
|
||||
(width, height, steps, scale, negative_scale, strength, network_muls) = batch[0]
|
||||
noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)
|
||||
|
||||
@@ -2353,7 +2395,7 @@ def main(args):
|
||||
all_images_are_same = True
|
||||
all_masks_are_same = True
|
||||
all_guide_images_are_same = True
|
||||
for i, ((_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch):
|
||||
for i, (_, (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch):
|
||||
prompts.append(prompt)
|
||||
negative_prompts.append(negative_prompt)
|
||||
seeds.append(seed)
|
||||
@@ -2413,8 +2455,10 @@ def main(args):
|
||||
n.set_multiplier(m)
|
||||
|
||||
images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, negative_scale, strength, latents=start_code,
|
||||
output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises, clip_prompts=clip_prompts, clip_guide_images=guide_images)[0]
|
||||
if highres_1st and not args.highres_fix_save_1st:
|
||||
output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises,
|
||||
vae_batch_size=args.vae_batch_size, return_latents=return_latents,
|
||||
clip_prompts=clip_prompts, clip_guide_images=guide_images)[0]
|
||||
if highres_1st and not args.highres_fix_save_1st: # return images or latents
|
||||
return images
|
||||
|
||||
# save image
|
||||
@@ -2612,9 +2656,9 @@ def main(args):
|
||||
print("Use previous image as guide image.")
|
||||
guide_image = prev_image
|
||||
|
||||
b1 = BatchData(BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image),
|
||||
b1 = BatchData(False, BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image),
|
||||
BatchDataExt(width, height, steps, scale, negative_scale, strength, tuple(network_muls) if network_muls else None))
|
||||
if len(batch_data) > 0 and batch_data[-1][1] != b1[1]: # バッチ分割必要?
|
||||
if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要?
|
||||
process_batch(batch_data, highres_fix)
|
||||
batch_data.clear()
|
||||
|
||||
@@ -2658,6 +2702,8 @@ if __name__ == '__main__':
|
||||
parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ")
|
||||
parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ")
|
||||
parser.add_argument("--vae_batch_size", type=float, default=None,
|
||||
help="batch size for VAE, < 1.0 for ratio / VAE処理時のバッチサイズ、1未満の値の場合は通常バッチサイズの比率")
|
||||
parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数")
|
||||
parser.add_argument('--sampler', type=str, default='ddim',
|
||||
choices=['ddim', 'pndm', 'lms', 'euler', 'euler_a', 'heun', 'dpm_2', 'dpm_2_a', 'dpmsolver',
|
||||
@@ -2713,6 +2759,8 @@ if __name__ == '__main__':
|
||||
help="1st stage steps for highres fix / highres fixの最初のステージのステップ数")
|
||||
parser.add_argument("--highres_fix_save_1st", action='store_true',
|
||||
help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する")
|
||||
parser.add_argument("--highres_fix_latents_upscaling", action='store_true',
|
||||
help="use latents upscaling for highres fix / highres fixでlatentで拡大する")
|
||||
parser.add_argument("--negative_scale", type=float, default=None,
|
||||
help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user