diff --git a/flux_minimal_inference.py b/flux_minimal_inference.py index de607c52..2f1b9a37 100644 --- a/flux_minimal_inference.py +++ b/flux_minimal_inference.py @@ -14,9 +14,11 @@ from tqdm import tqdm from PIL import Image import accelerate from transformers import CLIPTextModel +from safetensors.torch import load_file from library import device_utils from library.device_utils import init_ipex, get_preferred_device +from networks import oft_flux init_ipex() @@ -405,7 +407,7 @@ if __name__ == "__main__": type=str, nargs="*", default=[], - help="LoRA weights, only supports networks.lora_flux, each argument is a `path;multiplier` (semi-colon separated)", + help="LoRA weights, only supports networks.lora_flux and lora_oft, each argument is a `path;multiplier` (semi-colon separated)", ) parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model") parser.add_argument("--width", type=int, default=target_width) @@ -482,9 +484,19 @@ if __name__ == "__main__": else: multiplier = 1.0 - lora_model, weights_sd = lora_flux.create_network_from_weights( - multiplier, weights_file, ae, [clip_l, t5xxl], model, None, True - ) + weights_sd = load_file(weights_file) + is_lora = is_oft = False + for key in weights_sd.keys(): + if key.startswith("lora"): + is_lora = True + if key.startswith("oft"): + is_oft = True + if is_lora or is_oft: + break + + module = lora_flux if is_lora else oft_flux + lora_model, _ = module.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True) + if args.merge_lora_weights: lora_model.merge_to([clip_l, t5xxl], model, weights_sd) else: diff --git a/networks/lora_flux.py b/networks/lora_flux.py index dd267de0..ea7df8b4 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -41,7 +41,11 @@ class LoRAModule(torch.nn.Module): module_dropout=None, split_dims: Optional[List[int]] = None, ): - """if alpha == 0 or None, alpha is rank (no scaling).""" + """ + if alpha == 0 or None, alpha is rank (no scaling). + + split_dims is used to mimic the split qkv of FLUX as same as Diffusers + """ super().__init__() self.lora_name = lora_name diff --git a/networks/oft.py b/networks/oft.py index 6321def3..0c3a5393 100644 --- a/networks/oft.py +++ b/networks/oft.py @@ -51,7 +51,7 @@ class OFTModule(torch.nn.Module): alpha = alpha.detach().numpy() # constraint in original paper is alpha * out_dim * out_dim, but we use alpha * out_dim for backward compatibility - # original alpha is 1e-6, so we use 1e-3 or 1e-4 for alpha + # original alpha is 1e-5, so we use 1e-2 or 1e-4 for alpha self.constraint = alpha * out_dim self.register_buffer("alpha", torch.tensor(alpha)) diff --git a/networks/oft_flux.py b/networks/oft_flux.py new file mode 100644 index 00000000..27b8b637 --- /dev/null +++ b/networks/oft_flux.py @@ -0,0 +1,482 @@ +# OFT network module + +import math +import os +from typing import Dict, List, Optional, Tuple, Type, Union +from diffusers import AutoencoderKL +import einops +from transformers import CLIPTextModel +import numpy as np +import torch +import torch.nn.functional as F +import re +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +class OFTModule(torch.nn.Module): + """ + replaces forward method of the original Linear, instead of replacing the original Linear module. + """ + + def __init__( + self, + oft_name, + org_module: torch.nn.Module, + multiplier=1.0, + dim=4, + alpha=1, + split_dims: Optional[List[int]] = None, + ): + """ + dim -> num blocks + alpha -> constraint + + split_dims is used to mimic the split qkv of FLUX as same as Diffusers + """ + super().__init__() + self.oft_name = oft_name + self.num_blocks = dim + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().numpy() + self.register_buffer("alpha", torch.tensor(alpha)) + + # No conv2d in FLUX + # if "Linear" in org_module.__class__.__name__: + self.out_dim = org_module.out_features + # elif "Conv" in org_module.__class__.__name__: + # out_dim = org_module.out_channels + + if split_dims is None: + split_dims = [self.out_dim] + else: + assert sum(split_dims) == self.out_dim, "sum of split_dims must be equal to out_dim" + self.split_dims = split_dims + + # assert all dim is divisible by num_blocks + for split_dim in self.split_dims: + assert split_dim % self.num_blocks == 0, "split_dim must be divisible by num_blocks" + + self.constraint = [alpha * split_dim for split_dim in self.split_dims] + self.block_size = [split_dim // self.num_blocks for split_dim in self.split_dims] + self.oft_blocks = torch.nn.ParameterList( + [torch.nn.Parameter(torch.zeros(self.num_blocks, block_size, block_size)) for block_size in self.block_size] + ) + self.I = [torch.eye(block_size).unsqueeze(0).repeat(self.num_blocks, 1, 1) for block_size in self.block_size] + + self.shape = org_module.weight.shape + self.multiplier = multiplier + self.org_module = [org_module] # moduleにならないようにlistに入れる + + def apply_to(self): + self.org_forward = self.org_module[0].forward + self.org_module[0].forward = self.forward + + def get_weight(self, multiplier=None): + if multiplier is None: + multiplier = self.multiplier + + if self.I[0].device != self.oft_blocks[0].device: + self.I = [I.to(self.oft_blocks[0].device) for I in self.I] + + block_R_weighted_list = [] + for i in range(len(self.oft_blocks)): + block_Q = self.oft_blocks[i] - self.oft_blocks[i].transpose(1, 2) + norm_Q = torch.norm(block_Q.flatten()) + new_norm_Q = torch.clamp(norm_Q, max=self.constraint[i]) + block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) + + I = self.I[i] + block_R = torch.matmul(I + block_Q, (I - block_Q).float().inverse()) + block_R_weighted = self.multiplier * (block_R - I) + I + + block_R_weighted_list.append(block_R_weighted) + + return block_R_weighted_list + + def forward(self, x, scale=None): + if self.multiplier == 0.0: + return self.org_forward(x) + + org_module = self.org_module[0] + org_dtype = x.dtype + + R = self.get_weight() + W = org_module.weight.to(torch.float32) + B = org_module.bias.to(torch.float32) + + # split W to match R + results = [] + d2 = 0 + for i in range(len(R)): + d1 = d2 + d2 += self.split_dims[i] + + W1 = W[d1:d2] + W_reshaped = einops.rearrange(W1, "(k n) m -> k n m", k=self.num_blocks, n=self.block_size[i]) + RW_1 = torch.einsum("k n m, k n p -> k m p", R[i], W_reshaped) + RW_1 = einops.rearrange(RW_1, "k m p -> (k m) p") + + B1 = B[d1:d2] + result = F.linear(x, RW_1.to(org_dtype), B1.to(org_dtype)) + results.append(result) + + result = torch.cat(results, dim=-1) + return result + + +class OFTInfModule(OFTModule): + def __init__( + self, + oft_name, + org_module: torch.nn.Module, + multiplier=1.0, + dim=4, + alpha=1, + split_dims: Optional[List[int]] = None, + **kwargs, + ): + # no dropout for inference + super().__init__(oft_name, org_module, multiplier, dim, alpha, split_dims) + self.enabled = True + self.network: OFTNetwork = None + + def set_network(self, network): + self.network = network + + def forward(self, x, scale=None): + if not self.enabled: + return self.org_forward(x) + return super().forward(x, scale) + + def merge_to(self, multiplier=None): + # get org weight + org_sd = self.org_module[0].state_dict() + W = org_sd["weight"].to(torch.float32) + R = self.get_weight(multiplier).to(torch.float32) + + d2 = 0 + W_list = [] + for i in range(len(self.oft_blocks)): + d1 = d2 + d2 += self.split_dims[i] + + W1 = W[d1:d2] + W_reshaped = einops.rearrange(W1, "(k n) m -> k n m", k=self.num_blocks, n=self.block_size[i]) + W1 = torch.einsum("k n m, k n p -> k m p", R[i], W_reshaped) + W1 = einops.rearrange(W1, "k m p -> (k m) p") + + W_list.append(W1) + + W = torch.cat(W_list, dim=-1) + + # convert back to original dtype + W = W.to(org_sd["weight"].dtype) + + # set weight to org_module + org_sd["weight"] = W + self.org_module[0].load_state_dict(org_sd) + + +def create_network( + multiplier: float, + network_dim: Optional[int], + network_alpha: Optional[float], + vae: AutoencoderKL, + text_encoder: Union[CLIPTextModel, List[CLIPTextModel]], + unet, + neuron_dropout: Optional[float] = None, + **kwargs, +): + if network_dim is None: + network_dim = 4 # default + if network_alpha is None: # should be set + logger.info( + "network_alpha is not set, use default value 1e-3 / network_alphaが設定されていないのでデフォルト値 1e-3 を使用します" + ) + network_alpha = 1e-3 + elif network_alpha >= 1: + logger.warning( + "network_alpha is too large (>=1, maybe default value is too large), please consider to set smaller value like 1e-3" + " / network_alphaが大きすぎるようです(>=1, デフォルト値が大きすぎる可能性があります)。1e-3のような小さな値を推奨" + ) + + # attn only or all linear (FFN) layers + enable_all_linear = kwargs.get("enable_all_linear", None) + # enable_conv = kwargs.get("enable_conv", None) + if enable_all_linear is not None: + enable_all_linear = bool(enable_all_linear) + # if enable_conv is not None: + # enable_conv = bool(enable_conv) + + network = OFTNetwork( + text_encoder, + unet, + multiplier=multiplier, + dim=network_dim, + alpha=network_alpha, + enable_all_linear=enable_all_linear, + varbose=True, + ) + return network + + +# Create network from weights for inference, weights are not loaded here (because can be merged) +def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs): + if weights_sd is None: + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file, safe_open + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + # check dim, alpha and if weights have for conv2d + dim = None + alpha = None + all_linear = None + for name, param in weights_sd.items(): + if name.endswith(".alpha"): + if alpha is None: + alpha = param.item() + elif "qkv" in name: + continue # ignore qkv + else: + if dim is None: + dim = param.size()[0] + if all_linear is None and "_mlp" in name: + all_linear = True + if dim is not None and alpha is not None and all_linear is not None: + break + if all_linear is None: + all_linear = False + + module_class = OFTInfModule if for_inference else OFTModule + network = OFTNetwork( + text_encoder, + unet, + multiplier=multiplier, + dim=dim, + alpha=alpha, + enable_all_linear=all_linear, + module_class=module_class, + ) + return network, weights_sd + + +class OFTNetwork(torch.nn.Module): + FLUX_TARGET_REPLACE_MODULE_ALL_LINEAR = ["DoubleStreamBlock", "SingleStreamBlock"] + FLUX_TARGET_REPLACE_MODULE_ATTN_ONLY = ["SelfAttention"] + OFT_PREFIX_UNET = "oft_unet" + + def __init__( + self, + text_encoder: Union[List[CLIPTextModel], CLIPTextModel], + unet, + multiplier: float = 1.0, + dim: int = 4, + alpha: float = 1, + enable_all_linear: Optional[bool] = False, + module_class: Union[Type[OFTModule], Type[OFTInfModule]] = OFTModule, + varbose: Optional[bool] = False, + ) -> None: + super().__init__() + self.train_t5xxl = False # make compatible with LoRA + self.multiplier = multiplier + + self.dim = dim + self.alpha = alpha + + logger.info( + f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_all_linear: {enable_all_linear}" + ) + + # create module instances + def create_modules( + root_module: torch.nn.Module, + target_replace_modules: List[torch.nn.Module], + ) -> List[OFTModule]: + prefix = self.OFT_PREFIX_UNET + ofts = [] + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + is_linear = "Linear" in child_module.__class__.__name__ + + if is_linear: + oft_name = prefix + "." + name + "." + child_name + oft_name = oft_name.replace(".", "_") + # logger.info(oft_name) + + if "double" in oft_name and "qkv" in oft_name: + split_dims = [3072] * 3 + elif "single" in oft_name and "linear1" in oft_name: + split_dims = [3072] * 3 + [12288] + else: + split_dims = None + + oft = module_class(oft_name, child_module, self.multiplier, dim, alpha, split_dims) + ofts.append(oft) + return ofts + + # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights + if enable_all_linear: + target_modules = OFTNetwork.FLUX_TARGET_REPLACE_MODULE_ALL_LINEAR + else: + target_modules = OFTNetwork.FLUX_TARGET_REPLACE_MODULE_ATTN_ONLY + + self.unet_ofts: List[OFTModule] = create_modules(unet, target_modules) + logger.info(f"create OFT for Flux: {len(self.unet_ofts)} modules.") + + # assertion + names = set() + for oft in self.unet_ofts: + assert oft.oft_name not in names, f"duplicated oft name: {oft.oft_name}" + names.add(oft.oft_name) + + def set_multiplier(self, multiplier): + self.multiplier = multiplier + for oft in self.unet_ofts: + oft.multiplier = self.multiplier + + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + info = self.load_state_dict(weights_sd, False) + return info + + def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): + assert apply_unet, "apply_unet must be True" + + for oft in self.unet_ofts: + oft.apply_to() + self.add_module(oft.oft_name, oft) + + # マージできるかどうかを返す + def is_mergeable(self): + return True + + # TODO refactor to common function with apply_to + def merge_to(self, text_encoder, unet, weights_sd, dtype, device): + logger.info("enable OFT for U-Net") + + for oft in self.unet_ofts: + sd_for_lora = {} + for key in weights_sd.keys(): + if key.startswith(oft.oft_name): + sd_for_lora[key[len(oft.oft_name) + 1 :]] = weights_sd[key] + oft.load_state_dict(sd_for_lora, False) + oft.merge_to() + + logger.info(f"weights are merged") + + # 二つのText Encoderに別々の学習率を設定できるようにするといいかも + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): + self.requires_grad_(True) + all_params = [] + + def enumerate_params(ofts): + params = [] + for oft in ofts: + params.extend(oft.parameters()) + + # logger.info num of params + num_params = 0 + for p in params: + num_params += p.numel() + logger.info(f"OFT params: {num_params}") + return params + + param_data = {"params": enumerate_params(self.unet_ofts)} + if unet_lr is not None: + param_data["lr"] = unet_lr + all_params.append(param_data) + + return all_params + + def enable_gradient_checkpointing(self): + # not supported + pass + + def prepare_grad_etc(self, text_encoder, unet): + self.requires_grad_(True) + + def on_epoch_start(self, text_encoder, unet): + self.train() + + def get_trainable_params(self): + return self.parameters() + + def save_weights(self, file, dtype, metadata): + if metadata is not None and len(metadata) == 0: + metadata = None + + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + from library import train_util + + # Precalculate model hashes to save time on indexing + if metadata is None: + metadata = {} + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) + + def backup_weights(self): + # 重みのバックアップを行う + ofts: List[OFTInfModule] = self.unet_ofts + for oft in ofts: + org_module = oft.org_module[0] + if not hasattr(org_module, "_lora_org_weight"): + sd = org_module.state_dict() + org_module._lora_org_weight = sd["weight"].detach().clone() + org_module._lora_restored = True + + def restore_weights(self): + # 重みのリストアを行う + ofts: List[OFTInfModule] = self.unet_ofts + for oft in ofts: + org_module = oft.org_module[0] + if not org_module._lora_restored: + sd = org_module.state_dict() + sd["weight"] = org_module._lora_org_weight + org_module.load_state_dict(sd) + org_module._lora_restored = True + + def pre_calculation(self): + # 事前計算を行う + ofts: List[OFTInfModule] = self.unet_ofts + for oft in ofts: + org_module = oft.org_module[0] + oft.merge_to() + # sd = org_module.state_dict() + # org_weight = sd["weight"] + # lora_weight = oft.get_weight().to(org_weight.device, dtype=org_weight.dtype) + # sd["weight"] = org_weight + lora_weight + # assert sd["weight"].shape == org_weight.shape + # org_module.load_state_dict(sd) + + org_module._lora_restored = False + oft.enabled = False