diff --git a/networks/extract_lora_from_models.py b/networks/extract_lora_from_models.py index 5d77b9e5..b5d18d9b 100644 --- a/networks/extract_lora_from_models.py +++ b/networks/extract_lora_from_models.py @@ -103,7 +103,8 @@ def svd(args): if args.device: mat = mat.to(args.device) - # print(mat.size(), mat.device, rank, in_dim, out_dim) + + # print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim) rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim if conv2d: @@ -137,27 +138,17 @@ def svd(args): lora_weights[lora_name] = (U, Vh) # make state dict for LoRA - lora_network_o.apply_to(text_encoder_o, unet_o, text_encoder_different, True) # to make state dict - lora_sd = lora_network_o.state_dict() - print(f"LoRA has {len(lora_sd)} weights.") - - for key in list(lora_sd.keys()): - if "alpha" in key: - continue - - lora_name = key.split('.')[0] - i = 0 if "lora_up" in key else 1 - - weights = lora_weights[lora_name][i] - # print(key, i, weights.size(), lora_sd[key].size()) - # if len(lora_sd[key].size()) == 4: - # weights = weights.unsqueeze(2).unsqueeze(3) - - assert weights.size() == lora_sd[key].size(), f"size unmatch: {key}" - lora_sd[key] = weights + lora_sd = {} + for lora_name, (up_weight, down_weight) in lora_weights.items(): + lora_sd[lora_name + '.lora_up.weight'] = up_weight + lora_sd[lora_name + '.lora_down.weight'] = down_weight + lora_sd[lora_name + '.alpha'] = torch.tensor(down_weight.size()[0]) # load state dict to LoRA and save it - info = lora_network_o.load_state_dict(lora_sd) + lora_network_save = lora.create_network_from_weights(1.0, None, None, text_encoder_o, unet_o, weights_sd=lora_sd) + lora_network_save.apply_to(text_encoder_o, unet_o) # create internal module references for state_dict + + info = lora_network_save.load_state_dict(lora_sd) print(f"Loading extracted LoRA weights: {info}") dir_name = os.path.dirname(args.save_to) @@ -167,7 +158,7 @@ def svd(args): # minimum metadata metadata = {"ss_network_module": "networks.lora", "ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)} - lora_network_o.save_weights(args.save_to, save_dtype, metadata) + lora_network_save.save_weights(args.save_to, save_dtype, metadata) print(f"LoRA weights are saved to: {args.save_to}") diff --git a/networks/lora.py b/networks/lora.py index c0181c02..6d3875dc 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -21,30 +21,34 @@ class LoRAModule(torch.nn.Module): """ if alpha == 0 or None, alpha is rank (no scaling). """ super().__init__() self.lora_name = lora_name - self.lora_dim = lora_dim if org_module.__class__.__name__ == 'Conv2d': in_dim = org_module.in_channels out_dim = org_module.out_channels + else: + in_dim = org_module.in_features + out_dim = org_module.out_features - self.lora_dim = min(self.lora_dim, in_dim, out_dim) - if self.lora_dim != lora_dim: - print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") + # if limit_rank: + # self.lora_dim = min(lora_dim, in_dim, out_dim) + # if self.lora_dim != lora_dim: + # print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") + # else: + self.lora_dim = lora_dim + if org_module.__class__.__name__ == 'Conv2d': kernel_size = org_module.kernel_size stride = org_module.stride padding = org_module.padding self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) else: - in_dim = org_module.in_features - out_dim = org_module.out_features - self.lora_down = torch.nn.Linear(in_dim, lora_dim, bias=False) - self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False) + self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) + self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) if type(alpha) == torch.Tensor: alpha = alpha.detach().float().numpy() # without casting, bf16 causes error - alpha = lora_dim if alpha is None or alpha == 0 else alpha + alpha = self.lora_dim if alpha is None or alpha == 0 else alpha self.scale = alpha / self.lora_dim self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える @@ -149,12 +153,13 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un return network -def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwargs): - if os.path.splitext(file)[1] == '.safetensors': - from safetensors.torch import load_file, safe_open - weights_sd = load_file(file) - else: - weights_sd = torch.load(file, map_location='cpu') +def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, **kwargs): + if weights_sd is None: + if os.path.splitext(file)[1] == '.safetensors': + from safetensors.torch import load_file, safe_open + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location='cpu') # get dim/alpha mapping modules_dim = {} @@ -174,7 +179,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwa # support old LoRA without alpha for key in modules_dim.keys(): if key not in modules_alpha: - modules_alpha = modules_dim[key] + modules_alpha = modules_dim[key] network = LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha) network.weights_sd = weights_sd @@ -183,7 +188,8 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwa class LoRANetwork(torch.nn.Module): # is it possible to apply conv_in and conv_out? - UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention", "ResnetBlock2D", "Downsample2D", "Upsample2D"] + UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"] + UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] LORA_PREFIX_UNET = 'lora_unet' LORA_PREFIX_TEXT_ENCODER = 'lora_te' @@ -245,7 +251,12 @@ class LoRANetwork(torch.nn.Module): text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") - self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, LoRANetwork.UNET_TARGET_REPLACE_MODULE) + # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights + target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE + if modules_dim is not None or self.conv_lora_dim is not None: + target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 + + self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, target_modules) print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") self.weights_sd = None @@ -371,7 +382,7 @@ class LoRANetwork(torch.nn.Module): else: torch.save(state_dict, file) - @staticmethod + @ staticmethod def set_regions(networks, image): image = image.astype(np.float32) / 255.0 for i, network in enumerate(networks[:3]):