add Deep Shrink

This commit is contained in:
Kohya S
2023-11-23 19:40:48 +09:00
parent d0923d6710
commit 6d6d86260b
4 changed files with 288 additions and 3 deletions

View File

@@ -2501,6 +2501,10 @@ def main(args):
if args.diffusers_xformers:
pipe.enable_xformers_memory_efficient_attention()
# Deep Shrink
if args.ds_depth_1 is not None:
unet.set_deep_shrink(args.ds_depth_1, args.ds_timesteps_1, args.ds_depth_2, args.ds_timesteps_2, args.ds_ratio)
# Extended Textual Inversion および Textual Inversionを処理する
if args.XTI_embeddings:
diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI
@@ -3085,6 +3089,13 @@ def main(args):
clip_prompt = None
network_muls = None
# Deep Shrink
ds_depth_1 = None # means no override
ds_timesteps_1 = args.ds_timesteps_1
ds_depth_2 = args.ds_depth_2
ds_timesteps_2 = args.ds_timesteps_2
ds_ratio = args.ds_ratio
prompt_args = raw_prompt.strip().split(" --")
prompt = prompt_args[0]
print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}")
@@ -3156,10 +3167,51 @@ def main(args):
print(f"network mul: {network_muls}")
continue
# Deep Shrink
m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink depth 1
ds_depth_1 = int(m.group(1))
print(f"deep shrink depth 1: {ds_depth_1}")
continue
m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink timesteps 1
ds_timesteps_1 = int(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
print(f"deep shrink timesteps 1: {ds_timesteps_1}")
continue
m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink depth 2
ds_depth_2 = int(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
print(f"deep shrink depth 2: {ds_depth_2}")
continue
m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink timesteps 2
ds_timesteps_2 = int(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
print(f"deep shrink timesteps 2: {ds_timesteps_2}")
continue
m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink ratio
ds_ratio = float(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
print(f"deep shrink ratio: {ds_ratio}")
continue
except ValueError as ex:
print(f"Exception in parsing / 解析エラー: {parg}")
print(ex)
# override Deep Shrink
if ds_depth_1 is not None:
if ds_depth_1 < 0:
ds_depth_1 = args.ds_depth_1 or 3
unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio)
# prepare seed
if seeds is not None: # given in prompt
# 数が足りないなら前のをそのまま使う
@@ -3509,6 +3561,30 @@ def setup_parser() -> argparse.ArgumentParser:
# "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像"
# )
# Deep Shrink
parser.add_argument(
"--ds_depth_1",
type=int,
default=None,
help="Enable Deep Shrink with this depth 1, valid values are 0 to 3 / Deep Shrinkをこのdepthで有効にする",
)
parser.add_argument(
"--ds_timesteps_1",
type=int,
default=650,
help="Apply Deep Shrink depth 1 until this timesteps / Deep Shrink depth 1を適用するtimesteps",
)
parser.add_argument("--ds_depth_2", type=int, default=None, help="Deep Shrink depth 2 / Deep Shrinkのdepth 2")
parser.add_argument(
"--ds_timesteps_2",
type=int,
default=650,
help="Apply Deep Shrink depth 2 until this timesteps / Deep Shrink depth 2を適用するtimesteps",
)
parser.add_argument(
"--ds_ratio", type=float, default=0.5, help="Deep Shrink ratio for downsampling / Deep Shrinkのdownsampling比率"
)
return parser