common block lr args processing in create

This commit is contained in:
Kohya S
2023-05-11 21:47:59 +09:00
parent af08c56ce0
commit 2767a0f9f2
2 changed files with 44 additions and 48 deletions

View File

@@ -73,7 +73,7 @@ class LoRAModule(torch.nn.Module):
class LoRAInfModule(LoRAModule): class LoRAInfModule(LoRAModule):
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1): def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1):
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha) super().__init__(lora_name, org_module, multiplier, lora_dim, alpha)
self.org_module_ref = [org_module] # 後から参照できるように self.org_module_ref = [org_module] # 後から参照できるように
self.enabled = True self.enabled = True
@@ -319,6 +319,35 @@ class LoRAInfModule(LoRAModule):
return out return out
def parse_block_lr_kwargs(nw_kwargs):
down_lr_weight = nw_kwargs.get("down_lr_weight", None)
mid_lr_weight = nw_kwargs.get("mid_lr_weight", None)
up_lr_weight = nw_kwargs.get("up_lr_weight", None)
# 以上のいずれにも設定がない場合は無効としてNoneを返す
if down_lr_weight is None and mid_lr_weight is None and up_lr_weight is None:
return None, None, None
# extract learning rate weight for each block
if down_lr_weight is not None:
# if some parameters are not set, use zero
if "," in down_lr_weight:
down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")]
if mid_lr_weight is not None:
mid_lr_weight = float(mid_lr_weight)
if up_lr_weight is not None:
if "," in up_lr_weight:
up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")]
down_lr_weight, mid_lr_weight, up_lr_weight = get_block_lr_weight(
down_lr_weight, mid_lr_weight, up_lr_weight, float(nw_kwargs.get("block_lr_zero_threshold", 0.0))
)
return down_lr_weight, mid_lr_weight, up_lr_weight
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs): def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
if network_dim is None: if network_dim is None:
network_dim = 4 # default network_dim = 4 # default
@@ -337,9 +366,7 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
# block dim/alpha/lr # block dim/alpha/lr
block_dims = kwargs.get("block_dims", None) block_dims = kwargs.get("block_dims", None)
down_lr_weight = kwargs.get("down_lr_weight", None) down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs)
mid_lr_weight = kwargs.get("mid_lr_weight", None)
up_lr_weight = kwargs.get("up_lr_weight", None)
# 以上のいずれかに指定があればblockごとのdim(rank)を有効にする # 以上のいずれかに指定があればblockごとのdim(rank)を有効にする
if block_dims is not None or down_lr_weight is not None or mid_lr_weight is not None or up_lr_weight is not None: if block_dims is not None or down_lr_weight is not None or mid_lr_weight is not None or up_lr_weight is not None:
@@ -351,22 +378,6 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
) )
# extract learning rate weight for each block
if down_lr_weight is not None:
# if some parameters are not set, use zero
if "," in down_lr_weight:
down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")]
if mid_lr_weight is not None:
mid_lr_weight = float(mid_lr_weight)
if up_lr_weight is not None:
if "," in up_lr_weight:
up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")]
down_lr_weight, mid_lr_weight, up_lr_weight = get_block_lr_weight(
down_lr_weight, mid_lr_weight, up_lr_weight, float(kwargs.get("block_lr_zero_threshold", 0.0))
)
# remove block dim/alpha without learning rate # remove block dim/alpha without learning rate
block_dims, block_alphas, conv_block_dims, conv_block_alphas = remove_block_dims_and_alphas( block_dims, block_alphas, conv_block_dims, conv_block_alphas = remove_block_dims_and_alphas(
@@ -634,6 +645,12 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
network = LoRANetwork( network = LoRANetwork(
text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class
) )
# block lr
down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs)
if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
return network, weights_sd return network, weights_sd
@@ -835,7 +852,7 @@ class LoRANetwork(torch.nn.Module):
print(f"weights are merged") print(f"weights are merged")
# 層別学習率用に層ごとの学習率に対する倍率を定義する # 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない
def set_block_lr_weight( def set_block_lr_weight(
self, self,
up_lr_weight: List[float] = None, up_lr_weight: List[float] = None,

View File

@@ -176,32 +176,10 @@ def train(args):
net_kwargs[key] = value net_kwargs[key] = value
# if a new network is added in future, add if ~ then blocks for each network (;'∀') # if a new network is added in future, add if ~ then blocks for each network (;'∀')
if args.size_from_weights: if args.dim_from_weights:
network, weights = network_module.create_network_from_weights(1, args.network_weights, vae, text_encoder, unet) network, _ = network_module.create_network_from_weights(1, args.network_weights, vae, text_encoder, unet, **net_kwargs)
if net_kwargs is not None:
down_lr_weight = net_kwargs.get("down_lr_weight", None)
mid_lr_weight = net_kwargs.get("mid_lr_weight", None)
up_lr_weight = net_kwargs.get("up_lr_weight", None)
if down_lr_weight is not None:
# if some parameters are not set, use zero
if "," in down_lr_weight:
down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")]
if mid_lr_weight is not None:
mid_lr_weight = float(mid_lr_weight)
if up_lr_weight is not None:
if "," in up_lr_weight:
up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")]
down_lr_weight, mid_lr_weight, up_lr_weight = network_module.get_block_lr_weight(
down_lr_weight, mid_lr_weight, up_lr_weight, float(net_kwargs.get("block_lr_zero_threshold", 0.0))
)
if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
else: else:
network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, **net_kwargs) network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, **net_kwargs)
if network is None: if network is None:
return return
@@ -786,10 +764,11 @@ def setup_parser() -> argparse.ArgumentParser:
"--training_comment", type=str, default=None, help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列" "--training_comment", type=str, default=None, help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列"
) )
parser.add_argument( parser.add_argument(
"--size_from_weights", action="store_true", help="only training Text Encoder part / Text Encoder関連部分のみ学習する" "--dim_from_weights",
action="store_true",
help="automatically determine dim (rank) from network_weights / dim (rank)をnetwork_weightsで指定した重みから自動で決定する",
) )
return parser return parser