diff --git a/networks/lora.py b/networks/lora.py index 898ffce9..121a6281 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -73,7 +73,7 @@ class LoRAModule(torch.nn.Module): class LoRAInfModule(LoRAModule): def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1): super().__init__(lora_name, org_module, multiplier, lora_dim, alpha) - + self.org_module_ref = [org_module] # 後から参照できるように self.enabled = True @@ -319,6 +319,35 @@ class LoRAInfModule(LoRAModule): return out +def parse_block_lr_kwargs(nw_kwargs): + down_lr_weight = nw_kwargs.get("down_lr_weight", None) + mid_lr_weight = nw_kwargs.get("mid_lr_weight", None) + up_lr_weight = nw_kwargs.get("up_lr_weight", None) + + # 以上のいずれにも設定がない場合は無効としてNoneを返す + if down_lr_weight is None and mid_lr_weight is None and up_lr_weight is None: + return None, None, None + + # extract learning rate weight for each block + if down_lr_weight is not None: + # if some parameters are not set, use zero + if "," in down_lr_weight: + down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")] + + if mid_lr_weight is not None: + mid_lr_weight = float(mid_lr_weight) + + if up_lr_weight is not None: + if "," in up_lr_weight: + up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")] + + down_lr_weight, mid_lr_weight, up_lr_weight = get_block_lr_weight( + down_lr_weight, mid_lr_weight, up_lr_weight, float(nw_kwargs.get("block_lr_zero_threshold", 0.0)) + ) + + return down_lr_weight, mid_lr_weight, up_lr_weight + + def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs): if network_dim is None: network_dim = 4 # default @@ -337,9 +366,7 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un # block dim/alpha/lr block_dims = kwargs.get("block_dims", None) - down_lr_weight = kwargs.get("down_lr_weight", None) - mid_lr_weight = kwargs.get("mid_lr_weight", None) - up_lr_weight = kwargs.get("up_lr_weight", None) + down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs) # 以上のいずれかに指定があればblockごとのdim(rank)を有効にする if block_dims is not None or down_lr_weight is not None or mid_lr_weight is not None or up_lr_weight is not None: @@ -351,22 +378,6 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha ) - # extract learning rate weight for each block - if down_lr_weight is not None: - # if some parameters are not set, use zero - if "," in down_lr_weight: - down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")] - - if mid_lr_weight is not None: - mid_lr_weight = float(mid_lr_weight) - - if up_lr_weight is not None: - if "," in up_lr_weight: - up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")] - - down_lr_weight, mid_lr_weight, up_lr_weight = get_block_lr_weight( - down_lr_weight, mid_lr_weight, up_lr_weight, float(kwargs.get("block_lr_zero_threshold", 0.0)) - ) # remove block dim/alpha without learning rate block_dims, block_alphas, conv_block_dims, conv_block_alphas = remove_block_dims_and_alphas( @@ -634,6 +645,12 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh network = LoRANetwork( text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class ) + + # block lr + down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs) + if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None: + network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight) + return network, weights_sd @@ -835,7 +852,7 @@ class LoRANetwork(torch.nn.Module): print(f"weights are merged") - # 層別学習率用に層ごとの学習率に対する倍率を定義する + # 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない def set_block_lr_weight( self, up_lr_weight: List[float] = None, diff --git a/train_network.py b/train_network.py index bcfd657f..b5cdfea1 100644 --- a/train_network.py +++ b/train_network.py @@ -176,32 +176,10 @@ def train(args): net_kwargs[key] = value # if a new network is added in future, add if ~ then blocks for each network (;'∀') - if args.size_from_weights: - network, weights = network_module.create_network_from_weights(1, args.network_weights, vae, text_encoder, unet) - if net_kwargs is not None: - down_lr_weight = net_kwargs.get("down_lr_weight", None) - mid_lr_weight = net_kwargs.get("mid_lr_weight", None) - up_lr_weight = net_kwargs.get("up_lr_weight", None) - if down_lr_weight is not None: - # if some parameters are not set, use zero - if "," in down_lr_weight: - down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")] - - if mid_lr_weight is not None: - mid_lr_weight = float(mid_lr_weight) - - if up_lr_weight is not None: - if "," in up_lr_weight: - up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")] - - down_lr_weight, mid_lr_weight, up_lr_weight = network_module.get_block_lr_weight( - down_lr_weight, mid_lr_weight, up_lr_weight, float(net_kwargs.get("block_lr_zero_threshold", 0.0)) - ) - - if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None: - network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight) + if args.dim_from_weights: + network, _ = network_module.create_network_from_weights(1, args.network_weights, vae, text_encoder, unet, **net_kwargs) else: - network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, **net_kwargs) + network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, **net_kwargs) if network is None: return @@ -786,10 +764,11 @@ def setup_parser() -> argparse.ArgumentParser: "--training_comment", type=str, default=None, help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列" ) parser.add_argument( - "--size_from_weights", action="store_true", help="only training Text Encoder part / Text Encoder関連部分のみ学習する" + "--dim_from_weights", + action="store_true", + help="automatically determine dim (rank) from network_weights / dim (rank)をnetwork_weightsで指定した重みから自動で決定する", ) - return parser