diff --git a/networks/lora.py b/networks/lora.py index 2318605a..9a5d95b9 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -26,8 +26,16 @@ class LoRAModule(torch.nn.Module): if org_module.__class__.__name__ == 'Conv2d': in_dim = org_module.in_channels out_dim = org_module.out_channels - self.lora_down = torch.nn.Conv2d(in_dim, lora_dim, (1, 1), bias=False) - self.lora_up = torch.nn.Conv2d(lora_dim, out_dim, (1, 1), bias=False) + + self.lora_dim = min(self.lora_dim, in_dim, out_dim) + if self.lora_dim != lora_dim: + print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") + + kernel_size = org_module.kernel_size + stride = org_module.stride + padding = org_module.padding + self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) + self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) else: in_dim = org_module.in_features out_dim = org_module.out_features @@ -56,6 +64,7 @@ class LoRAModule(torch.nn.Module): def set_region(self, region): self.region = region + self.region_mask = None def forward(self, x): if self.region is None: @@ -67,6 +76,7 @@ class LoRAModule(torch.nn.Module): self.region = None return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale + # calculate region mask first time if self.region_mask is None: if len(x.size()) == 4: h, w = x.size()[2:4] @@ -95,7 +105,43 @@ class LoRAModule(torch.nn.Module): def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs): if network_dim is None: network_dim = 4 # default - network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha) + + # extract dim/alpha for conv2d, and block dim + conv_dim = int(kwargs.get('conv_dim', network_dim)) + conv_alpha = kwargs.get('conv_alpha', network_alpha) + if conv_alpha is not None: + conv_alpha = float(conv_alpha) + + """ + block_dims = kwargs.get("block_dims") + block_alphas = None + + if block_dims is not None: + block_dims = [int(d) for d in block_dims.split(',')] + assert len(block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}" + block_alphas = kwargs.get("block_alphas") + if block_alphas is None: + block_alphas = [1] * len(block_dims) + else: + block_alphas = [int(a) for a in block_alphas(',')] + assert len(block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}" + + conv_block_dims = kwargs.get("conv_block_dims") + conv_block_alphas = None + + if conv_block_dims is not None: + conv_block_dims = [int(d) for d in conv_block_dims.split(',')] + assert len(conv_block_dims) == NUM_BLOCKS, f"Number of block dimensions is not same to {NUM_BLOCKS}" + conv_block_alphas = kwargs.get("conv_block_alphas") + if conv_block_alphas is None: + conv_block_alphas = [1] * len(conv_block_dims) + else: + conv_block_alphas = [int(a) for a in conv_block_alphas(',')] + assert len(conv_block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}" + """ + + network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, + alpha=network_alpha, conv_lora_dim=conv_dim, conv_alpha=conv_alpha) return network @@ -106,45 +152,88 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwa else: weights_sd = torch.load(file, map_location='cpu') - # get dim (rank) - network_alpha = None - network_dim = None + # get dim/alpha mapping + modules_dim = {} + modules_alpha = {} for key, value in weights_sd.items(): - if network_alpha is None and 'alpha' in key: - network_alpha = value - if network_dim is None and 'lora_down' in key and len(value.size()) == 2: - network_dim = value.size()[0] + if '.' not in key: + continue - if network_alpha is None: - network_alpha = network_dim + lora_name = key.split('.')[0] + if 'alpha' in key: + modules_alpha[lora_name] = value + elif 'lora_down' in key: + dim = value.size()[0] + modules_dim[lora_name] = dim + print(lora_name, value.size(), dim) - network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha) + # support old LoRA without alpha + for key in modules_dim.keys(): + if key not in modules_alpha: + modules_alpha = modules_dim[key] + + network = LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha) network.weights_sd = weights_sd return network class LoRANetwork(torch.nn.Module): - UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"] + # is it possible to apply conv_in and conv_out? + UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention", "ResnetBlock2D", "Downsample2D", "Upsample2D"] TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] LORA_PREFIX_UNET = 'lora_unet' LORA_PREFIX_TEXT_ENCODER = 'lora_te' - def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1) -> None: + def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1, conv_lora_dim=None, conv_alpha=None, modules_dim=None, modules_alpha=None) -> None: super().__init__() self.multiplier = multiplier + self.lora_dim = lora_dim self.alpha = alpha + self.conv_lora_dim = conv_lora_dim + self.conv_alpha = conv_alpha + + if modules_dim is not None: + print(f"create LoRA network from weights") + else: + print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") + + self.apply_to_conv2d_3x3 = self.conv_lora_dim is not None + if self.apply_to_conv2d_3x3: + if self.conv_alpha is None: + self.conv_alpha = self.alpha + print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") # create module instances def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]: loras = [] for name, module in root_module.named_modules(): if module.__class__.__name__ in target_replace_modules: + # TODO get block index here for child_name, child_module in module.named_modules(): - if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)): + is_linear = child_module.__class__.__name__ == "Linear" + is_conv2d = child_module.__class__.__name__ == "Conv2d" + is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) + if is_linear or is_conv2d: lora_name = prefix + '.' + name + '.' + child_name lora_name = lora_name.replace('.', '_') - lora = LoRAModule(lora_name, child_module, self.multiplier, self.lora_dim, self.alpha) + + if modules_dim is not None: + if lora_name not in modules_dim: + continue # no LoRA module in this weights file + dim = modules_dim[lora_name] + alpha = modules_alpha[lora_name] + else: + if is_linear or is_conv2d_1x1: + dim = self.lora_dim + alpha = self.alpha + elif self.apply_to_conv2d_3x3: + dim = self.conv_lora_dim + alpha = self.conv_alpha + else: + continue + + lora = LoRAModule(lora_name, child_module, self.multiplier, dim, alpha) loras.append(lora) return loras