diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 7b4ef2e5..e025c74e 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -1981,7 +1981,6 @@ def main(args): imported_module = importlib.import_module(network_module) network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] - network_dim = None if args.network_dim is None or len(args.network_dim) <= i else args.network_dim[i] net_kwargs = {} if args.network_args and i < len(args.network_args): @@ -1992,22 +1991,21 @@ def main(args): key, value = net_arg.split("=") net_kwargs[key] = value - network = imported_module.create_network(network_mul, network_dim, vae, text_encoder, unet, **net_kwargs) - if network is None: - return - if args.network_weights and i < len(args.network_weights): network_weight = args.network_weights[i] print("load network weights from:", network_weight) - if os.path.splitext(network_weight)[1] == '.safetensors': - from safetensors.torch import safe_open - with safe_open(network_weight, framework="pt") as f: - metadata = f.metadata() - if metadata is not None: - print(f"metadata for: {network_weight}: {metadata}") + from safetensors.torch import safe_open + with safe_open(network_weight, framework="pt") as f: + metadata = f.metadata() + if metadata is not None: + print(f"metadata for: {network_weight}: {metadata}") - network.load_weights(network_weight) + network = imported_module.create_network_from_weights(network_mul, network_weight, vae, text_encoder, unet, **net_kwargs) + else: + raise ValueError("No weight. Weight is required.") + if network is None: + return network.apply_to(text_encoder, unet) @@ -2526,8 +2524,6 @@ if __name__ == '__main__': parser.add_argument("--network_weights", type=str, default=None, nargs='*', help='Hypernetwork weights to load / Hypernetworkの重み') parser.add_argument("--network_mul", type=float, default=None, nargs='*', help='Hypernetwork multiplier / Hypernetworkの効果の倍率') - parser.add_argument("--network_dim", type=int, default=None, nargs='*', - help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)') parser.add_argument("--network_args", type=str, default=None, nargs='*', help='additional argmuments for network (key=value) / ネットワークへの追加の引数') parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う') diff --git a/networks/lora.py b/networks/lora.py index 3f8244e0..9243f1e1 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -13,9 +13,11 @@ class LoRAModule(torch.nn.Module): replaces forward method of the original Linear, instead of replacing the original Linear module. """ - def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4): + def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1): + """ if alpha == 0 or None, alpha is rank (no scaling). """ super().__init__() self.lora_name = lora_name + self.lora_dim = lora_dim if org_module.__class__.__name__ == 'Conv2d': in_dim = org_module.in_channels @@ -28,6 +30,12 @@ class LoRAModule(torch.nn.Module): self.lora_down = torch.nn.Linear(in_dim, lora_dim, bias=False) self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False) + if type(alpha) == torch.Tensor: + alpha = alpha.detach().numpy() + alpha = lora_dim if alpha is None or alpha == 0 else alpha + self.scale = alpha / self.lora_dim + self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える + # same as microsoft's torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) torch.nn.init.zeros_(self.lora_up.weight) @@ -41,13 +49,37 @@ class LoRAModule(torch.nn.Module): del self.org_module def forward(self, x): - return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier + return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale -def create_network(multiplier, network_dim, vae, text_encoder, unet, **kwargs): +def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs): if network_dim is None: network_dim = 4 # default - network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim) + network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha) + return network + + +def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwargs): + if os.path.splitext(file)[1] == '.safetensors': + from safetensors.torch import load_file, safe_open + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location='cpu') + + # get dim (rank) + network_alpha = None + network_dim = None + for key, value in weights_sd.items(): + if network_alpha is None and 'alpha' in key: + network_alpha = value + if network_dim is None and 'lora_down' in key and len(value.size()) == 2: + network_dim = value.size()[0] + + if network_alpha is None: + network_alpha = network_dim + + network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha) + network.weights_sd = weights_sd return network @@ -57,10 +89,11 @@ class LoRANetwork(torch.nn.Module): LORA_PREFIX_UNET = 'lora_unet' LORA_PREFIX_TEXT_ENCODER = 'lora_te' - def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4) -> None: + def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1) -> None: super().__init__() self.multiplier = multiplier self.lora_dim = lora_dim + self.alpha = alpha # create module instances def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> list[LoRAModule]: @@ -71,7 +104,7 @@ class LoRANetwork(torch.nn.Module): if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)): lora_name = prefix + '.' + name + '.' + child_name lora_name = lora_name.replace('.', '_') - lora = LoRAModule(lora_name, child_module, self.multiplier, self.lora_dim) + lora = LoRAModule(lora_name, child_module, self.multiplier, self.lora_dim, self.alpha) loras.append(lora) return loras @@ -149,21 +182,21 @@ class LoRANetwork(torch.nn.Module): return params self.requires_grad_(True) - params = [] + all_params = [] if self.text_encoder_loras: param_data = {'params': enumerate_params(self.text_encoder_loras)} if text_encoder_lr is not None: param_data['lr'] = text_encoder_lr - params.append(param_data) + all_params.append(param_data) if self.unet_loras: param_data = {'params': enumerate_params(self.unet_loras)} if unet_lr is not None: param_data['lr'] = unet_lr - params.append(param_data) + all_params.append(param_data) - return params + return all_params def prepare_grad_etc(self, text_encoder, unet): self.requires_grad_(True) diff --git a/train_network.py b/train_network.py index 70db4450..88014ddb 100644 --- a/train_network.py +++ b/train_network.py @@ -107,7 +107,8 @@ def train(args): key, value = net_arg.split('=') net_kwargs[key] = value - network = network_module.create_network(1.0, args.network_dim, vae, text_encoder, unet, **net_kwargs) + # if a new network is added in future, add if ~ then blocks for each network (;'∀') + network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, **net_kwargs) if network is None: return @@ -243,7 +244,8 @@ def train(args): "ss_lr_warmup_steps": args.lr_warmup_steps, "ss_lr_scheduler": args.lr_scheduler, "ss_network_module": args.network_module, - "ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim + "ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim + "ss_network_alpha": args.network_alpha, # some networks may not use this value "ss_mixed_precision": args.mixed_precision, "ss_full_fp16": bool(args.full_fp16), "ss_v2": bool(args.v2), @@ -445,6 +447,8 @@ if __name__ == '__main__': parser.add_argument("--network_module", type=str, default=None, help='network module to train / 学習対象のネットワークのモジュール') parser.add_argument("--network_dim", type=int, default=None, help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)') + parser.add_argument("--network_alpha", type=float, default=1, + help='alpha for LoRA weight scaling, 0 for no scaling (same as old version) / LoRaの重み調整のalpha値、0で調整なし(旧バージョンと同じ)') parser.add_argument("--network_args", type=str, default=None, nargs='*', help='additional argmuments for network (key=value) / ネットワークへの追加の引数') parser.add_argument("--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する")