fix face_crop_aug not working on finetune method, prepare upscaler

This commit is contained in:
Kohya S
2023-04-22 10:41:36 +09:00
parent 220436244c
commit 884e6bff5d
3 changed files with 403 additions and 10 deletions

View File

@@ -945,7 +945,7 @@ class PipelineLike:
# encode the init image into latents and scale the latents
init_image = init_image.to(device=self.device, dtype=latents_dtype)
if init_image.size()[2:] == (height // 8, width // 8):
if init_image.size()[1:] == (height // 8, width // 8):
init_latents = init_image
else:
if vae_batch_size >= batch_size:
@@ -1015,7 +1015,7 @@ class PipelineLike:
if self.control_nets:
if reginonal_network:
num_sub_and_neg_prompts = len(text_embeddings) // batch_size
text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2::num_sub_and_neg_prompts] # last subprompt
text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt
else:
text_emb_last = text_embeddings
noise_pred = original_control_net.call_unet_and_control_net(
@@ -2318,6 +2318,22 @@ def main(args):
else:
networks = []
# upscalerの指定があれば取得する
upscaler = None
if args.highres_fix_upscaler:
print("import upscaler module:", args.highres_fix_upscaler)
imported_module = importlib.import_module(args.highres_fix_upscaler)
us_kwargs = {}
if args.highres_fix_upscaler_args:
for net_arg in args.highres_fix_upscaler_args.split(";"):
key, value = net_arg.split("=")
us_kwargs[key] = value
print("create upscaler")
upscaler = imported_module.create_upscaler(**us_kwargs)
upscaler.to(dtype).to(device)
# ControlNetの処理
control_nets: List[ControlNetInfo] = []
if args.control_net_models:
@@ -2590,7 +2606,7 @@ def main(args):
np_mask = np_mask[:, :, i]
size = np_mask.shape
else:
np_mask = np.full(size, 255, dtype=np.uint8)
np_mask = np.full(size, 255, dtype=np.uint8)
mask = torch.from_numpy(np_mask.astype(np.float32) / 255.0)
network.set_region(i, i == len(networks) - 1, mask)
mask_images = None
@@ -2639,6 +2655,8 @@ def main(args):
# highres_fixの処理
if highres_fix and not highres_1st:
# 1st stageのバッチを作成して呼び出すサイズを小さくして呼び出す
is_1st_latent = upscaler.support_latents() if upscaler else args.highres_fix_latents_upscaling
print("process 1st stage")
batch_1st = []
for _, base, ext in batch:
@@ -2657,12 +2675,32 @@ def main(args):
ext.network_muls,
ext.num_sub_prompts,
)
batch_1st.append(BatchData(args.highres_fix_latents_upscaling, base, ext_1st))
batch_1st.append(BatchData(is_1st_latent, base, ext_1st))
images_1st = process_batch(batch_1st, True, True)
# 2nd stageのバッチを作成して以下処理する
print("process 2nd stage")
if args.highres_fix_latents_upscaling:
width_2nd, height_2nd = batch[0].ext.width, batch[0].ext.height
if upscaler:
# upscalerを使って画像を拡大する
lowreso_imgs = None if is_1st_latent else images_1st
lowreso_latents = None if not is_1st_latent else images_1st
# 戻り値はPIL.Image.Imageかtorch.Tensorのlatents
batch_size = len(images_1st)
vae_batch_size = (
batch_size
if args.vae_batch_size is None
else (max(1, int(batch_size * args.vae_batch_size)) if args.vae_batch_size < 1 else args.vae_batch_size)
)
vae_batch_size = int(vae_batch_size)
images_1st = upscaler.upscale(
vae, lowreso_imgs, lowreso_latents, dtype, width_2nd, height_2nd, batch_size, vae_batch_size
)
elif args.highres_fix_latents_upscaling:
# latentを拡大する
org_dtype = images_1st.dtype
if images_1st.dtype == torch.bfloat16:
images_1st = images_1st.to(torch.float) # interpolateがbf16をサポートしていない
@@ -2671,10 +2709,12 @@ def main(args):
) # , antialias=True)
images_1st = images_1st.to(org_dtype)
else:
# 画像をLANCZOSで拡大する
images_1st = [image.resize((width_2nd, height_2nd), resample=PIL.Image.LANCZOS) for image in images_1st]
batch_2nd = []
for i, (bd, image) in enumerate(zip(batch, images_1st)):
if not args.highres_fix_latents_upscaling:
image = image.resize((bd.ext.width, bd.ext.height), resample=PIL.Image.LANCZOS) # img2imgとして設定
bd_2nd = BatchData(False, BatchDataBase(*bd.base[0:3], bd.base.seed + 1, image, None, *bd.base[6:]), bd.ext)
batch_2nd.append(bd_2nd)
batch = batch_2nd
@@ -3229,6 +3269,16 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="use latents upscaling for highres fix / highres fixでlatentで拡大する",
)
parser.add_argument(
"--highres_fix_upscaler", type=str, default=None, help="upscaler module for highres fix / highres fixで使うupscalerのモジュール名"
)
parser.add_argument(
"--highres_fix_upscaler_args",
type=str,
default=None,
help="additional argmuments for upscaler (key=value) / upscalerへの追加の引数",
)
parser.add_argument(
"--negative_scale", type=float, default=None, help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する"
)