mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
add Deep Shrink
This commit is contained in:
@@ -2501,6 +2501,10 @@ def main(args):
|
|||||||
if args.diffusers_xformers:
|
if args.diffusers_xformers:
|
||||||
pipe.enable_xformers_memory_efficient_attention()
|
pipe.enable_xformers_memory_efficient_attention()
|
||||||
|
|
||||||
|
# Deep Shrink
|
||||||
|
if args.ds_depth_1 is not None:
|
||||||
|
unet.set_deep_shrink(args.ds_depth_1, args.ds_timesteps_1, args.ds_depth_2, args.ds_timesteps_2, args.ds_ratio)
|
||||||
|
|
||||||
# Extended Textual Inversion および Textual Inversionを処理する
|
# Extended Textual Inversion および Textual Inversionを処理する
|
||||||
if args.XTI_embeddings:
|
if args.XTI_embeddings:
|
||||||
diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI
|
diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI
|
||||||
@@ -3085,6 +3089,13 @@ def main(args):
|
|||||||
clip_prompt = None
|
clip_prompt = None
|
||||||
network_muls = None
|
network_muls = None
|
||||||
|
|
||||||
|
# Deep Shrink
|
||||||
|
ds_depth_1 = None # means no override
|
||||||
|
ds_timesteps_1 = args.ds_timesteps_1
|
||||||
|
ds_depth_2 = args.ds_depth_2
|
||||||
|
ds_timesteps_2 = args.ds_timesteps_2
|
||||||
|
ds_ratio = args.ds_ratio
|
||||||
|
|
||||||
prompt_args = raw_prompt.strip().split(" --")
|
prompt_args = raw_prompt.strip().split(" --")
|
||||||
prompt = prompt_args[0]
|
prompt = prompt_args[0]
|
||||||
print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}")
|
print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}")
|
||||||
@@ -3156,10 +3167,51 @@ def main(args):
|
|||||||
print(f"network mul: {network_muls}")
|
print(f"network mul: {network_muls}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Deep Shrink
|
||||||
|
m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
|
if m: # deep shrink depth 1
|
||||||
|
ds_depth_1 = int(m.group(1))
|
||||||
|
print(f"deep shrink depth 1: {ds_depth_1}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
|
if m: # deep shrink timesteps 1
|
||||||
|
ds_timesteps_1 = int(m.group(1))
|
||||||
|
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
|
||||||
|
print(f"deep shrink timesteps 1: {ds_timesteps_1}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
|
if m: # deep shrink depth 2
|
||||||
|
ds_depth_2 = int(m.group(1))
|
||||||
|
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
|
||||||
|
print(f"deep shrink depth 2: {ds_depth_2}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
|
if m: # deep shrink timesteps 2
|
||||||
|
ds_timesteps_2 = int(m.group(1))
|
||||||
|
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
|
||||||
|
print(f"deep shrink timesteps 2: {ds_timesteps_2}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
|
if m: # deep shrink ratio
|
||||||
|
ds_ratio = float(m.group(1))
|
||||||
|
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
|
||||||
|
print(f"deep shrink ratio: {ds_ratio}")
|
||||||
|
continue
|
||||||
|
|
||||||
except ValueError as ex:
|
except ValueError as ex:
|
||||||
print(f"Exception in parsing / 解析エラー: {parg}")
|
print(f"Exception in parsing / 解析エラー: {parg}")
|
||||||
print(ex)
|
print(ex)
|
||||||
|
|
||||||
|
# override Deep Shrink
|
||||||
|
if ds_depth_1 is not None:
|
||||||
|
if ds_depth_1 < 0:
|
||||||
|
ds_depth_1 = args.ds_depth_1 or 3
|
||||||
|
unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio)
|
||||||
|
|
||||||
# prepare seed
|
# prepare seed
|
||||||
if seeds is not None: # given in prompt
|
if seeds is not None: # given in prompt
|
||||||
# 数が足りないなら前のをそのまま使う
|
# 数が足りないなら前のをそのまま使う
|
||||||
@@ -3509,6 +3561,30 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
# "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像"
|
# "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像"
|
||||||
# )
|
# )
|
||||||
|
|
||||||
|
# Deep Shrink
|
||||||
|
parser.add_argument(
|
||||||
|
"--ds_depth_1",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Enable Deep Shrink with this depth 1, valid values are 0 to 3 / Deep Shrinkをこのdepthで有効にする",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ds_timesteps_1",
|
||||||
|
type=int,
|
||||||
|
default=650,
|
||||||
|
help="Apply Deep Shrink depth 1 until this timesteps / Deep Shrink depth 1を適用するtimesteps",
|
||||||
|
)
|
||||||
|
parser.add_argument("--ds_depth_2", type=int, default=None, help="Deep Shrink depth 2 / Deep Shrinkのdepth 2")
|
||||||
|
parser.add_argument(
|
||||||
|
"--ds_timesteps_2",
|
||||||
|
type=int,
|
||||||
|
default=650,
|
||||||
|
help="Apply Deep Shrink depth 2 until this timesteps / Deep Shrink depth 2を適用するtimesteps",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ds_ratio", type=float, default=0.5, help="Deep Shrink ratio for downsampling / Deep Shrinkのdownsampling比率"
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -361,6 +361,23 @@ def get_timestep_embedding(
|
|||||||
return emb
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
# Deep Shrink: We do not common this function, because minimize dependencies.
|
||||||
|
def resize_like(x, target, mode="bicubic", align_corners=False):
|
||||||
|
org_dtype = x.dtype
|
||||||
|
if org_dtype == torch.bfloat16:
|
||||||
|
x = x.to(torch.float32)
|
||||||
|
|
||||||
|
if x.shape[-2:] != target.shape[-2:]:
|
||||||
|
if mode == "nearest":
|
||||||
|
x = F.interpolate(x, size=target.shape[-2:], mode=mode)
|
||||||
|
else:
|
||||||
|
x = F.interpolate(x, size=target.shape[-2:], mode=mode, align_corners=align_corners)
|
||||||
|
|
||||||
|
if org_dtype == torch.bfloat16:
|
||||||
|
x = x.to(org_dtype)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class SampleOutput:
|
class SampleOutput:
|
||||||
def __init__(self, sample):
|
def __init__(self, sample):
|
||||||
self.sample = sample
|
self.sample = sample
|
||||||
@@ -1130,6 +1147,11 @@ class UpBlock2D(nn.Module):
|
|||||||
# pop res hidden states
|
# pop res hidden states
|
||||||
res_hidden_states = res_hidden_states_tuple[-1]
|
res_hidden_states = res_hidden_states_tuple[-1]
|
||||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||||
|
|
||||||
|
# Deep Shrink
|
||||||
|
if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]:
|
||||||
|
hidden_states = resize_like(hidden_states, res_hidden_states)
|
||||||
|
|
||||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||||
|
|
||||||
if self.training and self.gradient_checkpointing:
|
if self.training and self.gradient_checkpointing:
|
||||||
@@ -1221,6 +1243,11 @@ class CrossAttnUpBlock2D(nn.Module):
|
|||||||
# pop res hidden states
|
# pop res hidden states
|
||||||
res_hidden_states = res_hidden_states_tuple[-1]
|
res_hidden_states = res_hidden_states_tuple[-1]
|
||||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||||
|
|
||||||
|
# Deep Shrink
|
||||||
|
if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]:
|
||||||
|
hidden_states = resize_like(hidden_states, res_hidden_states)
|
||||||
|
|
||||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||||
|
|
||||||
if self.training and self.gradient_checkpointing:
|
if self.training and self.gradient_checkpointing:
|
||||||
@@ -1417,6 +1444,31 @@ class UNet2DConditionModel(nn.Module):
|
|||||||
self.conv_act = nn.SiLU()
|
self.conv_act = nn.SiLU()
|
||||||
self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1)
|
self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1)
|
||||||
|
|
||||||
|
# Deep Shrink
|
||||||
|
self.ds_depth_1 = None
|
||||||
|
self.ds_depth_2 = None
|
||||||
|
self.ds_timesteps_1 = None
|
||||||
|
self.ds_timesteps_2 = None
|
||||||
|
self.ds_ratio = None
|
||||||
|
|
||||||
|
def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
|
||||||
|
if ds_depth_1 is None:
|
||||||
|
print("Deep Shrink is disabled.")
|
||||||
|
self.ds_depth_1 = None
|
||||||
|
self.ds_timesteps_1 = None
|
||||||
|
self.ds_depth_2 = None
|
||||||
|
self.ds_timesteps_2 = None
|
||||||
|
self.ds_ratio = None
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
|
||||||
|
)
|
||||||
|
self.ds_depth_1 = ds_depth_1
|
||||||
|
self.ds_timesteps_1 = ds_timesteps_1
|
||||||
|
self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1
|
||||||
|
self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
|
||||||
|
self.ds_ratio = ds_ratio
|
||||||
|
|
||||||
# region diffusers compatibility
|
# region diffusers compatibility
|
||||||
def prepare_config(self):
|
def prepare_config(self):
|
||||||
self.config = SimpleNamespace()
|
self.config = SimpleNamespace()
|
||||||
@@ -1519,9 +1571,21 @@ class UNet2DConditionModel(nn.Module):
|
|||||||
# 2. pre-process
|
# 2. pre-process
|
||||||
sample = self.conv_in(sample)
|
sample = self.conv_in(sample)
|
||||||
|
|
||||||
# 3. down
|
|
||||||
down_block_res_samples = (sample,)
|
down_block_res_samples = (sample,)
|
||||||
for downsample_block in self.down_blocks:
|
for depth, downsample_block in enumerate(self.down_blocks):
|
||||||
|
# Deep Shrink
|
||||||
|
if self.ds_depth_1 is not None:
|
||||||
|
if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or (
|
||||||
|
self.ds_depth_2 is not None
|
||||||
|
and depth == self.ds_depth_2
|
||||||
|
and timesteps[0] < self.ds_timesteps_1
|
||||||
|
and timesteps[0] >= self.ds_timesteps_2
|
||||||
|
):
|
||||||
|
org_dtype = sample.dtype
|
||||||
|
if org_dtype == torch.bfloat16:
|
||||||
|
sample = sample.to(torch.float32)
|
||||||
|
sample = F.interpolate(sample, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype)
|
||||||
|
|
||||||
# downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、
|
# downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、
|
||||||
# まあこちらのほうがわかりやすいかもしれない
|
# まあこちらのほうがわかりやすいかもしれない
|
||||||
if downsample_block.has_cross_attention:
|
if downsample_block.has_cross_attention:
|
||||||
|
|||||||
@@ -266,6 +266,23 @@ def get_timestep_embedding(
|
|||||||
return emb
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
# Deep Shrink: We do not common this function, because minimize dependencies.
|
||||||
|
def resize_like(x, target, mode="bicubic", align_corners=False):
|
||||||
|
org_dtype = x.dtype
|
||||||
|
if org_dtype == torch.bfloat16:
|
||||||
|
x = x.to(torch.float32)
|
||||||
|
|
||||||
|
if x.shape[-2:] != target.shape[-2:]:
|
||||||
|
if mode == "nearest":
|
||||||
|
x = F.interpolate(x, size=target.shape[-2:], mode=mode)
|
||||||
|
else:
|
||||||
|
x = F.interpolate(x, size=target.shape[-2:], mode=mode, align_corners=align_corners)
|
||||||
|
|
||||||
|
if org_dtype == torch.bfloat16:
|
||||||
|
x = x.to(org_dtype)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class GroupNorm32(nn.GroupNorm):
|
class GroupNorm32(nn.GroupNorm):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if self.weight.dtype != torch.float32:
|
if self.weight.dtype != torch.float32:
|
||||||
@@ -996,6 +1013,31 @@ class SdxlUNet2DConditionModel(nn.Module):
|
|||||||
[GroupNorm32(32, self.model_channels), nn.SiLU(), nn.Conv2d(self.model_channels, self.out_channels, 3, padding=1)]
|
[GroupNorm32(32, self.model_channels), nn.SiLU(), nn.Conv2d(self.model_channels, self.out_channels, 3, padding=1)]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Deep Shrink
|
||||||
|
self.ds_depth_1 = None
|
||||||
|
self.ds_depth_2 = None
|
||||||
|
self.ds_timesteps_1 = None
|
||||||
|
self.ds_timesteps_2 = None
|
||||||
|
self.ds_ratio = None
|
||||||
|
|
||||||
|
def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
|
||||||
|
if ds_depth_1 is None:
|
||||||
|
print("Deep Shrink is disabled.")
|
||||||
|
self.ds_depth_1 = None
|
||||||
|
self.ds_timesteps_1 = None
|
||||||
|
self.ds_depth_2 = None
|
||||||
|
self.ds_timesteps_2 = None
|
||||||
|
self.ds_ratio = None
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
|
||||||
|
)
|
||||||
|
self.ds_depth_1 = ds_depth_1
|
||||||
|
self.ds_timesteps_1 = ds_timesteps_1
|
||||||
|
self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1
|
||||||
|
self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
|
||||||
|
self.ds_ratio = ds_ratio
|
||||||
|
|
||||||
# region diffusers compatibility
|
# region diffusers compatibility
|
||||||
def prepare_config(self):
|
def prepare_config(self):
|
||||||
self.config = SimpleNamespace()
|
self.config = SimpleNamespace()
|
||||||
@@ -1077,16 +1119,42 @@ class SdxlUNet2DConditionModel(nn.Module):
|
|||||||
|
|
||||||
# h = x.type(self.dtype)
|
# h = x.type(self.dtype)
|
||||||
h = x
|
h = x
|
||||||
for module in self.input_blocks:
|
|
||||||
|
for depth, module in enumerate(self.input_blocks):
|
||||||
|
# Deep Shrink
|
||||||
|
if self.ds_depth_1 is not None:
|
||||||
|
if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or (
|
||||||
|
self.ds_depth_2 is not None
|
||||||
|
and depth == self.ds_depth_2
|
||||||
|
and timesteps[0] < self.ds_timesteps_1
|
||||||
|
and timesteps[0] >= self.ds_timesteps_2
|
||||||
|
):
|
||||||
|
# print("downsample", h.shape, self.ds_ratio)
|
||||||
|
org_dtype = h.dtype
|
||||||
|
if org_dtype == torch.bfloat16:
|
||||||
|
h = h.to(torch.float32)
|
||||||
|
h = F.interpolate(h, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype)
|
||||||
|
|
||||||
h = call_module(module, h, emb, context)
|
h = call_module(module, h, emb, context)
|
||||||
hs.append(h)
|
hs.append(h)
|
||||||
|
|
||||||
h = call_module(self.middle_block, h, emb, context)
|
h = call_module(self.middle_block, h, emb, context)
|
||||||
|
|
||||||
for module in self.output_blocks:
|
for module in self.output_blocks:
|
||||||
|
# Deep Shrink
|
||||||
|
if self.ds_depth_1 is not None:
|
||||||
|
if hs[-1].shape[-2:] != h.shape[-2:]:
|
||||||
|
# print("upsample", h.shape, hs[-1].shape)
|
||||||
|
h = resize_like(h, hs[-1])
|
||||||
|
|
||||||
h = torch.cat([h, hs.pop()], dim=1)
|
h = torch.cat([h, hs.pop()], dim=1)
|
||||||
h = call_module(module, h, emb, context)
|
h = call_module(module, h, emb, context)
|
||||||
|
|
||||||
|
# Deep Shrink: in case of depth 0
|
||||||
|
if self.ds_depth_1 == 0 and h.shape[-2:] != x.shape[-2:]:
|
||||||
|
# print("upsample", h.shape, x.shape)
|
||||||
|
h = resize_like(h, x)
|
||||||
|
|
||||||
h = h.type(x.dtype)
|
h = h.type(x.dtype)
|
||||||
h = call_module(self.out, h, emb, context)
|
h = call_module(self.out, h, emb, context)
|
||||||
|
|
||||||
|
|||||||
@@ -1696,6 +1696,10 @@ def main(args):
|
|||||||
if args.diffusers_xformers:
|
if args.diffusers_xformers:
|
||||||
pipe.enable_xformers_memory_efficient_attention()
|
pipe.enable_xformers_memory_efficient_attention()
|
||||||
|
|
||||||
|
# Deep Shrink
|
||||||
|
if args.ds_depth_1 is not None:
|
||||||
|
unet.set_deep_shrink(args.ds_depth_1, args.ds_timesteps_1, args.ds_depth_2, args.ds_timesteps_2, args.ds_ratio)
|
||||||
|
|
||||||
# Textual Inversionを処理する
|
# Textual Inversionを処理する
|
||||||
if args.textual_inversion_embeddings:
|
if args.textual_inversion_embeddings:
|
||||||
token_ids_embeds1 = []
|
token_ids_embeds1 = []
|
||||||
@@ -2286,6 +2290,13 @@ def main(args):
|
|||||||
clip_prompt = None
|
clip_prompt = None
|
||||||
network_muls = None
|
network_muls = None
|
||||||
|
|
||||||
|
# Deep Shrink
|
||||||
|
ds_depth_1 = None # means no override
|
||||||
|
ds_timesteps_1 = args.ds_timesteps_1
|
||||||
|
ds_depth_2 = args.ds_depth_2
|
||||||
|
ds_timesteps_2 = args.ds_timesteps_2
|
||||||
|
ds_ratio = args.ds_ratio
|
||||||
|
|
||||||
prompt_args = raw_prompt.strip().split(" --")
|
prompt_args = raw_prompt.strip().split(" --")
|
||||||
prompt = prompt_args[0]
|
prompt = prompt_args[0]
|
||||||
print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}")
|
print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}")
|
||||||
@@ -2393,10 +2404,51 @@ def main(args):
|
|||||||
print(f"network mul: {network_muls}")
|
print(f"network mul: {network_muls}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Deep Shrink
|
||||||
|
m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
|
if m: # deep shrink depth 1
|
||||||
|
ds_depth_1 = int(m.group(1))
|
||||||
|
print(f"deep shrink depth 1: {ds_depth_1}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
|
if m: # deep shrink timesteps 1
|
||||||
|
ds_timesteps_1 = int(m.group(1))
|
||||||
|
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
|
||||||
|
print(f"deep shrink timesteps 1: {ds_timesteps_1}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
|
if m: # deep shrink depth 2
|
||||||
|
ds_depth_2 = int(m.group(1))
|
||||||
|
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
|
||||||
|
print(f"deep shrink depth 2: {ds_depth_2}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
|
if m: # deep shrink timesteps 2
|
||||||
|
ds_timesteps_2 = int(m.group(1))
|
||||||
|
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
|
||||||
|
print(f"deep shrink timesteps 2: {ds_timesteps_2}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
|
if m: # deep shrink ratio
|
||||||
|
ds_ratio = float(m.group(1))
|
||||||
|
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
|
||||||
|
print(f"deep shrink ratio: {ds_ratio}")
|
||||||
|
continue
|
||||||
|
|
||||||
except ValueError as ex:
|
except ValueError as ex:
|
||||||
print(f"Exception in parsing / 解析エラー: {parg}")
|
print(f"Exception in parsing / 解析エラー: {parg}")
|
||||||
print(ex)
|
print(ex)
|
||||||
|
|
||||||
|
# override Deep Shrink
|
||||||
|
if ds_depth_1 is not None:
|
||||||
|
if ds_depth_1 < 0:
|
||||||
|
ds_depth_1 = args.ds_depth_1 or 3
|
||||||
|
unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio)
|
||||||
|
|
||||||
# prepare seed
|
# prepare seed
|
||||||
if seeds is not None: # given in prompt
|
if seeds is not None: # given in prompt
|
||||||
# 数が足りないなら前のをそのまま使う
|
# 数が足りないなら前のをそのまま使う
|
||||||
@@ -2734,6 +2786,31 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
default=None,
|
default=None,
|
||||||
help="enable CLIP Vision Conditioning for img2img with this strength / img2imgでCLIP Vision Conditioningを有効にしてこのstrengthで処理する",
|
help="enable CLIP Vision Conditioning for img2img with this strength / img2imgでCLIP Vision Conditioningを有効にしてこのstrengthで処理する",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Deep Shrink
|
||||||
|
parser.add_argument(
|
||||||
|
"--ds_depth_1",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Enable Deep Shrink with this depth 1, valid values are 0 to 8 / Deep Shrinkをこのdepthで有効にする",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ds_timesteps_1",
|
||||||
|
type=int,
|
||||||
|
default=650,
|
||||||
|
help="Apply Deep Shrink depth 1 until this timesteps / Deep Shrink depth 1を適用するtimesteps",
|
||||||
|
)
|
||||||
|
parser.add_argument("--ds_depth_2", type=int, default=None, help="Deep Shrink depth 2 / Deep Shrinkのdepth 2")
|
||||||
|
parser.add_argument(
|
||||||
|
"--ds_timesteps_2",
|
||||||
|
type=int,
|
||||||
|
default=650,
|
||||||
|
help="Apply Deep Shrink depth 2 until this timesteps / Deep Shrink depth 2を適用するtimesteps",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ds_ratio", type=float, default=0.5, help="Deep Shrink ratio for downsampling / Deep Shrinkのdownsampling比率"
|
||||||
|
)
|
||||||
|
|
||||||
# # parser.add_argument(
|
# # parser.add_argument(
|
||||||
# "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像"
|
# "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像"
|
||||||
# )
|
# )
|
||||||
|
|||||||
Reference in New Issue
Block a user