fix to work LyCORIS<0.1.6

This commit is contained in:
Kohya S
2023-06-06 21:59:57 +09:00
parent 98635ebde2
commit bb91a10b5f
2 changed files with 4 additions and 4 deletions

View File

@@ -400,7 +400,7 @@ def parse_block_lr_kwargs(nw_kwargs):
return down_lr_weight, mid_lr_weight, up_lr_weight return down_lr_weight, mid_lr_weight, up_lr_weight
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, dropout=None, **kwargs): def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, neuron_dropout=None, **kwargs):
if network_dim is None: if network_dim is None:
network_dim = 4 # default network_dim = 4 # default
if network_alpha is None: if network_alpha is None:
@@ -455,7 +455,7 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
multiplier=multiplier, multiplier=multiplier,
lora_dim=network_dim, lora_dim=network_dim,
alpha=network_alpha, alpha=network_alpha,
dropout=dropout, dropout=neuron_dropout,
rank_dropout=rank_dropout, rank_dropout=rank_dropout,
module_dropout=module_dropout, module_dropout=module_dropout,
conv_lora_dim=conv_dim, conv_lora_dim=conv_dim,

View File

@@ -212,7 +212,7 @@ def train(args):
else: else:
# LyCORIS will work with this... # LyCORIS will work with this...
network = network_module.create_network( network = network_module.create_network(
1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, dropout=args.network_dropout, **net_kwargs 1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, neuron_dropout=args.network_dropout, **net_kwargs
) )
if network is None: if network is None:
return return
@@ -724,7 +724,7 @@ def train(args):
progress_bar.set_postfix(**logs) progress_bar.set_postfix(**logs)
if args.scale_weight_norms: if args.scale_weight_norms:
progress_bar.set_postfix(**{**max_mean_logs,**logs}) progress_bar.set_postfix(**{**max_mean_logs, **logs})
if args.logging_dir is not None: if args.logging_dir is not None:
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm)