mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
add Deep Shrink
This commit is contained in:
@@ -2501,6 +2501,10 @@ def main(args):
|
||||
if args.diffusers_xformers:
|
||||
pipe.enable_xformers_memory_efficient_attention()
|
||||
|
||||
# Deep Shrink
|
||||
if args.ds_depth_1 is not None:
|
||||
unet.set_deep_shrink(args.ds_depth_1, args.ds_timesteps_1, args.ds_depth_2, args.ds_timesteps_2, args.ds_ratio)
|
||||
|
||||
# Extended Textual Inversion および Textual Inversionを処理する
|
||||
if args.XTI_embeddings:
|
||||
diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI
|
||||
@@ -3085,6 +3089,13 @@ def main(args):
|
||||
clip_prompt = None
|
||||
network_muls = None
|
||||
|
||||
# Deep Shrink
|
||||
ds_depth_1 = None # means no override
|
||||
ds_timesteps_1 = args.ds_timesteps_1
|
||||
ds_depth_2 = args.ds_depth_2
|
||||
ds_timesteps_2 = args.ds_timesteps_2
|
||||
ds_ratio = args.ds_ratio
|
||||
|
||||
prompt_args = raw_prompt.strip().split(" --")
|
||||
prompt = prompt_args[0]
|
||||
print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}")
|
||||
@@ -3156,10 +3167,51 @@ def main(args):
|
||||
print(f"network mul: {network_muls}")
|
||||
continue
|
||||
|
||||
# Deep Shrink
|
||||
m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # deep shrink depth 1
|
||||
ds_depth_1 = int(m.group(1))
|
||||
print(f"deep shrink depth 1: {ds_depth_1}")
|
||||
continue
|
||||
|
||||
m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # deep shrink timesteps 1
|
||||
ds_timesteps_1 = int(m.group(1))
|
||||
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
|
||||
print(f"deep shrink timesteps 1: {ds_timesteps_1}")
|
||||
continue
|
||||
|
||||
m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # deep shrink depth 2
|
||||
ds_depth_2 = int(m.group(1))
|
||||
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
|
||||
print(f"deep shrink depth 2: {ds_depth_2}")
|
||||
continue
|
||||
|
||||
m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # deep shrink timesteps 2
|
||||
ds_timesteps_2 = int(m.group(1))
|
||||
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
|
||||
print(f"deep shrink timesteps 2: {ds_timesteps_2}")
|
||||
continue
|
||||
|
||||
m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # deep shrink ratio
|
||||
ds_ratio = float(m.group(1))
|
||||
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
|
||||
print(f"deep shrink ratio: {ds_ratio}")
|
||||
continue
|
||||
|
||||
except ValueError as ex:
|
||||
print(f"Exception in parsing / 解析エラー: {parg}")
|
||||
print(ex)
|
||||
|
||||
# override Deep Shrink
|
||||
if ds_depth_1 is not None:
|
||||
if ds_depth_1 < 0:
|
||||
ds_depth_1 = args.ds_depth_1 or 3
|
||||
unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio)
|
||||
|
||||
# prepare seed
|
||||
if seeds is not None: # given in prompt
|
||||
# 数が足りないなら前のをそのまま使う
|
||||
@@ -3509,6 +3561,30 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
# "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像"
|
||||
# )
|
||||
|
||||
# Deep Shrink
|
||||
parser.add_argument(
|
||||
"--ds_depth_1",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Enable Deep Shrink with this depth 1, valid values are 0 to 3 / Deep Shrinkをこのdepthで有効にする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ds_timesteps_1",
|
||||
type=int,
|
||||
default=650,
|
||||
help="Apply Deep Shrink depth 1 until this timesteps / Deep Shrink depth 1を適用するtimesteps",
|
||||
)
|
||||
parser.add_argument("--ds_depth_2", type=int, default=None, help="Deep Shrink depth 2 / Deep Shrinkのdepth 2")
|
||||
parser.add_argument(
|
||||
"--ds_timesteps_2",
|
||||
type=int,
|
||||
default=650,
|
||||
help="Apply Deep Shrink depth 2 until this timesteps / Deep Shrink depth 2を適用するtimesteps",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ds_ratio", type=float, default=0.5, help="Deep Shrink ratio for downsampling / Deep Shrinkのdownsampling比率"
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user