fix crash gen script, change to network_dropout

This commit is contained in:
Kohya S
2023-06-01 20:07:04 +09:00
parent f4c9276336
commit f8e8df5a04
2 changed files with 21 additions and 19 deletions

View File

@@ -69,14 +69,17 @@ class LoRAModule(torch.nn.Module):
def forward(self, x):
if self.dropout:
return self.org_forward(x) + self.lora_up(torch.nn.functional.dropout(self.lora_down(x),p=self.dropout)) * self.multiplier * self.scale
return (
self.org_forward(x)
+ self.lora_up(torch.nn.functional.dropout(self.lora_down(x), p=self.dropout)) * self.multiplier * self.scale
)
else:
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
class LoRAInfModule(LoRAModule):
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1):
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha)
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1, dropout=None):
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha, dropout)
self.org_module_ref = [org_module] # 後から参照できるように
self.enabled = True
@@ -382,7 +385,6 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
)
# remove block dim/alpha without learning rate
block_dims, block_alphas, conv_block_dims, conv_block_alphas = remove_block_dims_and_alphas(
block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight
@@ -400,6 +402,7 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
multiplier=multiplier,
lora_dim=network_dim,
alpha=network_alpha,
dropout=dropout,
conv_lora_dim=conv_dim,
conv_alpha=conv_alpha,
block_dims=block_dims,
@@ -407,7 +410,6 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
conv_block_dims=conv_block_dims,
conv_block_alphas=conv_block_alphas,
varbose=True,
dropout=dropout,
)
if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
@@ -676,6 +678,7 @@ class LoRANetwork(torch.nn.Module):
multiplier=1.0,
lora_dim=4,
alpha=1,
dropout=None,
conv_lora_dim=None,
conv_alpha=None,
block_dims=None,
@@ -686,7 +689,6 @@ class LoRANetwork(torch.nn.Module):
modules_alpha=None,
module_class=LoRAModule,
varbose=False,
dropout=None
) -> None:
"""
LoRA network: すごく引数が多いが、パターンは以下の通り
@@ -704,19 +706,18 @@ class LoRANetwork(torch.nn.Module):
self.conv_lora_dim = conv_lora_dim
self.conv_alpha = conv_alpha
self.dropout = dropout
print(f"Neuron dropout: p={self.dropout}")
if modules_dim is not None:
print(f"create LoRA network from weights")
elif block_dims is not None:
print(f"create LoRA network from block_dims")
print(f"create LoRA network from block_dims, neuron dropout: p={self.dropout}")
print(f"block_dims: {block_dims}")
print(f"block_alphas: {block_alphas}")
if conv_block_dims is not None:
print(f"conv_block_dims: {conv_block_dims}")
print(f"conv_block_alphas: {conv_block_alphas}")
else:
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}, neuron dropout: p={self.dropout}")
if self.conv_lora_dim is not None:
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")