mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-17 01:12:41 +00:00
feat: support LoRA format without net. prefix
This commit is contained in:
@@ -212,7 +212,7 @@ def check_inputs(args: argparse.Namespace) -> Tuple[int, int]:
|
||||
|
||||
def load_dit_model(
|
||||
args: argparse.Namespace, device: torch.device, dit_weight_dtype: Optional[torch.dtype] = None
|
||||
) -> anima_models.MiniTrainDIT:
|
||||
) -> anima_models.Anima:
|
||||
"""load DiT model
|
||||
|
||||
Args:
|
||||
@@ -221,7 +221,7 @@ def load_dit_model(
|
||||
dit_weight_dtype: data type for the model weights. None for as-is
|
||||
|
||||
Returns:
|
||||
anima_models.MiniTrainDIT: DiT model instance
|
||||
anima_models.Anima: DiT model instance
|
||||
"""
|
||||
# If LyCORIS is enabled, we will load the model to CPU and then merge LoRA weights (static method)
|
||||
|
||||
@@ -439,7 +439,7 @@ def generate(
|
||||
else:
|
||||
# use shared model
|
||||
logger.info("Using shared DiT model.")
|
||||
anima: anima_models.MiniTrainDIT = shared_models["model"]
|
||||
anima: anima_models.Anima = shared_models["model"]
|
||||
|
||||
if precomputed_text_data is not None:
|
||||
logger.info("Using precomputed text data.")
|
||||
@@ -455,7 +455,7 @@ def generate(
|
||||
|
||||
def generate_body(
|
||||
args: Union[argparse.Namespace, SimpleNamespace],
|
||||
anima: anima_models.MiniTrainDIT,
|
||||
anima: anima_models.Anima,
|
||||
context: Dict[str, Any],
|
||||
context_null: Optional[Dict[str, Any]],
|
||||
device: torch.device,
|
||||
@@ -479,7 +479,7 @@ def generate_body(
|
||||
negative_embed = context_null["embed"][0].to(device, dtype=torch.bfloat16)
|
||||
|
||||
# Prepare latent variables
|
||||
num_channels_latents = 16 # anima_models.MiniTrainDIT.LATENT_CHANNELS
|
||||
num_channels_latents = anima_models.Anima.LATENT_CHANNELS
|
||||
shape = (
|
||||
1,
|
||||
num_channels_latents,
|
||||
|
||||
@@ -1049,13 +1049,15 @@ class Block(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
# Main DiT Model: MiniTrainDIT
|
||||
class MiniTrainDIT(nn.Module):
|
||||
# Main DiT Model: MiniTrainDIT (renamed to Anima)
|
||||
class Anima(nn.Module):
|
||||
"""Cosmos-Predict2 DiT model for image/video generation.
|
||||
|
||||
28 transformer blocks with AdaLN-LoRA modulation, 3D RoPE, and optional LLM Adapter.
|
||||
"""
|
||||
|
||||
LATENT_CHANNELS = 16
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_img_h: int,
|
||||
@@ -1307,7 +1309,7 @@ class MiniTrainDIT(nn.Module):
|
||||
return
|
||||
self.offloader.prepare_block_devices_before_forward(self.blocks)
|
||||
|
||||
def forward(
|
||||
def forward_mini_train_dit(
|
||||
self,
|
||||
x_B_C_T_H_W: torch.Tensor,
|
||||
timesteps_B_T: torch.Tensor,
|
||||
@@ -1378,6 +1380,41 @@ class MiniTrainDIT(nn.Module):
|
||||
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)
|
||||
return x_B_C_Tt_Hp_Wp
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
fps: Optional[torch.Tensor] = None,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
target_input_ids: Optional[torch.Tensor] = None,
|
||||
target_attention_mask: Optional[torch.Tensor] = None,
|
||||
source_attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
context = self._preprocess_text_embeds(context, target_input_ids, target_attention_mask, source_attention_mask)
|
||||
return self.forward_mini_train_dit(x, timesteps, context, fps=fps, padding_mask=padding_mask, **kwargs)
|
||||
|
||||
def _preprocess_text_embeds(
|
||||
self, source_hidden_states, target_input_ids, target_attention_mask=None, source_attention_mask=None
|
||||
):
|
||||
if target_input_ids is not None:
|
||||
# print(
|
||||
# f"Source hidden states shape: {source_hidden_states.shape},sum of attention mask: {torch.sum(source_attention_mask)}"
|
||||
# )
|
||||
# print(f"non zero source_hidden_states before LLM Adapter: {torch.sum(source_hidden_states != 0)}")
|
||||
context = self.llm_adapter(
|
||||
source_hidden_states,
|
||||
target_input_ids,
|
||||
target_attention_mask=target_attention_mask,
|
||||
source_attention_mask=source_attention_mask,
|
||||
)
|
||||
context[~target_attention_mask.bool()] = 0 # zero out padding tokens
|
||||
# print(f"LLM Adapter output context: {context.shape}, {torch.isnan(context).sum()}")
|
||||
return context
|
||||
else:
|
||||
return source_hidden_states
|
||||
|
||||
|
||||
# LLM Adapter: Bridges Qwen3 embeddings to T5-compatible space
|
||||
class LLMAdapterRMSNorm(nn.Module):
|
||||
@@ -1531,33 +1568,15 @@ class LLMAdapterTransformerBlock(nn.Module):
|
||||
)
|
||||
x = x + attn_out
|
||||
|
||||
if source_attention_mask is not None:
|
||||
# Select batch elements where source_attention_mask has at least one True value
|
||||
batch_indices = torch.where(source_attention_mask.any(dim=(1, 2, 3)))[0]
|
||||
# print("Batch indices for cross-attention:", batch_indices)
|
||||
if len(batch_indices) == 0:
|
||||
pass # No valid batch elements, skip cross-attention
|
||||
else:
|
||||
normed = self.norm_cross_attn(x[batch_indices])
|
||||
attn_out = self.cross_attn(
|
||||
normed,
|
||||
mask=None,
|
||||
context=context[batch_indices],
|
||||
position_embeddings=position_embeddings,
|
||||
position_embeddings_context=position_embeddings_context,
|
||||
)
|
||||
x[batch_indices] = x[batch_indices] + attn_out
|
||||
else:
|
||||
# Standard cross-attention without masking
|
||||
normed = self.norm_cross_attn(x)
|
||||
attn_out = self.cross_attn(
|
||||
normed,
|
||||
mask=source_attention_mask,
|
||||
context=context,
|
||||
position_embeddings=position_embeddings,
|
||||
position_embeddings_context=position_embeddings_context,
|
||||
)
|
||||
x = x + attn_out
|
||||
normed = self.norm_cross_attn(x)
|
||||
attn_out = self.cross_attn(
|
||||
normed,
|
||||
mask=source_attention_mask,
|
||||
context=context,
|
||||
position_embeddings=position_embeddings,
|
||||
position_embeddings_context=position_embeddings_context,
|
||||
)
|
||||
x = x + attn_out
|
||||
|
||||
x = x + self.mlp(self.norm_mlp(x))
|
||||
return x
|
||||
@@ -1623,76 +1642,6 @@ class LLMAdapter(nn.Module):
|
||||
return self.norm(self.out_proj(x))
|
||||
|
||||
|
||||
class Anima(nn.Module):
|
||||
"""
|
||||
Wrapper class for the MiniTrainDIT and LLM Adapter.
|
||||
"""
|
||||
|
||||
LATENT_CHANNELS = 16
|
||||
|
||||
def __init__(self, dit_config: dict):
|
||||
super().__init__()
|
||||
self.net = MiniTrainDIT(**dit_config)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.net.device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.net.dtype
|
||||
|
||||
def enable_gradient_checkpointing(self, *args, **kwargs):
|
||||
self.net.enable_gradient_checkpointing(*args, **kwargs)
|
||||
|
||||
def disable_gradient_checkpointing(self):
|
||||
self.net.disable_gradient_checkpointing()
|
||||
|
||||
def enable_block_swap(self, *args, **kwargs):
|
||||
self.net.enable_block_swap(*args, **kwargs)
|
||||
|
||||
def move_to_device_except_swap_blocks(self, *args, **kwargs):
|
||||
self.net.move_to_device_except_swap_blocks(*args, **kwargs)
|
||||
|
||||
def prepare_block_swap_before_forward(self, *args, **kwargs):
|
||||
self.net.prepare_block_swap_before_forward(*args, **kwargs)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
fps: Optional[torch.Tensor] = None,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
target_input_ids: Optional[torch.Tensor] = None,
|
||||
target_attention_mask: Optional[torch.Tensor] = None,
|
||||
source_attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
context = self._preprocess_text_embeds(context, target_input_ids, target_attention_mask, source_attention_mask)
|
||||
return self.net(x, timesteps, context, fps=fps, padding_mask=padding_mask, **kwargs)
|
||||
|
||||
def _preprocess_text_embeds(
|
||||
self, source_hidden_states, target_input_ids, target_attention_mask=None, source_attention_mask=None
|
||||
):
|
||||
if target_input_ids is not None:
|
||||
# print(
|
||||
# f"Source hidden states shape: {source_hidden_states.shape},sum of attention mask: {torch.sum(source_attention_mask)}"
|
||||
# )
|
||||
# print(f"non zero source_hidden_states before LLM Adapter: {torch.sum(source_hidden_states != 0)}")
|
||||
context = self.net.llm_adapter(
|
||||
source_hidden_states,
|
||||
target_input_ids,
|
||||
target_attention_mask=target_attention_mask,
|
||||
source_attention_mask=source_attention_mask,
|
||||
)
|
||||
context[~target_attention_mask.bool()] = 0 # zero out padding tokens
|
||||
# print(f"LLM Adapter output context: {context.shape}, {torch.isnan(context).sum()}")
|
||||
return context
|
||||
else:
|
||||
return source_hidden_states
|
||||
|
||||
|
||||
# VAE Wrapper
|
||||
|
||||
# VAE normalization constants
|
||||
|
||||
@@ -225,7 +225,7 @@ def get_anima_param_groups(
|
||||
"""Create parameter groups for Anima training with separate learning rates.
|
||||
|
||||
Args:
|
||||
dit: MiniTrainDIT model
|
||||
dit: Anima model
|
||||
base_lr: Base learning rate
|
||||
self_attn_lr: LR for self-attention layers (None = base_lr, 0 = freeze)
|
||||
cross_attn_lr: LR for cross-attention layers
|
||||
@@ -307,7 +307,7 @@ def save_anima_model_on_train_end(
|
||||
save_dtype: torch.dtype,
|
||||
epoch: int,
|
||||
global_step: int,
|
||||
dit: anima_models.MiniTrainDIT,
|
||||
dit: anima_models.Anima,
|
||||
):
|
||||
"""Save Anima model at the end of training."""
|
||||
|
||||
@@ -328,7 +328,7 @@ def save_anima_model_on_epoch_end_or_stepwise(
|
||||
epoch: int,
|
||||
num_train_epochs: int,
|
||||
global_step: int,
|
||||
dit: anima_models.MiniTrainDIT,
|
||||
dit: anima_models.Anima,
|
||||
):
|
||||
"""Save Anima model at epoch end or specific steps."""
|
||||
|
||||
@@ -356,7 +356,7 @@ def do_sample(
|
||||
height: int,
|
||||
width: int,
|
||||
seed: Optional[int],
|
||||
dit: anima_models.MiniTrainDIT,
|
||||
dit: anima_models.Anima,
|
||||
crossattn_emb: torch.Tensor,
|
||||
steps: int,
|
||||
dtype: torch.dtype,
|
||||
@@ -370,7 +370,7 @@ def do_sample(
|
||||
Args:
|
||||
height, width: Output image dimensions
|
||||
seed: Random seed (None for random)
|
||||
dit: MiniTrainDIT model
|
||||
dit: Anima model
|
||||
crossattn_emb: Cross-attention embeddings (B, N, D)
|
||||
steps: Number of sampling steps
|
||||
dtype: Compute dtype
|
||||
@@ -588,8 +588,8 @@ def _sample_image_inference(
|
||||
t5_attn_mask = t5_attn_mask.to(accelerator.device)
|
||||
|
||||
# Process through LLM adapter if available
|
||||
if dit.net.use_llm_adapter:
|
||||
crossattn_emb = dit.net.llm_adapter(
|
||||
if dit.use_llm_adapter:
|
||||
crossattn_emb = dit.llm_adapter(
|
||||
source_hidden_states=prompt_embeds,
|
||||
target_input_ids=t5_input_ids,
|
||||
target_attention_mask=t5_attn_mask,
|
||||
@@ -616,8 +616,8 @@ def _sample_image_inference(
|
||||
neg_t5_ids = neg_t5_ids.to(accelerator.device, dtype=torch.long)
|
||||
neg_t5_am = neg_t5_am.to(accelerator.device)
|
||||
|
||||
if dit.net.use_llm_adapter:
|
||||
neg_crossattn_emb = dit.net.llm_adapter(
|
||||
if dit.use_llm_adapter:
|
||||
neg_crossattn_emb = dit.llm_adapter(
|
||||
source_hidden_states=neg_pe,
|
||||
target_input_ids=neg_t5_ids,
|
||||
target_attention_mask=neg_t5_am,
|
||||
|
||||
@@ -11,6 +11,7 @@ from accelerate import init_empty_weights
|
||||
from library.fp8_optimization_utils import apply_fp8_monkey_patch
|
||||
from library.lora_utils import load_safetensors_with_lora_and_fp8
|
||||
from library import anima_models
|
||||
from library.safetensors_utils import WeightTransformHooks
|
||||
from .utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
@@ -19,7 +20,6 @@ import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
# Keys that should stay in high precision (float32/bfloat16, not quantized)
|
||||
KEEP_IN_HIGH_PRECISION = ["x_embedder", "t_embedder", "t_embedding_norm", "final_layer"]
|
||||
|
||||
@@ -39,8 +39,8 @@ def load_anima_dit(
|
||||
transformer_dtype: Optional[torch.dtype] = None,
|
||||
llm_adapter_path: Optional[str] = None,
|
||||
disable_mmap: bool = False,
|
||||
) -> anima_models.MiniTrainDIT:
|
||||
"""Load the MiniTrainDIT model from safetensors.
|
||||
) -> anima_models.Anima:
|
||||
"""Load the Anima model from safetensors.
|
||||
|
||||
Args:
|
||||
dit_path: Path to DiT safetensors file
|
||||
@@ -91,7 +91,7 @@ def load_anima_dit(
|
||||
)
|
||||
|
||||
# Build model normally on CPU — buffers get proper values from __init__
|
||||
dit = anima_models.MiniTrainDIT(**dit_config)
|
||||
dit = anima_models.Anima(**dit_config)
|
||||
|
||||
# Merge LLM adapter weights into state_dict if loaded separately
|
||||
if use_llm_adapter and llm_adapter_state_dict is not None:
|
||||
@@ -192,12 +192,13 @@ def load_anima_model(
|
||||
"split_attn": split_attn,
|
||||
}
|
||||
with init_empty_weights():
|
||||
model = anima_models.Anima(dit_config)
|
||||
model = anima_models.Anima(**dit_config)
|
||||
if dit_weight_dtype is not None:
|
||||
model.to(dit_weight_dtype)
|
||||
|
||||
# load model weights with dynamic fp8 optimization and LoRA merging if needed
|
||||
logger.info(f"Loading DiT model from {dit_path}, device={loading_device}")
|
||||
rename_hooks = WeightTransformHooks(rename_hook=lambda k: k[len("net.") :] if k.startswith("net.") else k)
|
||||
sd = load_safetensors_with_lora_and_fp8(
|
||||
model_files=dit_path,
|
||||
lora_weights_list=lora_weights_list,
|
||||
@@ -208,6 +209,7 @@ def load_anima_model(
|
||||
dit_weight_dtype=dit_weight_dtype,
|
||||
target_keys=FP8_OPTIMIZATION_TARGET_KEYS,
|
||||
exclude_keys=FP8_OPTIMIZATION_EXCLUDE_KEYS,
|
||||
weight_transform_hooks=rename_hooks,
|
||||
)
|
||||
|
||||
if fp8_scaled:
|
||||
|
||||
@@ -9,7 +9,7 @@ import logging
|
||||
from tqdm import tqdm
|
||||
|
||||
from library.device_utils import clean_memory_on_device
|
||||
from library.safetensors_utils import MemoryEfficientSafeOpen
|
||||
from library.safetensors_utils import MemoryEfficientSafeOpen, TensorWeightAdapter, WeightTransformHooks
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
@@ -220,6 +220,8 @@ def quantize_weight(
|
||||
tensor_max = torch.max(torch.abs(tensor).view(-1))
|
||||
scale = tensor_max / max_value
|
||||
|
||||
# print(f"Optimizing {key} with scale: {scale}")
|
||||
|
||||
# numerical safety
|
||||
scale = torch.clamp(scale, min=1e-8)
|
||||
scale = scale.to(torch.float32) # ensure scale is in float32 for division
|
||||
@@ -245,6 +247,8 @@ def load_safetensors_with_fp8_optimization(
|
||||
weight_hook=None,
|
||||
quantization_mode: str = "block",
|
||||
block_size: Optional[int] = 64,
|
||||
disable_numpy_memmap: bool = False,
|
||||
weight_transform_hooks: Optional[WeightTransformHooks] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Load weight tensors from safetensors files and merge LoRA weights into the state dict with explicit FP8 optimization.
|
||||
@@ -260,6 +264,8 @@ def load_safetensors_with_fp8_optimization(
|
||||
weight_hook (callable, optional): Function to apply to each weight tensor before optimization
|
||||
quantization_mode (str): Quantization mode, "tensor", "channel", or "block"
|
||||
block_size (int, optional): Block size for block-wise quantization (used if quantization_mode is "block")
|
||||
disable_numpy_memmap (bool): Disable numpy memmap when loading safetensors
|
||||
weight_transform_hooks (WeightTransformHooks, optional): Hooks for weight transformation during loading
|
||||
|
||||
Returns:
|
||||
dict: FP8 optimized state dict
|
||||
@@ -288,7 +294,9 @@ def load_safetensors_with_fp8_optimization(
|
||||
# Process each file
|
||||
state_dict = {}
|
||||
for model_file in model_files:
|
||||
with MemoryEfficientSafeOpen(model_file) as f:
|
||||
with MemoryEfficientSafeOpen(model_file, disable_numpy_memmap=disable_numpy_memmap) as original_f:
|
||||
f = TensorWeightAdapter(weight_transform_hooks, original_f) if weight_transform_hooks is not None else original_f
|
||||
|
||||
keys = f.keys()
|
||||
for key in tqdm(keys, desc=f"Loading {os.path.basename(model_file)}", unit="key"):
|
||||
value = f.get_tensor(key)
|
||||
@@ -311,6 +319,11 @@ def load_safetensors_with_fp8_optimization(
|
||||
value = value.to(calc_device)
|
||||
|
||||
original_dtype = value.dtype
|
||||
if original_dtype.itemsize == 1:
|
||||
raise ValueError(
|
||||
f"Layer {key} is already in {original_dtype} format. `--fp8_scaled` optimization should not be applied. Please use fp16/bf16/float32 model weights."
|
||||
+ f" / レイヤー {key} は既に{original_dtype}形式です。`--fp8_scaled` 最適化は適用できません。FP16/BF16/Float32のモデル重みを使用してください。"
|
||||
)
|
||||
quantized_weight, scale_tensor = quantize_weight(
|
||||
key, value, fp8_dtype, max_value, min_value, quantization_mode, block_size
|
||||
)
|
||||
@@ -387,7 +400,7 @@ def fp8_linear_forward_patch(self: nn.Linear, x, use_scaled_mm=False, max_value=
|
||||
else:
|
||||
o = torch._scaled_mm(x, weight, out_dtype=input_dtype, scale_a=scale_x, scale_b=scale_weight)
|
||||
|
||||
o = o.reshape(original_shape[0], original_shape[1], -1) if x.ndim == 3 else o.reshape(original_shape[0], -1)
|
||||
o = o.reshape(original_shape[0], original_shape[1], -1) if len(original_shape) == 3 else o.reshape(original_shape[0], -1)
|
||||
return o.to(input_dtype)
|
||||
|
||||
else:
|
||||
|
||||
@@ -5,7 +5,7 @@ import torch
|
||||
from tqdm import tqdm
|
||||
from library.device_utils import synchronize_device
|
||||
from library.fp8_optimization_utils import load_safetensors_with_fp8_optimization
|
||||
from library.safetensors_utils import MemoryEfficientSafeOpen
|
||||
from library.safetensors_utils import MemoryEfficientSafeOpen, TensorWeightAdapter, WeightTransformHooks, get_split_weight_filenames
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
@@ -44,7 +44,7 @@ def filter_lora_state_dict(
|
||||
|
||||
def load_safetensors_with_lora_and_fp8(
|
||||
model_files: Union[str, List[str]],
|
||||
lora_weights_list: Optional[Dict[str, torch.Tensor]],
|
||||
lora_weights_list: Optional[List[Dict[str, torch.Tensor]]],
|
||||
lora_multipliers: Optional[List[float]],
|
||||
fp8_optimization: bool,
|
||||
calc_device: torch.device,
|
||||
@@ -52,19 +52,23 @@ def load_safetensors_with_lora_and_fp8(
|
||||
dit_weight_dtype: Optional[torch.dtype] = None,
|
||||
target_keys: Optional[List[str]] = None,
|
||||
exclude_keys: Optional[List[str]] = None,
|
||||
disable_numpy_memmap: bool = False,
|
||||
weight_transform_hooks: Optional[WeightTransformHooks] = None,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Merge LoRA weights into the state dict of a model with fp8 optimization if needed.
|
||||
|
||||
Args:
|
||||
model_files (Union[str, List[str]]): Path to the model file or list of paths. If the path matches a pattern like `00001-of-00004`, it will load all files with the same prefix.
|
||||
lora_weights_list (Optional[Dict[str, torch.Tensor]]): Dictionary of LoRA weight tensors to load.
|
||||
lora_weights_list (Optional[List[Dict[str, torch.Tensor]]]): List of dictionaries of LoRA weight tensors to load.
|
||||
lora_multipliers (Optional[List[float]]): List of multipliers for LoRA weights.
|
||||
fp8_optimization (bool): Whether to apply FP8 optimization.
|
||||
calc_device (torch.device): Device to calculate on.
|
||||
move_to_device (bool): Whether to move tensors to the calculation device after loading.
|
||||
target_keys (Optional[List[str]]): Keys to target for optimization.
|
||||
exclude_keys (Optional[List[str]]): Keys to exclude from optimization.
|
||||
disable_numpy_memmap (bool): Whether to disable numpy memmap when loading safetensors.
|
||||
weight_transform_hooks (Optional[WeightTransformHooks]): Hooks for transforming weights during loading.
|
||||
"""
|
||||
|
||||
# if the file name ends with 00001-of-00004 etc, we need to load the files with the same prefix
|
||||
@@ -73,19 +77,9 @@ def load_safetensors_with_lora_and_fp8(
|
||||
|
||||
extended_model_files = []
|
||||
for model_file in model_files:
|
||||
basename = os.path.basename(model_file)
|
||||
match = re.match(r"^(.*?)(\d+)-of-(\d+)\.safetensors$", basename)
|
||||
if match:
|
||||
prefix = basename[: match.start(2)]
|
||||
count = int(match.group(3))
|
||||
state_dict = {}
|
||||
for i in range(count):
|
||||
filename = f"{prefix}{i + 1:05d}-of-{count:05d}.safetensors"
|
||||
filepath = os.path.join(os.path.dirname(model_file), filename)
|
||||
if os.path.exists(filepath):
|
||||
extended_model_files.append(filepath)
|
||||
else:
|
||||
raise FileNotFoundError(f"File {filepath} not found")
|
||||
split_filenames = get_split_weight_filenames(model_file)
|
||||
if split_filenames is not None:
|
||||
extended_model_files.extend(split_filenames)
|
||||
else:
|
||||
extended_model_files.append(model_file)
|
||||
model_files = extended_model_files
|
||||
@@ -114,7 +108,7 @@ def load_safetensors_with_lora_and_fp8(
|
||||
logger.info(f"Merging LoRA weights into state dict. multipliers: {lora_multipliers}")
|
||||
|
||||
# make hook for LoRA merging
|
||||
def weight_hook_func(model_weight_key, model_weight, keep_on_calc_device=False):
|
||||
def weight_hook_func(model_weight_key, model_weight: torch.Tensor, keep_on_calc_device=False):
|
||||
nonlocal list_of_lora_weight_keys, lora_weights_list, lora_multipliers, calc_device
|
||||
|
||||
if not model_weight_key.endswith(".weight"):
|
||||
@@ -145,6 +139,13 @@ def load_safetensors_with_lora_and_fp8(
|
||||
down_weight = down_weight.to(calc_device)
|
||||
up_weight = up_weight.to(calc_device)
|
||||
|
||||
original_dtype = model_weight.dtype
|
||||
if original_dtype.itemsize == 1: # fp8
|
||||
# temporarily convert to float16 for calculation
|
||||
model_weight = model_weight.to(torch.float16)
|
||||
down_weight = down_weight.to(torch.float16)
|
||||
up_weight = up_weight.to(torch.float16)
|
||||
|
||||
# W <- W + U * D
|
||||
if len(model_weight.size()) == 2:
|
||||
# linear
|
||||
@@ -166,6 +167,9 @@ def load_safetensors_with_lora_and_fp8(
|
||||
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
|
||||
model_weight = model_weight + multiplier * conved * scale
|
||||
|
||||
if original_dtype.itemsize == 1: # fp8
|
||||
model_weight = model_weight.to(original_dtype) # convert back to original dtype
|
||||
|
||||
# remove LoRA keys from set
|
||||
lora_weight_keys.remove(down_key)
|
||||
lora_weight_keys.remove(up_key)
|
||||
@@ -187,6 +191,8 @@ def load_safetensors_with_lora_and_fp8(
|
||||
target_keys,
|
||||
exclude_keys,
|
||||
weight_hook=weight_hook,
|
||||
disable_numpy_memmap=disable_numpy_memmap,
|
||||
weight_transform_hooks=weight_transform_hooks,
|
||||
)
|
||||
|
||||
for lora_weight_keys in list_of_lora_weight_keys:
|
||||
@@ -208,6 +214,8 @@ def load_safetensors_with_fp8_optimization_and_hook(
|
||||
target_keys: Optional[List[str]] = None,
|
||||
exclude_keys: Optional[List[str]] = None,
|
||||
weight_hook: callable = None,
|
||||
disable_numpy_memmap: bool = False,
|
||||
weight_transform_hooks: Optional[WeightTransformHooks] = None,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Load state dict from safetensors files and merge LoRA weights into the state dict with fp8 optimization if needed.
|
||||
@@ -218,7 +226,14 @@ def load_safetensors_with_fp8_optimization_and_hook(
|
||||
)
|
||||
# dit_weight_dtype is not used because we use fp8 optimization
|
||||
state_dict = load_safetensors_with_fp8_optimization(
|
||||
model_files, calc_device, target_keys, exclude_keys, move_to_device=move_to_device, weight_hook=weight_hook
|
||||
model_files,
|
||||
calc_device,
|
||||
target_keys,
|
||||
exclude_keys,
|
||||
move_to_device=move_to_device,
|
||||
weight_hook=weight_hook,
|
||||
disable_numpy_memmap=disable_numpy_memmap,
|
||||
weight_transform_hooks=weight_transform_hooks,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
@@ -226,7 +241,8 @@ def load_safetensors_with_fp8_optimization_and_hook(
|
||||
)
|
||||
state_dict = {}
|
||||
for model_file in model_files:
|
||||
with MemoryEfficientSafeOpen(model_file) as f:
|
||||
with MemoryEfficientSafeOpen(model_file, disable_numpy_memmap=disable_numpy_memmap) as original_f:
|
||||
f = TensorWeightAdapter(weight_transform_hooks, original_f) if weight_transform_hooks is not None else original_f
|
||||
for key in tqdm(f.keys(), desc=f"Loading {os.path.basename(model_file)}", leave=False):
|
||||
if weight_hook is None and move_to_device:
|
||||
value = f.get_tensor(key, device=calc_device, dtype=dit_weight_dtype)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from dataclasses import dataclass
|
||||
import os
|
||||
import re
|
||||
import numpy as np
|
||||
@@ -44,6 +45,7 @@ def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata:
|
||||
validated[key] = value
|
||||
return validated
|
||||
|
||||
# print(f"Using memory efficient save file: {filename}")
|
||||
|
||||
header = {}
|
||||
offset = 0
|
||||
@@ -88,15 +90,17 @@ class MemoryEfficientSafeOpen:
|
||||
by using memory mapping for large tensors and avoiding unnecessary copies.
|
||||
"""
|
||||
|
||||
def __init__(self, filename):
|
||||
def __init__(self, filename, disable_numpy_memmap=False):
|
||||
"""Initialize the SafeTensor reader.
|
||||
|
||||
Args:
|
||||
filename (str): Path to the safetensors file to read.
|
||||
disable_numpy_memmap (bool): If True, disable numpy memory mapping for large tensors, using standard file read instead.
|
||||
"""
|
||||
self.filename = filename
|
||||
self.file = open(filename, "rb")
|
||||
self.header, self.header_size = self._read_header()
|
||||
self.disable_numpy_memmap = disable_numpy_memmap
|
||||
|
||||
def __enter__(self):
|
||||
"""Enter context manager."""
|
||||
@@ -178,7 +182,8 @@ class MemoryEfficientSafeOpen:
|
||||
# Use memmap for large tensors to avoid intermediate copies.
|
||||
# If device is cpu, tensor is not copied to gpu, so using memmap locks the file, which is not desired.
|
||||
# So we only use memmap if device is not cpu.
|
||||
if num_bytes > 10 * 1024 * 1024 and device is not None and device.type != "cpu":
|
||||
# If disable_numpy_memmap is True, skip numpy memory mapping to load with standard file read.
|
||||
if not self.disable_numpy_memmap and num_bytes > 10 * 1024 * 1024 and device is not None and device.type != "cpu":
|
||||
# Create memory map for zero-copy reading
|
||||
mm = np.memmap(self.filename, mode="c", dtype=np.uint8, offset=tensor_offset, shape=(num_bytes,))
|
||||
byte_tensor = torch.from_numpy(mm) # zero copy
|
||||
@@ -285,7 +290,11 @@ class MemoryEfficientSafeOpen:
|
||||
|
||||
|
||||
def load_safetensors(
|
||||
path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = None
|
||||
path: str,
|
||||
device: Union[str, torch.device],
|
||||
disable_mmap: bool = False,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
disable_numpy_memmap: bool = False,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
if disable_mmap:
|
||||
# return safetensors.torch.load(open(path, "rb").read())
|
||||
@@ -293,7 +302,7 @@ def load_safetensors(
|
||||
# logger.info(f"Loading without mmap (experimental)")
|
||||
state_dict = {}
|
||||
device = torch.device(device) if device is not None else None
|
||||
with MemoryEfficientSafeOpen(path) as f:
|
||||
with MemoryEfficientSafeOpen(path, disable_numpy_memmap=disable_numpy_memmap) as f:
|
||||
for key in f.keys():
|
||||
state_dict[key] = f.get_tensor(key, device=device, dtype=dtype)
|
||||
synchronize_device(device)
|
||||
@@ -309,6 +318,29 @@ def load_safetensors(
|
||||
return state_dict
|
||||
|
||||
|
||||
def get_split_weight_filenames(file_path: str) -> Optional[list[str]]:
|
||||
"""
|
||||
Get the list of split weight filenames (full paths) if the file name ends with 00001-of-00004 etc.
|
||||
Returns None if the file is not split.
|
||||
"""
|
||||
basename = os.path.basename(file_path)
|
||||
match = re.match(r"^(.*?)(\d+)-of-(\d+)\.safetensors$", basename)
|
||||
if match:
|
||||
prefix = basename[: match.start(2)]
|
||||
count = int(match.group(3))
|
||||
filenames = []
|
||||
for i in range(count):
|
||||
filename = f"{prefix}{i + 1:05d}-of-{count:05d}.safetensors"
|
||||
filepath = os.path.join(os.path.dirname(file_path), filename)
|
||||
if os.path.exists(filepath):
|
||||
filenames.append(filepath)
|
||||
else:
|
||||
raise FileNotFoundError(f"File {filepath} not found")
|
||||
return filenames
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def load_split_weights(
|
||||
file_path: str, device: Union[str, torch.device] = "cpu", disable_mmap: bool = False, dtype: Optional[torch.dtype] = None
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
@@ -319,19 +351,11 @@ def load_split_weights(
|
||||
device = torch.device(device)
|
||||
|
||||
# if the file name ends with 00001-of-00004 etc, we need to load the files with the same prefix
|
||||
basename = os.path.basename(file_path)
|
||||
match = re.match(r"^(.*?)(\d+)-of-(\d+)\.safetensors$", basename)
|
||||
if match:
|
||||
prefix = basename[: match.start(2)]
|
||||
count = int(match.group(3))
|
||||
split_filenames = get_split_weight_filenames(file_path)
|
||||
if split_filenames is not None:
|
||||
state_dict = {}
|
||||
for i in range(count):
|
||||
filename = f"{prefix}{i + 1:05d}-of-{count:05d}.safetensors"
|
||||
filepath = os.path.join(os.path.dirname(file_path), filename)
|
||||
if os.path.exists(filepath):
|
||||
state_dict.update(load_safetensors(filepath, device=device, disable_mmap=disable_mmap, dtype=dtype))
|
||||
else:
|
||||
raise FileNotFoundError(f"File {filepath} not found")
|
||||
for filename in split_filenames:
|
||||
state_dict.update(load_safetensors(filename, device=device, disable_mmap=disable_mmap, dtype=dtype))
|
||||
else:
|
||||
state_dict = load_safetensors(file_path, device=device, disable_mmap=disable_mmap, dtype=dtype)
|
||||
return state_dict
|
||||
@@ -349,3 +373,107 @@ def find_key(safetensors_file: str, starts_with: Optional[str] = None, ends_with
|
||||
if (starts_with is None or key.startswith(starts_with)) and (ends_with is None or key.endswith(ends_with)):
|
||||
return key
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class WeightTransformHooks:
|
||||
split_hook: Optional[callable] = None
|
||||
concat_hook: Optional[callable] = None
|
||||
rename_hook: Optional[callable] = None
|
||||
|
||||
|
||||
class TensorWeightAdapter:
|
||||
"""
|
||||
A wrapper for weight conversion hooks (split and concat) to be used with MemoryEfficientSafeOpen.
|
||||
This wrapper adapts the original MemoryEfficientSafeOpen to apply the provided split and concat hooks
|
||||
when loading tensors.
|
||||
|
||||
split_hook: A callable that takes (original_key: str, original_tensor: torch.Tensor) and returns (new_keys: list[str], new_tensors: list[torch.Tensor]).
|
||||
concat_hook: A callable that takes (original_key: str, tensors: dict[str, torch.Tensor]) and returns (new_key: str, concatenated_tensor: torch.Tensor).
|
||||
rename_hook: A callable that takes (original_key: str) and returns (new_key: str).
|
||||
|
||||
If tensors is None, the hook should return only the new keys (for split) or new key (for concat), without tensors.
|
||||
|
||||
No need to implement __enter__ and __exit__ methods, as they are handled by the original MemoryEfficientSafeOpen.
|
||||
Do not use this wrapper as a context manager directly, like `with WeightConvertHookWrapper(...) as f:`.
|
||||
|
||||
**concat_hook is not tested yet.**
|
||||
"""
|
||||
|
||||
def __init__(self, weight_convert_hook: WeightTransformHooks, original_f: MemoryEfficientSafeOpen):
|
||||
self.original_f = original_f
|
||||
self.new_key_to_original_key_map: dict[str, Union[str, list[str]]] = (
|
||||
{}
|
||||
) # for split: new_key -> original_key; for concat: new_key -> list of original_keys; for direct mapping: new_key -> original_key
|
||||
self.concat_key_set = set() # set of concatenated keys
|
||||
self.split_key_set = set() # set of split keys
|
||||
self.new_keys = []
|
||||
self.tensor_cache = {} # cache for split tensors
|
||||
self.split_hook = weight_convert_hook.split_hook
|
||||
self.concat_hook = weight_convert_hook.concat_hook
|
||||
self.rename_hook = weight_convert_hook.rename_hook
|
||||
|
||||
for key in self.original_f.keys():
|
||||
if self.split_hook is not None:
|
||||
converted_keys, _ = self.split_hook(key, None) # get new keys only
|
||||
if converted_keys is not None:
|
||||
for converted_key in converted_keys:
|
||||
self.new_key_to_original_key_map[converted_key] = key
|
||||
self.split_key_set.add(converted_key)
|
||||
self.new_keys.extend(converted_keys)
|
||||
continue # skip concat_hook if split_hook is applied
|
||||
|
||||
if self.concat_hook is not None:
|
||||
converted_key, _ = self.concat_hook(key, None) # get new key only
|
||||
if converted_key is not None:
|
||||
if converted_key not in self.concat_key_set: # first time seeing this concatenated key
|
||||
self.concat_key_set.add(converted_key)
|
||||
self.new_key_to_original_key_map[converted_key] = []
|
||||
|
||||
# multiple original keys map to the same concatenated key
|
||||
self.new_key_to_original_key_map[converted_key].append(key)
|
||||
|
||||
self.new_keys.append(converted_key)
|
||||
continue # skip to next key
|
||||
|
||||
# direct mapping
|
||||
if self.rename_hook is not None:
|
||||
new_key = self.rename_hook(key)
|
||||
self.new_key_to_original_key_map[new_key] = key
|
||||
else:
|
||||
new_key = key
|
||||
|
||||
self.new_keys.append(new_key)
|
||||
|
||||
def keys(self) -> list[str]:
|
||||
return self.new_keys
|
||||
|
||||
def get_tensor(self, new_key: str, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
|
||||
# load tensor by new_key, applying split or concat hooks as needed
|
||||
if new_key not in self.new_key_to_original_key_map:
|
||||
# direct mapping
|
||||
return self.original_f.get_tensor(new_key, device=device, dtype=dtype)
|
||||
|
||||
elif new_key in self.split_key_set:
|
||||
# split hook: split key is requested multiple times, so we cache the result
|
||||
original_key = self.new_key_to_original_key_map[new_key]
|
||||
if original_key not in self.tensor_cache: # not yet split
|
||||
original_tensor = self.original_f.get_tensor(original_key, device=device, dtype=dtype)
|
||||
new_keys, new_tensors = self.split_hook(original_key, original_tensor) # apply split hook
|
||||
for k, t in zip(new_keys, new_tensors):
|
||||
self.tensor_cache[k] = t
|
||||
return self.tensor_cache.pop(new_key) # return and remove from cache
|
||||
|
||||
elif new_key in self.concat_key_set:
|
||||
# concat hook: concatenated key is requested only once, so we do not cache the result
|
||||
tensors = {}
|
||||
for original_key in self.new_key_to_original_key_map[new_key]:
|
||||
tensor = self.original_f.get_tensor(original_key, device=device, dtype=dtype)
|
||||
tensors[original_key] = tensor
|
||||
_, concatenated_tensors = self.concat_hook(self.new_key_to_original_key_map[new_key][0], tensors) # apply concat hook
|
||||
return concatenated_tensors
|
||||
|
||||
else:
|
||||
# direct mapping
|
||||
original_key = self.new_key_to_original_key_map[new_key]
|
||||
return self.original_f.get_tensor(original_key, device=device, dtype=dtype)
|
||||
|
||||
Reference in New Issue
Block a user