mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
* fix: update extend-exclude list in _typos.toml to include configs * fix: exclude anima tests from pytest * feat: add entry for 'temperal' in extend-words section of _typos.toml for Qwen-Image VAE * fix: update default value for --discrete_flow_shift in anima training guide * feat: add Qwen-Image VAE * feat: simplify encode_tokens * feat: use unified attention module, add wrapper for state dict compatibility * feat: loading with dynamic fp8 optimization and LoRA support * feat: add anima minimal inference script (WIP) * format: format * feat: simplify target module selection by regular expression patterns * feat: kept caption dropout rate in cache and handle in training script * feat: update train_llm_adapter and verbose default values to string type * fix: use strategy instead of using tokenizers directly * feat: add dtype property and all-zero mask handling in cross-attention in LLMAdapterTransformerBlock * feat: support 5d tensor in get_noisy_model_input_and_timesteps * feat: update loss calculation to support 5d tensor * fix: update argument names in anima_train_utils to align with other archtectures * feat: simplify Anima training script and update empty caption handling * feat: support LoRA format without `net.` prefix * fix: update to work fp8_scaled option * feat: add regex-based learning rates and dimensions handling in create_network * fix: improve regex matching for module selection and learning rates in LoRANetwork * fix: update logging message for regex match in LoRANetwork * fix: keep latents 4D except DiT call * feat: enhance block swap functionality for inference and training in Anima model * feat: refactor Anima training script * feat: optimize VAE processing by adjusting tensor dimensions and data types * fix: wait all block trasfer before siwtching offloader mode * feat: update Anima training guide with new argument specifications and regex-based module selection. Thank you Claude! * feat: support LORA for Qwen3 * feat: update Anima SAI model spec metadata handling * fix: remove unused code * feat: split CFG processing in do_sample function to reduce memory usage * feat: add VAE chunking and caching options to reduce memory usage * feat: optimize RMSNorm forward method and remove unused torch_attention_op * Update library/strategy_anima.py Use torch.all instead of all. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update library/safetensors_utils.py Fix duplicated new_key for concat_hook. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update anima_minimal_inference.py Remove unused code. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update anima_train.py Remove unused import. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update library/anima_train_utils.py Remove unused import. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix: review with Copilot * feat: add script to convert LoRA format to ComfyUI compatible format (WIP, not tested yet) * feat: add process_escape function to handle escape sequences in prompts * feat: enhance LoRA weight handling in model loading and add text encoder loading function * feat: improve ComfyUI conversion script with prefix constants and module name adjustments * feat: update caption dropout documentation to clarify cache regeneration requirement * feat: add clarification on learning rate adjustments * feat: add note on PyTorch version requirement to prevent NaN loss --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
268 lines
12 KiB
Python
268 lines
12 KiB
Python
import os
|
|
import re
|
|
from typing import Dict, List, Optional, Union
|
|
import torch
|
|
from tqdm import tqdm
|
|
from library.device_utils import synchronize_device
|
|
from library.fp8_optimization_utils import load_safetensors_with_fp8_optimization
|
|
from library.safetensors_utils import MemoryEfficientSafeOpen, TensorWeightAdapter, WeightTransformHooks, get_split_weight_filenames
|
|
from library.utils import setup_logging
|
|
|
|
setup_logging()
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def filter_lora_state_dict(
|
|
weights_sd: Dict[str, torch.Tensor],
|
|
include_pattern: Optional[str] = None,
|
|
exclude_pattern: Optional[str] = None,
|
|
) -> Dict[str, torch.Tensor]:
|
|
# apply include/exclude patterns
|
|
original_key_count = len(weights_sd.keys())
|
|
if include_pattern is not None:
|
|
regex_include = re.compile(include_pattern)
|
|
weights_sd = {k: v for k, v in weights_sd.items() if regex_include.search(k)}
|
|
logger.info(f"Filtered keys with include pattern {include_pattern}: {original_key_count} -> {len(weights_sd.keys())}")
|
|
|
|
if exclude_pattern is not None:
|
|
original_key_count_ex = len(weights_sd.keys())
|
|
regex_exclude = re.compile(exclude_pattern)
|
|
weights_sd = {k: v for k, v in weights_sd.items() if not regex_exclude.search(k)}
|
|
logger.info(f"Filtered keys with exclude pattern {exclude_pattern}: {original_key_count_ex} -> {len(weights_sd.keys())}")
|
|
|
|
if len(weights_sd) != original_key_count:
|
|
remaining_keys = list(set([k.split(".", 1)[0] for k in weights_sd.keys()]))
|
|
remaining_keys.sort()
|
|
logger.info(f"Remaining LoRA modules after filtering: {remaining_keys}")
|
|
if len(weights_sd) == 0:
|
|
logger.warning("No keys left after filtering.")
|
|
|
|
return weights_sd
|
|
|
|
|
|
def load_safetensors_with_lora_and_fp8(
|
|
model_files: Union[str, List[str]],
|
|
lora_weights_list: Optional[List[Dict[str, torch.Tensor]]],
|
|
lora_multipliers: Optional[List[float]],
|
|
fp8_optimization: bool,
|
|
calc_device: torch.device,
|
|
move_to_device: bool = False,
|
|
dit_weight_dtype: Optional[torch.dtype] = None,
|
|
target_keys: Optional[List[str]] = None,
|
|
exclude_keys: Optional[List[str]] = None,
|
|
disable_numpy_memmap: bool = False,
|
|
weight_transform_hooks: Optional[WeightTransformHooks] = None,
|
|
) -> dict[str, torch.Tensor]:
|
|
"""
|
|
Merge LoRA weights into the state dict of a model with fp8 optimization if needed.
|
|
|
|
Args:
|
|
model_files (Union[str, List[str]]): Path to the model file or list of paths. If the path matches a pattern like `00001-of-00004`, it will load all files with the same prefix.
|
|
lora_weights_list (Optional[List[Dict[str, torch.Tensor]]]): List of dictionaries of LoRA weight tensors to load.
|
|
lora_multipliers (Optional[List[float]]): List of multipliers for LoRA weights.
|
|
fp8_optimization (bool): Whether to apply FP8 optimization.
|
|
calc_device (torch.device): Device to calculate on.
|
|
move_to_device (bool): Whether to move tensors to the calculation device after loading.
|
|
target_keys (Optional[List[str]]): Keys to target for optimization.
|
|
exclude_keys (Optional[List[str]]): Keys to exclude from optimization.
|
|
disable_numpy_memmap (bool): Whether to disable numpy memmap when loading safetensors.
|
|
weight_transform_hooks (Optional[WeightTransformHooks]): Hooks for transforming weights during loading.
|
|
"""
|
|
|
|
# if the file name ends with 00001-of-00004 etc, we need to load the files with the same prefix
|
|
if isinstance(model_files, str):
|
|
model_files = [model_files]
|
|
|
|
extended_model_files = []
|
|
for model_file in model_files:
|
|
split_filenames = get_split_weight_filenames(model_file)
|
|
if split_filenames is not None:
|
|
extended_model_files.extend(split_filenames)
|
|
else:
|
|
extended_model_files.append(model_file)
|
|
model_files = extended_model_files
|
|
logger.info(f"Loading model files: {model_files}")
|
|
|
|
# load LoRA weights
|
|
weight_hook = None
|
|
if lora_weights_list is None or len(lora_weights_list) == 0:
|
|
lora_weights_list = []
|
|
lora_multipliers = []
|
|
list_of_lora_weight_keys = []
|
|
else:
|
|
list_of_lora_weight_keys = []
|
|
for lora_sd in lora_weights_list:
|
|
lora_weight_keys = set(lora_sd.keys())
|
|
list_of_lora_weight_keys.append(lora_weight_keys)
|
|
|
|
if lora_multipliers is None:
|
|
lora_multipliers = [1.0] * len(lora_weights_list)
|
|
while len(lora_multipliers) < len(lora_weights_list):
|
|
lora_multipliers.append(1.0)
|
|
if len(lora_multipliers) > len(lora_weights_list):
|
|
lora_multipliers = lora_multipliers[: len(lora_weights_list)]
|
|
|
|
# Merge LoRA weights into the state dict
|
|
logger.info(f"Merging LoRA weights into state dict. multipliers: {lora_multipliers}")
|
|
|
|
# make hook for LoRA merging
|
|
def weight_hook_func(model_weight_key, model_weight: torch.Tensor, keep_on_calc_device=False):
|
|
nonlocal list_of_lora_weight_keys, lora_weights_list, lora_multipliers, calc_device
|
|
|
|
if not model_weight_key.endswith(".weight"):
|
|
return model_weight
|
|
|
|
original_device = model_weight.device
|
|
if original_device != calc_device:
|
|
model_weight = model_weight.to(calc_device) # to make calculation faster
|
|
|
|
for lora_weight_keys, lora_sd, multiplier in zip(list_of_lora_weight_keys, lora_weights_list, lora_multipliers):
|
|
# check if this weight has LoRA weights
|
|
lora_name_without_prefix = model_weight_key.rsplit(".", 1)[0] # remove trailing ".weight"
|
|
found = False
|
|
for prefix in ["lora_unet_", ""]:
|
|
lora_name = prefix + lora_name_without_prefix.replace(".", "_")
|
|
down_key = lora_name + ".lora_down.weight"
|
|
up_key = lora_name + ".lora_up.weight"
|
|
alpha_key = lora_name + ".alpha"
|
|
if down_key in lora_weight_keys and up_key in lora_weight_keys:
|
|
found = True
|
|
break
|
|
if not found:
|
|
continue # no LoRA weights for this model weight
|
|
|
|
# get LoRA weights
|
|
down_weight = lora_sd[down_key]
|
|
up_weight = lora_sd[up_key]
|
|
|
|
dim = down_weight.size()[0]
|
|
alpha = lora_sd.get(alpha_key, dim)
|
|
scale = alpha / dim
|
|
|
|
down_weight = down_weight.to(calc_device)
|
|
up_weight = up_weight.to(calc_device)
|
|
|
|
original_dtype = model_weight.dtype
|
|
if original_dtype.itemsize == 1: # fp8
|
|
# temporarily convert to float16 for calculation
|
|
model_weight = model_weight.to(torch.float16)
|
|
down_weight = down_weight.to(torch.float16)
|
|
up_weight = up_weight.to(torch.float16)
|
|
|
|
# W <- W + U * D
|
|
if len(model_weight.size()) == 2:
|
|
# linear
|
|
if len(up_weight.size()) == 4: # use linear projection mismatch
|
|
up_weight = up_weight.squeeze(3).squeeze(2)
|
|
down_weight = down_weight.squeeze(3).squeeze(2)
|
|
model_weight = model_weight + multiplier * (up_weight @ down_weight) * scale
|
|
elif down_weight.size()[2:4] == (1, 1):
|
|
# conv2d 1x1
|
|
model_weight = (
|
|
model_weight
|
|
+ multiplier
|
|
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
|
* scale
|
|
)
|
|
else:
|
|
# conv2d 3x3
|
|
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
|
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
|
|
model_weight = model_weight + multiplier * conved * scale
|
|
|
|
if original_dtype.itemsize == 1: # fp8
|
|
model_weight = model_weight.to(original_dtype) # convert back to original dtype
|
|
|
|
# remove LoRA keys from set
|
|
lora_weight_keys.remove(down_key)
|
|
lora_weight_keys.remove(up_key)
|
|
if alpha_key in lora_weight_keys:
|
|
lora_weight_keys.remove(alpha_key)
|
|
|
|
if not keep_on_calc_device and original_device != calc_device:
|
|
model_weight = model_weight.to(original_device) # move back to original device
|
|
return model_weight
|
|
|
|
weight_hook = weight_hook_func
|
|
|
|
state_dict = load_safetensors_with_fp8_optimization_and_hook(
|
|
model_files,
|
|
fp8_optimization,
|
|
calc_device,
|
|
move_to_device,
|
|
dit_weight_dtype,
|
|
target_keys,
|
|
exclude_keys,
|
|
weight_hook=weight_hook,
|
|
disable_numpy_memmap=disable_numpy_memmap,
|
|
weight_transform_hooks=weight_transform_hooks,
|
|
)
|
|
|
|
for lora_weight_keys in list_of_lora_weight_keys:
|
|
# check if all LoRA keys are used
|
|
if len(lora_weight_keys) > 0:
|
|
# if there are still LoRA keys left, it means they are not used in the model
|
|
# this is a warning, not an error
|
|
logger.warning(f"Warning: not all LoRA keys are used: {', '.join(lora_weight_keys)}")
|
|
|
|
return state_dict
|
|
|
|
|
|
def load_safetensors_with_fp8_optimization_and_hook(
|
|
model_files: list[str],
|
|
fp8_optimization: bool,
|
|
calc_device: torch.device,
|
|
move_to_device: bool = False,
|
|
dit_weight_dtype: Optional[torch.dtype] = None,
|
|
target_keys: Optional[List[str]] = None,
|
|
exclude_keys: Optional[List[str]] = None,
|
|
weight_hook: callable = None,
|
|
disable_numpy_memmap: bool = False,
|
|
weight_transform_hooks: Optional[WeightTransformHooks] = None,
|
|
) -> dict[str, torch.Tensor]:
|
|
"""
|
|
Load state dict from safetensors files and merge LoRA weights into the state dict with fp8 optimization if needed.
|
|
"""
|
|
if fp8_optimization:
|
|
logger.info(
|
|
f"Loading state dict with FP8 optimization. Dtype of weight: {dit_weight_dtype}, hook enabled: {weight_hook is not None}"
|
|
)
|
|
# dit_weight_dtype is not used because we use fp8 optimization
|
|
state_dict = load_safetensors_with_fp8_optimization(
|
|
model_files,
|
|
calc_device,
|
|
target_keys,
|
|
exclude_keys,
|
|
move_to_device=move_to_device,
|
|
weight_hook=weight_hook,
|
|
disable_numpy_memmap=disable_numpy_memmap,
|
|
weight_transform_hooks=weight_transform_hooks,
|
|
)
|
|
else:
|
|
logger.info(
|
|
f"Loading state dict without FP8 optimization. Dtype of weight: {dit_weight_dtype}, hook enabled: {weight_hook is not None}"
|
|
)
|
|
state_dict = {}
|
|
for model_file in model_files:
|
|
with MemoryEfficientSafeOpen(model_file, disable_numpy_memmap=disable_numpy_memmap) as original_f:
|
|
f = TensorWeightAdapter(weight_transform_hooks, original_f) if weight_transform_hooks is not None else original_f
|
|
for key in tqdm(f.keys(), desc=f"Loading {os.path.basename(model_file)}", leave=False):
|
|
if weight_hook is None and move_to_device:
|
|
value = f.get_tensor(key, device=calc_device, dtype=dit_weight_dtype)
|
|
else:
|
|
value = f.get_tensor(key) # we cannot directly load to device because get_tensor does non-blocking transfer
|
|
if weight_hook is not None:
|
|
value = weight_hook(key, value, keep_on_calc_device=move_to_device)
|
|
if move_to_device:
|
|
value = value.to(calc_device, dtype=dit_weight_dtype, non_blocking=True)
|
|
elif dit_weight_dtype is not None:
|
|
value = value.to(dit_weight_dtype)
|
|
|
|
state_dict[key] = value
|
|
if move_to_device:
|
|
synchronize_device(calc_device)
|
|
|
|
return state_dict
|