diff --git a/library/flux_utils.py b/library/flux_utils.py index 86a2ec60..f3093615 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -54,6 +54,10 @@ def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int with safe_open(ckpt_path, framework="pt") as f: keys.extend(f.keys()) + # if the key has annoying prefix, remove it + if keys[0].startswith("model.diffusion_model."): + keys = [key.replace("model.diffusion_model.", "") for key in keys] + is_diffusers = "transformer_blocks.0.attn.add_k_proj.bias" in keys is_schnell = not ("guidance_in.in_layer.bias" in keys or "time_text_embed.guidance_embedder.linear_1.bias" in keys) @@ -122,6 +126,13 @@ def load_flow_model( sd = convert_diffusers_sd_to_bfl(sd, num_double_blocks, num_single_blocks) logger.info("Converted Diffusers to BFL") + # if the key has annoying prefix, remove it + for key in list(sd.keys()): + new_key = key.replace("model.diffusion_model.", "") + if new_key == key: + break # the model doesn't have annoying prefix + sd[new_key] = sd.pop(key) + info = model.load_state_dict(sd, strict=False, assign=True) logger.info(f"Loaded Flux: {info}") return is_schnell, model diff --git a/networks/lora_sd3.py b/networks/lora_sd3.py index cbabf8da..efe20245 100644 --- a/networks/lora_sd3.py +++ b/networks/lora_sd3.py @@ -307,6 +307,7 @@ class LoRANetwork(torch.nn.Module): target_replace_modules: List[str], filter: Optional[str] = None, default_dim: Optional[int] = None, + include_conv2d_if_filter: bool = False, ) -> List[LoRAModule]: prefix = ( self.LORA_PREFIX_SD3 @@ -332,8 +333,11 @@ class LoRANetwork(torch.nn.Module): lora_name = prefix + "." + (name + "." if name else "") + child_name lora_name = lora_name.replace(".", "_") - if filter is not None and not filter in lora_name: - continue + force_incl_conv2d = False + if filter is not None: + if not filter in lora_name: + continue + force_incl_conv2d = include_conv2d_if_filter dim = None alpha = None @@ -373,6 +377,10 @@ class LoRANetwork(torch.nn.Module): elif self.conv_lora_dim is not None: dim = self.conv_lora_dim alpha = self.conv_alpha + elif force_incl_conv2d: + # x_embedder + dim = default_dim if default_dim is not None else self.lora_dim + alpha = self.alpha if dim is None or dim == 0: # skipした情報を出力 @@ -428,7 +436,7 @@ class LoRANetwork(torch.nn.Module): for filter, in_dim in zip( [ "context_embedder", - "t_embedder", + "_t_embedder", # don't use "t_embedder" because it's used in "context_embedder" "x_embedder", "y_embedder", "final_layer_adaLN_modulation", @@ -436,7 +444,12 @@ class LoRANetwork(torch.nn.Module): ], self.emb_dims, ): - loras, _ = create_modules(True, None, unet, None, filter=filter, default_dim=in_dim) + # x_embedder is conv2d, so we need to include it + loras, _ = create_modules( + True, None, unet, None, filter=filter, default_dim=in_dim, include_conv2d_if_filter=filter == "x_embedder" + ) + # if len(loras) > 0: + # logger.info(f"create LoRA for {filter}: {len(loras)} modules.") self.unet_loras.extend(loras) logger.info(f"create LoRA for SD3 MMDiT: {len(self.unet_loras)} modules.") @@ -540,8 +553,8 @@ class LoRANetwork(torch.nn.Module): down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim) # merge up weight (sum of split_dim, rank*3) - qkv_dim, rank = up_weights[0].size() - split_dim = qkv_dim // 3 + split_dim, rank = up_weights[0].size() + qkv_dim = split_dim * 3 up_weight = torch.zeros((qkv_dim, down_weight.size(0)), device=down_weight.device, dtype=down_weight.dtype) i = 0 for j in range(3): diff --git a/sd3_minimal_inference.py b/sd3_minimal_inference.py index d099fe18..86dba246 100644 --- a/sd3_minimal_inference.py +++ b/sd3_minimal_inference.py @@ -10,11 +10,13 @@ import numpy as np import torch from safetensors.torch import safe_open, load_file +import torch.amp from tqdm import tqdm from PIL import Image from transformers import CLIPTextModelWithProjection, T5EncoderModel from library.device_utils import init_ipex, get_preferred_device +from networks import lora_sd3 init_ipex() @@ -104,7 +106,8 @@ def do_sample( x_c_nc = torch.cat([x, x], dim=0) # print(x_c_nc.shape, timestep.shape, c_crossattn.shape, y.shape) - model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y) + with torch.autocast(device_type=device.type, dtype=dtype): + model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y) model_output = model_output.float() batched = model_sampling.calculate_denoised(sigma_hat, model_output, x) @@ -153,7 +156,7 @@ def generate_image( clip_g.to(device) t5xxl.to(device) - with torch.no_grad(): + with torch.autocast(device_type=device.type, dtype=mmdit.dtype), torch.no_grad(): tokens_and_masks = tokenize_strategy.tokenize(prompt) lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = encoding_strategy.encode_tokens( tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask @@ -233,13 +236,14 @@ if __name__ == "__main__": parser.add_argument("--bf16", action="store_true") parser.add_argument("--seed", type=int, default=1) parser.add_argument("--steps", type=int, default=50) - # parser.add_argument( - # "--lora_weights", - # type=str, - # nargs="*", - # default=[], - # help="LoRA weights, only supports networks.lora, each argument is a `path;multiplier` (semi-colon separated)", - # ) + parser.add_argument( + "--lora_weights", + type=str, + nargs="*", + default=[], + help="LoRA weights, only supports networks.lora_sd3, each argument is a `path;multiplier` (semi-colon separated)", + ) + parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model") parser.add_argument("--width", type=int, default=target_width) parser.add_argument("--height", type=int, default=target_height) parser.add_argument("--interactive", action="store_true") @@ -294,6 +298,30 @@ if __name__ == "__main__": tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_token_length) encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy() + # LoRA + lora_models: list[lora_sd3.LoRANetwork] = [] + for weights_file in args.lora_weights: + if ";" in weights_file: + weights_file, multiplier = weights_file.split(";") + multiplier = float(multiplier) + else: + multiplier = 1.0 + + weights_sd = load_file(weights_file) + module = lora_sd3 + lora_model, _ = module.create_network_from_weights(multiplier, None, vae, [clip_l, clip_g, t5xxl], mmdit, weights_sd, True) + + if args.merge_lora_weights: + lora_model.merge_to([clip_l, clip_g, t5xxl], mmdit, weights_sd) + else: + lora_model.apply_to([clip_l, clip_g, t5xxl], mmdit) + info = lora_model.load_state_dict(weights_sd, strict=True) + logger.info(f"Loaded LoRA weights from {weights_file}: {info}") + lora_model.eval() + lora_model.to(device) + + lora_models.append(lora_model) + if not args.interactive: generate_image( mmdit, @@ -344,13 +372,13 @@ if __name__ == "__main__": steps = int(opt[1:].strip()) elif opt.startswith("d"): seed = int(opt[1:].strip()) - # elif opt.startswith("m"): - # mutipliers = opt[1:].strip().split(",") - # if len(mutipliers) != len(lora_models): - # logger.error(f"Invalid number of multipliers, expected {len(lora_models)}") - # continue - # for i, lora_model in enumerate(lora_models): - # lora_model.set_multiplier(float(mutipliers[i])) + elif opt.startswith("m"): + mutipliers = opt[1:].strip().split(",") + if len(mutipliers) != len(lora_models): + logger.error(f"Invalid number of multipliers, expected {len(lora_models)}") + continue + for i, lora_model in enumerate(lora_models): + lora_model.set_multiplier(float(mutipliers[i])) elif opt.startswith("n"): negative_prompt = opt[1:].strip() if negative_prompt == "-":