mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
Compare commits
5 Commits
a029c38b4b
...
b3de44417e
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b3de44417e | ||
|
|
5948a59e89 | ||
|
|
cc1f57bc70 | ||
|
|
57aa70ea9e | ||
|
|
3612bedda6 |
@@ -13,7 +13,7 @@ import toml
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
from library import flux_train_utils, qwen_image_autoencoder_kl, utils
|
||||
from library import flux_train_utils, qwen_image_autoencoder_kl
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
|
||||
|
||||
@@ -72,16 +72,6 @@ def train(args):
|
||||
args.blocks_to_swap is None or args.blocks_to_swap == 0
|
||||
) or not args.unsloth_offload_checkpointing, "blocks_to_swap is not supported with unsloth_offload_checkpointing"
|
||||
|
||||
# # Flash attention: validate availability
|
||||
# if args.flash_attn:
|
||||
# try:
|
||||
# import flash_attn # noqa: F401
|
||||
|
||||
# logger.info("Flash Attention enabled for DiT blocks")
|
||||
# except ImportError:
|
||||
# logger.warning("flash_attn package not installed, falling back to PyTorch SDPA")
|
||||
# args.flash_attn = False
|
||||
|
||||
cache_latents = args.cache_latents
|
||||
use_dreambooth_method = args.in_json is None
|
||||
|
||||
@@ -348,7 +338,7 @@ def train(args):
|
||||
assert args.mixed_precision == "bf16", "full_bf16 requires mixed_precision='bf16'"
|
||||
accelerator.print("enable full bf16 training.")
|
||||
else:
|
||||
dit_weight_dtype = torch.float32 # Default to float32
|
||||
dit_weight_dtype = torch.float32 # If neither full_fp16 nor full_bf16, the model weights should be in float32
|
||||
dit.to(dit_weight_dtype) # convert dit to target weight dtype
|
||||
|
||||
# move text encoder to GPU if not cached
|
||||
@@ -431,6 +421,7 @@ def train(args):
|
||||
global_step = 0
|
||||
|
||||
noise_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
|
||||
# Copy for noise and timestep generation, because noise_scheduler may be changed during training in future
|
||||
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
@@ -540,7 +531,7 @@ def train(args):
|
||||
|
||||
# Get noisy model input and timesteps
|
||||
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
|
||||
args, noise_scheduler, latents, noise, accelerator.device, dit_weight_dtype
|
||||
args, noise_scheduler_copy, latents, noise, accelerator.device, dit_weight_dtype
|
||||
)
|
||||
timesteps = timesteps / 1000.0 # scale to [0, 1] range. timesteps is float32
|
||||
|
||||
|
||||
@@ -352,10 +352,6 @@ class Attention(nn.Module):
|
||||
|
||||
return q, k, v
|
||||
|
||||
# def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
||||
# result = self.attn_op(q, k, v) # [B, S, H, D]
|
||||
# return self.output_dropout(self.output_proj(result))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
|
||||
@@ -5,7 +5,7 @@ import gc
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
import os
|
||||
from typing import Dict, List, Optional, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from safetensors.torch import load_file, save_file
|
||||
from accelerate.utils import set_module_tensor_to_device # kept for potential future use
|
||||
from accelerate import init_empty_weights
|
||||
@@ -38,11 +37,11 @@ def load_anima_model(
|
||||
loading_device: Union[str, torch.device],
|
||||
dit_weight_dtype: Optional[torch.dtype],
|
||||
fp8_scaled: bool = False,
|
||||
lora_weights_list: Optional[Dict[str, torch.Tensor]] = None,
|
||||
lora_weights_list: Optional[List[Dict[str, torch.Tensor]]] = None,
|
||||
lora_multipliers: Optional[list[float]] = None,
|
||||
) -> anima_models.Anima:
|
||||
"""
|
||||
Load a HunyuanImage model from the specified checkpoint.
|
||||
Load Anima model from the specified checkpoint.
|
||||
|
||||
Args:
|
||||
device (Union[str, torch.device]): Device for optimization or merging
|
||||
@@ -53,7 +52,7 @@ def load_anima_model(
|
||||
dit_weight_dtype (Optional[torch.dtype]): Data type of the DiT weights.
|
||||
If None, it will be loaded as is (same as the state_dict) or scaled for fp8. if not None, model weights will be casted to this dtype.
|
||||
fp8_scaled (bool): Whether to use fp8 scaling for the model weights.
|
||||
lora_weights_list (Optional[Dict[str, torch.Tensor]]): LoRA weights to apply, if any.
|
||||
lora_weights_list (Optional[List[Dict[str, torch.Tensor]]]): LoRA weights to apply, if any.
|
||||
lora_multipliers (Optional[List[float]]): LoRA multipliers for the weights, if any.
|
||||
"""
|
||||
# dit_weight_dtype is None for fp8_scaled
|
||||
|
||||
@@ -28,8 +28,6 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
import logging
|
||||
|
||||
from library.safetensors_utils import load_safetensors
|
||||
|
||||
from library.utils import setup_logging
|
||||
|
||||
156
networks/convert_anima_lora_to_comfy.py
Normal file
156
networks/convert_anima_lora_to_comfy.py
Normal file
@@ -0,0 +1,156 @@
|
||||
import argparse
|
||||
from safetensors.torch import save_file
|
||||
from safetensors import safe_open
|
||||
|
||||
|
||||
from library import train_util
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def main(args):
|
||||
# load source safetensors
|
||||
logger.info(f"Loading source file {args.src_path}")
|
||||
state_dict = {}
|
||||
with safe_open(args.src_path, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
for k in f.keys():
|
||||
state_dict[k] = f.get_tensor(k)
|
||||
|
||||
logger.info(f"Converting...")
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
count = 0
|
||||
|
||||
for k in keys:
|
||||
if not args.reverse:
|
||||
is_dit_lora = k.startswith("lora_unet_")
|
||||
module_and_weight_name = "_".join(k.split("_")[2:]) # Remove `lora_unet_`or `lora_te_` prefix
|
||||
|
||||
# Split at the first dot, e.g., "block1_linear.weight" -> "block1_linear", "weight"
|
||||
module_name, weight_name = module_and_weight_name.split(".", 1)
|
||||
|
||||
# Weight name conversion: lora_up/lora_down to lora_A/lora_B
|
||||
if weight_name.startswith("lora_up"):
|
||||
weight_name = weight_name.replace("lora_up", "lora_B")
|
||||
elif weight_name.startswith("lora_down"):
|
||||
weight_name = weight_name.replace("lora_down", "lora_A")
|
||||
else:
|
||||
# Keep other weight names as-is: e.g. alpha
|
||||
pass
|
||||
|
||||
# Module name conversion: convert dots to underscores
|
||||
original_module_name = module_name.replace("_", ".") # Convert to dot notation
|
||||
|
||||
# Convert back illegal dots in module names
|
||||
# DiT
|
||||
original_module_name = original_module_name.replace(".linear.", ".linear_")
|
||||
original_module_name = original_module_name.replace("t.embedding.norm", "t_embedding_norm")
|
||||
original_module_name = original_module_name.replace("x.embedder", "x_embedder")
|
||||
original_module_name = original_module_name.replace("adaln.modulation.cross_attn", "adaln_modulation_cross_attn")
|
||||
original_module_name = original_module_name.replace("adaln.modulation.mlp", "adaln_modulation_mlp")
|
||||
original_module_name = original_module_name.replace("cross.attn", "cross_attn")
|
||||
original_module_name = original_module_name.replace("k.proj", "k_proj")
|
||||
original_module_name = original_module_name.replace("k.norm", "k_norm")
|
||||
original_module_name = original_module_name.replace("q.proj", "q_proj")
|
||||
original_module_name = original_module_name.replace("q.norm", "q_norm")
|
||||
original_module_name = original_module_name.replace("v.proj", "v_proj")
|
||||
original_module_name = original_module_name.replace("o.proj", "o_proj")
|
||||
original_module_name = original_module_name.replace("output.proj", "output_proj")
|
||||
original_module_name = original_module_name.replace("self.attn", "self_attn")
|
||||
original_module_name = original_module_name.replace("final.layer", "final_layer")
|
||||
original_module_name = original_module_name.replace("adaln.modulation", "adaln_modulation")
|
||||
original_module_name = original_module_name.replace("norm.cross.attn", "norm_cross_attn")
|
||||
original_module_name = original_module_name.replace("norm.mlp", "norm_mlp")
|
||||
original_module_name = original_module_name.replace("norm.self.attn", "norm_self_attn")
|
||||
original_module_name = original_module_name.replace("out.proj", "out_proj")
|
||||
|
||||
# Qwen3
|
||||
original_module_name = original_module_name.replace("embed.tokens", "embed_tokens")
|
||||
original_module_name = original_module_name.replace("input.layernorm", "input_layernorm")
|
||||
original_module_name = original_module_name.replace("down.proj", "down_proj")
|
||||
original_module_name = original_module_name.replace("gate.proj", "gate_proj")
|
||||
original_module_name = original_module_name.replace("up.proj", "up_proj")
|
||||
original_module_name = original_module_name.replace("post.attention.layernorm", "post_attention_layernorm")
|
||||
|
||||
# Prefix conversion
|
||||
new_prefix = "diffusion_model." if is_dit_lora else "text_encoder.qwen3."
|
||||
|
||||
new_k = f"{new_prefix}{original_module_name}.{weight_name}"
|
||||
else:
|
||||
if k.startswith("diffusion_model."):
|
||||
is_dit_lora = True
|
||||
module_and_weight_name = k[len("diffusion_model.") :]
|
||||
elif k.startswith("text_encoder.qwen3."):
|
||||
is_dit_lora = False
|
||||
module_and_weight_name = k[len("text_encoder.qwen3.") :]
|
||||
else:
|
||||
logger.warning(f"Skipping unrecognized key {k}")
|
||||
continue
|
||||
|
||||
# Get weight name
|
||||
if ".lora_" in module_and_weight_name:
|
||||
module_name, weight_name = module_and_weight_name.rsplit(".lora_", 1)
|
||||
weight_name = "lora_" + weight_name
|
||||
else:
|
||||
module_name, weight_name = module_and_weight_name.rsplit(".", 1) # Keep other weight names as-is: e.g. alpha
|
||||
|
||||
# Weight name conversion: lora_A/lora_B to lora_up/lora_down
|
||||
# Note: we only convert lora_A and lora_B weights, other weights are kept as-is
|
||||
if weight_name.startswith("lora_B"):
|
||||
weight_name = weight_name.replace("lora_B", "lora_up")
|
||||
elif weight_name.startswith("lora_A"):
|
||||
weight_name = weight_name.replace("lora_A", "lora_down")
|
||||
|
||||
# Module name conversion: convert dots to underscores
|
||||
module_name = module_name.replace(".", "_") # Convert to underscore notation
|
||||
|
||||
# Prefix conversion
|
||||
prefix = "lora_unet_" if is_dit_lora else "lora_te_"
|
||||
|
||||
new_k = f"{prefix}{module_name}.{weight_name}"
|
||||
|
||||
state_dict[new_k] = state_dict.pop(k)
|
||||
count += 1
|
||||
|
||||
logger.info(f"Converted {count} keys")
|
||||
if count == 0:
|
||||
logger.warning("No keys were converted. Please check if the source file is in the expected format.")
|
||||
elif count > 0 and count < len(keys):
|
||||
logger.warning(
|
||||
f"Only {count} out of {len(keys)} keys were converted. Please check if there are unexpected keys in the source file."
|
||||
)
|
||||
|
||||
# Calculate hash
|
||||
if metadata is not None:
|
||||
logger.info(f"Calculating hashes and creating metadata...")
|
||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||
metadata["sshs_model_hash"] = model_hash
|
||||
metadata["sshs_legacy_hash"] = legacy_hash
|
||||
|
||||
# save destination safetensors
|
||||
logger.info(f"Saving destination file {args.dst_path}")
|
||||
save_file(state_dict, args.dst_path, metadata=metadata)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert LoRA format")
|
||||
parser.add_argument(
|
||||
"src_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="source path, sd-scripts format (or ComfyUI compatible format if --reverse is set, only supported for LoRAs converted by this script)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"dst_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="destination path, ComfyUI compatible format (or sd-scripts format if --reverse is set)",
|
||||
)
|
||||
parser.add_argument("--reverse", action="store_true", help="reverse conversion direction")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
Reference in New Issue
Block a user