fix crash gen script, change to network_dropout

This commit is contained in:
Kohya S
2023-06-01 20:07:04 +09:00
parent f4c9276336
commit f8e8df5a04
2 changed files with 21 additions and 19 deletions

View File

@@ -69,14 +69,17 @@ class LoRAModule(torch.nn.Module):
def forward(self, x): def forward(self, x):
if self.dropout: if self.dropout:
return self.org_forward(x) + self.lora_up(torch.nn.functional.dropout(self.lora_down(x),p=self.dropout)) * self.multiplier * self.scale return (
self.org_forward(x)
+ self.lora_up(torch.nn.functional.dropout(self.lora_down(x), p=self.dropout)) * self.multiplier * self.scale
)
else: else:
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
class LoRAInfModule(LoRAModule): class LoRAInfModule(LoRAModule):
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1): def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1, dropout=None):
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha) super().__init__(lora_name, org_module, multiplier, lora_dim, alpha, dropout)
self.org_module_ref = [org_module] # 後から参照できるように self.org_module_ref = [org_module] # 後から参照できるように
self.enabled = True self.enabled = True
@@ -382,7 +385,6 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
) )
# remove block dim/alpha without learning rate # remove block dim/alpha without learning rate
block_dims, block_alphas, conv_block_dims, conv_block_alphas = remove_block_dims_and_alphas( block_dims, block_alphas, conv_block_dims, conv_block_alphas = remove_block_dims_and_alphas(
block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight
@@ -400,6 +402,7 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
multiplier=multiplier, multiplier=multiplier,
lora_dim=network_dim, lora_dim=network_dim,
alpha=network_alpha, alpha=network_alpha,
dropout=dropout,
conv_lora_dim=conv_dim, conv_lora_dim=conv_dim,
conv_alpha=conv_alpha, conv_alpha=conv_alpha,
block_dims=block_dims, block_dims=block_dims,
@@ -407,7 +410,6 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
conv_block_dims=conv_block_dims, conv_block_dims=conv_block_dims,
conv_block_alphas=conv_block_alphas, conv_block_alphas=conv_block_alphas,
varbose=True, varbose=True,
dropout=dropout,
) )
if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None: if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
@@ -676,6 +678,7 @@ class LoRANetwork(torch.nn.Module):
multiplier=1.0, multiplier=1.0,
lora_dim=4, lora_dim=4,
alpha=1, alpha=1,
dropout=None,
conv_lora_dim=None, conv_lora_dim=None,
conv_alpha=None, conv_alpha=None,
block_dims=None, block_dims=None,
@@ -686,7 +689,6 @@ class LoRANetwork(torch.nn.Module):
modules_alpha=None, modules_alpha=None,
module_class=LoRAModule, module_class=LoRAModule,
varbose=False, varbose=False,
dropout=None
) -> None: ) -> None:
""" """
LoRA network: すごく引数が多いが、パターンは以下の通り LoRA network: すごく引数が多いが、パターンは以下の通り
@@ -704,19 +706,18 @@ class LoRANetwork(torch.nn.Module):
self.conv_lora_dim = conv_lora_dim self.conv_lora_dim = conv_lora_dim
self.conv_alpha = conv_alpha self.conv_alpha = conv_alpha
self.dropout = dropout self.dropout = dropout
print(f"Neuron dropout: p={self.dropout}")
if modules_dim is not None: if modules_dim is not None:
print(f"create LoRA network from weights") print(f"create LoRA network from weights")
elif block_dims is not None: elif block_dims is not None:
print(f"create LoRA network from block_dims") print(f"create LoRA network from block_dims, neuron dropout: p={self.dropout}")
print(f"block_dims: {block_dims}") print(f"block_dims: {block_dims}")
print(f"block_alphas: {block_alphas}") print(f"block_alphas: {block_alphas}")
if conv_block_dims is not None: if conv_block_dims is not None:
print(f"conv_block_dims: {conv_block_dims}") print(f"conv_block_dims: {conv_block_dims}")
print(f"conv_block_alphas: {conv_block_alphas}") print(f"conv_block_alphas: {conv_block_alphas}")
else: else:
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}, neuron dropout: p={self.dropout}")
if self.conv_lora_dim is not None: if self.conv_lora_dim is not None:
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")

View File

@@ -209,8 +209,9 @@ def train(args):
if args.dim_from_weights: if args.dim_from_weights:
network, _ = network_module.create_network_from_weights(1, args.network_weights, vae, text_encoder, unet, **net_kwargs) network, _ = network_module.create_network_from_weights(1, args.network_weights, vae, text_encoder, unet, **net_kwargs)
else: else:
# LyCORIS will work with this...
network = network_module.create_network( network = network_module.create_network(
1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, args.dropout, **net_kwargs 1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, dropout=args.network_dropout, **net_kwargs
) )
if network is None: if network is None:
return return
@@ -367,7 +368,8 @@ def train(args):
"ss_lr_scheduler": args.lr_scheduler, "ss_lr_scheduler": args.lr_scheduler,
"ss_network_module": args.network_module, "ss_network_module": args.network_module,
"ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim "ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim
"ss_network_alpha": args.network_alpha, # some networks may not use this value "ss_network_alpha": args.network_alpha, # some networks may not have alpha
"ss_network_dropout": args.network_dropout, # some networks may not have dropout
"ss_mixed_precision": args.mixed_precision, "ss_mixed_precision": args.mixed_precision,
"ss_full_fp16": bool(args.full_fp16), "ss_full_fp16": bool(args.full_fp16),
"ss_v2": bool(args.v2), "ss_v2": bool(args.v2),
@@ -391,7 +393,6 @@ def train(args):
"ss_prior_loss_weight": args.prior_loss_weight, "ss_prior_loss_weight": args.prior_loss_weight,
"ss_min_snr_gamma": args.min_snr_gamma, "ss_min_snr_gamma": args.min_snr_gamma,
"ss_scale_weight_norms": args.scale_weight_norms, "ss_scale_weight_norms": args.scale_weight_norms,
"ss_dropout": args.dropout,
} }
if use_user_config: if use_user_config:
@@ -798,6 +799,12 @@ def setup_parser() -> argparse.ArgumentParser:
default=1, default=1,
help="alpha for LoRA weight scaling, default 1 (same as network_dim for same behavior as old version) / LoRaの重み調整のalpha値、デフォルト1旧バージョンと同じ動作をするにはnetwork_dimと同じ値を指定", help="alpha for LoRA weight scaling, default 1 (same as network_dim for same behavior as old version) / LoRaの重み調整のalpha値、デフォルト1旧バージョンと同じ動作をするにはnetwork_dimと同じ値を指定",
) )
parser.add_argument(
"--network_dropout",
type=float,
default=None,
help="Drops neurons out of training every step (0 or None is default behavior (no dropout), 1 would drop all neurons) / 訓練時に毎ステップでニューロンをdropする0またはNoneはdropoutなし、1は全ニューロンをdropout",
)
parser.add_argument( parser.add_argument(
"--network_args", type=str, default=None, nargs="*", help="additional argmuments for network (key=value) / ネットワークへの追加の引数" "--network_args", type=str, default=None, nargs="*", help="additional argmuments for network (key=value) / ネットワークへの追加の引数"
) )
@@ -819,12 +826,6 @@ def setup_parser() -> argparse.ArgumentParser:
default=None, default=None,
help="Scale the weight of each key pair to help prevent overtraing via exploding gradients. (1 is a good starting point) / 重みの値をスケーリングして勾配爆発を防ぐ1が初期値としては適当", help="Scale the weight of each key pair to help prevent overtraing via exploding gradients. (1 is a good starting point) / 重みの値をスケーリングして勾配爆発を防ぐ1が初期値としては適当",
) )
parser.add_argument(
"--dropout",
type=float,
default=None,
help="Drops neurons out of training every step (0 is default behavior, 1 would drop all neurons)",
)
parser.add_argument( parser.add_argument(
"--base_weights", "--base_weights",
type=str, type=str,