feat: support LoRA format without net. prefix

This commit is contained in:
Kohya S
2026-02-09 23:21:04 +09:00
parent 8f5b298906
commit 0f413974b7
7 changed files with 265 additions and 157 deletions

View File

@@ -212,7 +212,7 @@ def check_inputs(args: argparse.Namespace) -> Tuple[int, int]:
def load_dit_model(
args: argparse.Namespace, device: torch.device, dit_weight_dtype: Optional[torch.dtype] = None
) -> anima_models.MiniTrainDIT:
) -> anima_models.Anima:
"""load DiT model
Args:
@@ -221,7 +221,7 @@ def load_dit_model(
dit_weight_dtype: data type for the model weights. None for as-is
Returns:
anima_models.MiniTrainDIT: DiT model instance
anima_models.Anima: DiT model instance
"""
# If LyCORIS is enabled, we will load the model to CPU and then merge LoRA weights (static method)
@@ -439,7 +439,7 @@ def generate(
else:
# use shared model
logger.info("Using shared DiT model.")
anima: anima_models.MiniTrainDIT = shared_models["model"]
anima: anima_models.Anima = shared_models["model"]
if precomputed_text_data is not None:
logger.info("Using precomputed text data.")
@@ -455,7 +455,7 @@ def generate(
def generate_body(
args: Union[argparse.Namespace, SimpleNamespace],
anima: anima_models.MiniTrainDIT,
anima: anima_models.Anima,
context: Dict[str, Any],
context_null: Optional[Dict[str, Any]],
device: torch.device,
@@ -479,7 +479,7 @@ def generate_body(
negative_embed = context_null["embed"][0].to(device, dtype=torch.bfloat16)
# Prepare latent variables
num_channels_latents = 16 # anima_models.MiniTrainDIT.LATENT_CHANNELS
num_channels_latents = anima_models.Anima.LATENT_CHANNELS
shape = (
1,
num_channels_latents,

View File

@@ -1049,13 +1049,15 @@ class Block(nn.Module):
)
# Main DiT Model: MiniTrainDIT
class MiniTrainDIT(nn.Module):
# Main DiT Model: MiniTrainDIT (renamed to Anima)
class Anima(nn.Module):
"""Cosmos-Predict2 DiT model for image/video generation.
28 transformer blocks with AdaLN-LoRA modulation, 3D RoPE, and optional LLM Adapter.
"""
LATENT_CHANNELS = 16
def __init__(
self,
max_img_h: int,
@@ -1307,7 +1309,7 @@ class MiniTrainDIT(nn.Module):
return
self.offloader.prepare_block_devices_before_forward(self.blocks)
def forward(
def forward_mini_train_dit(
self,
x_B_C_T_H_W: torch.Tensor,
timesteps_B_T: torch.Tensor,
@@ -1378,6 +1380,41 @@ class MiniTrainDIT(nn.Module):
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)
return x_B_C_Tt_Hp_Wp
def forward(
self,
x: torch.Tensor,
timesteps: torch.Tensor,
context: Optional[torch.Tensor] = None,
fps: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
target_input_ids: Optional[torch.Tensor] = None,
target_attention_mask: Optional[torch.Tensor] = None,
source_attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
context = self._preprocess_text_embeds(context, target_input_ids, target_attention_mask, source_attention_mask)
return self.forward_mini_train_dit(x, timesteps, context, fps=fps, padding_mask=padding_mask, **kwargs)
def _preprocess_text_embeds(
self, source_hidden_states, target_input_ids, target_attention_mask=None, source_attention_mask=None
):
if target_input_ids is not None:
# print(
# f"Source hidden states shape: {source_hidden_states.shape},sum of attention mask: {torch.sum(source_attention_mask)}"
# )
# print(f"non zero source_hidden_states before LLM Adapter: {torch.sum(source_hidden_states != 0)}")
context = self.llm_adapter(
source_hidden_states,
target_input_ids,
target_attention_mask=target_attention_mask,
source_attention_mask=source_attention_mask,
)
context[~target_attention_mask.bool()] = 0 # zero out padding tokens
# print(f"LLM Adapter output context: {context.shape}, {torch.isnan(context).sum()}")
return context
else:
return source_hidden_states
# LLM Adapter: Bridges Qwen3 embeddings to T5-compatible space
class LLMAdapterRMSNorm(nn.Module):
@@ -1531,33 +1568,15 @@ class LLMAdapterTransformerBlock(nn.Module):
)
x = x + attn_out
if source_attention_mask is not None:
# Select batch elements where source_attention_mask has at least one True value
batch_indices = torch.where(source_attention_mask.any(dim=(1, 2, 3)))[0]
# print("Batch indices for cross-attention:", batch_indices)
if len(batch_indices) == 0:
pass # No valid batch elements, skip cross-attention
else:
normed = self.norm_cross_attn(x[batch_indices])
attn_out = self.cross_attn(
normed,
mask=None,
context=context[batch_indices],
position_embeddings=position_embeddings,
position_embeddings_context=position_embeddings_context,
)
x[batch_indices] = x[batch_indices] + attn_out
else:
# Standard cross-attention without masking
normed = self.norm_cross_attn(x)
attn_out = self.cross_attn(
normed,
mask=source_attention_mask,
context=context,
position_embeddings=position_embeddings,
position_embeddings_context=position_embeddings_context,
)
x = x + attn_out
normed = self.norm_cross_attn(x)
attn_out = self.cross_attn(
normed,
mask=source_attention_mask,
context=context,
position_embeddings=position_embeddings,
position_embeddings_context=position_embeddings_context,
)
x = x + attn_out
x = x + self.mlp(self.norm_mlp(x))
return x
@@ -1623,76 +1642,6 @@ class LLMAdapter(nn.Module):
return self.norm(self.out_proj(x))
class Anima(nn.Module):
"""
Wrapper class for the MiniTrainDIT and LLM Adapter.
"""
LATENT_CHANNELS = 16
def __init__(self, dit_config: dict):
super().__init__()
self.net = MiniTrainDIT(**dit_config)
@property
def device(self):
return self.net.device
@property
def dtype(self):
return self.net.dtype
def enable_gradient_checkpointing(self, *args, **kwargs):
self.net.enable_gradient_checkpointing(*args, **kwargs)
def disable_gradient_checkpointing(self):
self.net.disable_gradient_checkpointing()
def enable_block_swap(self, *args, **kwargs):
self.net.enable_block_swap(*args, **kwargs)
def move_to_device_except_swap_blocks(self, *args, **kwargs):
self.net.move_to_device_except_swap_blocks(*args, **kwargs)
def prepare_block_swap_before_forward(self, *args, **kwargs):
self.net.prepare_block_swap_before_forward(*args, **kwargs)
def forward(
self,
x: torch.Tensor,
timesteps: torch.Tensor,
context: Optional[torch.Tensor] = None,
fps: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
target_input_ids: Optional[torch.Tensor] = None,
target_attention_mask: Optional[torch.Tensor] = None,
source_attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
context = self._preprocess_text_embeds(context, target_input_ids, target_attention_mask, source_attention_mask)
return self.net(x, timesteps, context, fps=fps, padding_mask=padding_mask, **kwargs)
def _preprocess_text_embeds(
self, source_hidden_states, target_input_ids, target_attention_mask=None, source_attention_mask=None
):
if target_input_ids is not None:
# print(
# f"Source hidden states shape: {source_hidden_states.shape},sum of attention mask: {torch.sum(source_attention_mask)}"
# )
# print(f"non zero source_hidden_states before LLM Adapter: {torch.sum(source_hidden_states != 0)}")
context = self.net.llm_adapter(
source_hidden_states,
target_input_ids,
target_attention_mask=target_attention_mask,
source_attention_mask=source_attention_mask,
)
context[~target_attention_mask.bool()] = 0 # zero out padding tokens
# print(f"LLM Adapter output context: {context.shape}, {torch.isnan(context).sum()}")
return context
else:
return source_hidden_states
# VAE Wrapper
# VAE normalization constants

View File

@@ -225,7 +225,7 @@ def get_anima_param_groups(
"""Create parameter groups for Anima training with separate learning rates.
Args:
dit: MiniTrainDIT model
dit: Anima model
base_lr: Base learning rate
self_attn_lr: LR for self-attention layers (None = base_lr, 0 = freeze)
cross_attn_lr: LR for cross-attention layers
@@ -307,7 +307,7 @@ def save_anima_model_on_train_end(
save_dtype: torch.dtype,
epoch: int,
global_step: int,
dit: anima_models.MiniTrainDIT,
dit: anima_models.Anima,
):
"""Save Anima model at the end of training."""
@@ -328,7 +328,7 @@ def save_anima_model_on_epoch_end_or_stepwise(
epoch: int,
num_train_epochs: int,
global_step: int,
dit: anima_models.MiniTrainDIT,
dit: anima_models.Anima,
):
"""Save Anima model at epoch end or specific steps."""
@@ -356,7 +356,7 @@ def do_sample(
height: int,
width: int,
seed: Optional[int],
dit: anima_models.MiniTrainDIT,
dit: anima_models.Anima,
crossattn_emb: torch.Tensor,
steps: int,
dtype: torch.dtype,
@@ -370,7 +370,7 @@ def do_sample(
Args:
height, width: Output image dimensions
seed: Random seed (None for random)
dit: MiniTrainDIT model
dit: Anima model
crossattn_emb: Cross-attention embeddings (B, N, D)
steps: Number of sampling steps
dtype: Compute dtype
@@ -588,8 +588,8 @@ def _sample_image_inference(
t5_attn_mask = t5_attn_mask.to(accelerator.device)
# Process through LLM adapter if available
if dit.net.use_llm_adapter:
crossattn_emb = dit.net.llm_adapter(
if dit.use_llm_adapter:
crossattn_emb = dit.llm_adapter(
source_hidden_states=prompt_embeds,
target_input_ids=t5_input_ids,
target_attention_mask=t5_attn_mask,
@@ -616,8 +616,8 @@ def _sample_image_inference(
neg_t5_ids = neg_t5_ids.to(accelerator.device, dtype=torch.long)
neg_t5_am = neg_t5_am.to(accelerator.device)
if dit.net.use_llm_adapter:
neg_crossattn_emb = dit.net.llm_adapter(
if dit.use_llm_adapter:
neg_crossattn_emb = dit.llm_adapter(
source_hidden_states=neg_pe,
target_input_ids=neg_t5_ids,
target_attention_mask=neg_t5_am,

View File

@@ -11,6 +11,7 @@ from accelerate import init_empty_weights
from library.fp8_optimization_utils import apply_fp8_monkey_patch
from library.lora_utils import load_safetensors_with_lora_and_fp8
from library import anima_models
from library.safetensors_utils import WeightTransformHooks
from .utils import setup_logging
setup_logging()
@@ -19,7 +20,6 @@ import logging
logger = logging.getLogger(__name__)
# Keys that should stay in high precision (float32/bfloat16, not quantized)
KEEP_IN_HIGH_PRECISION = ["x_embedder", "t_embedder", "t_embedding_norm", "final_layer"]
@@ -39,8 +39,8 @@ def load_anima_dit(
transformer_dtype: Optional[torch.dtype] = None,
llm_adapter_path: Optional[str] = None,
disable_mmap: bool = False,
) -> anima_models.MiniTrainDIT:
"""Load the MiniTrainDIT model from safetensors.
) -> anima_models.Anima:
"""Load the Anima model from safetensors.
Args:
dit_path: Path to DiT safetensors file
@@ -91,7 +91,7 @@ def load_anima_dit(
)
# Build model normally on CPU — buffers get proper values from __init__
dit = anima_models.MiniTrainDIT(**dit_config)
dit = anima_models.Anima(**dit_config)
# Merge LLM adapter weights into state_dict if loaded separately
if use_llm_adapter and llm_adapter_state_dict is not None:
@@ -192,12 +192,13 @@ def load_anima_model(
"split_attn": split_attn,
}
with init_empty_weights():
model = anima_models.Anima(dit_config)
model = anima_models.Anima(**dit_config)
if dit_weight_dtype is not None:
model.to(dit_weight_dtype)
# load model weights with dynamic fp8 optimization and LoRA merging if needed
logger.info(f"Loading DiT model from {dit_path}, device={loading_device}")
rename_hooks = WeightTransformHooks(rename_hook=lambda k: k[len("net.") :] if k.startswith("net.") else k)
sd = load_safetensors_with_lora_and_fp8(
model_files=dit_path,
lora_weights_list=lora_weights_list,
@@ -208,6 +209,7 @@ def load_anima_model(
dit_weight_dtype=dit_weight_dtype,
target_keys=FP8_OPTIMIZATION_TARGET_KEYS,
exclude_keys=FP8_OPTIMIZATION_EXCLUDE_KEYS,
weight_transform_hooks=rename_hooks,
)
if fp8_scaled:

View File

@@ -9,7 +9,7 @@ import logging
from tqdm import tqdm
from library.device_utils import clean_memory_on_device
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.safetensors_utils import MemoryEfficientSafeOpen, TensorWeightAdapter, WeightTransformHooks
from library.utils import setup_logging
setup_logging()
@@ -220,6 +220,8 @@ def quantize_weight(
tensor_max = torch.max(torch.abs(tensor).view(-1))
scale = tensor_max / max_value
# print(f"Optimizing {key} with scale: {scale}")
# numerical safety
scale = torch.clamp(scale, min=1e-8)
scale = scale.to(torch.float32) # ensure scale is in float32 for division
@@ -245,6 +247,8 @@ def load_safetensors_with_fp8_optimization(
weight_hook=None,
quantization_mode: str = "block",
block_size: Optional[int] = 64,
disable_numpy_memmap: bool = False,
weight_transform_hooks: Optional[WeightTransformHooks] = None,
) -> dict:
"""
Load weight tensors from safetensors files and merge LoRA weights into the state dict with explicit FP8 optimization.
@@ -260,6 +264,8 @@ def load_safetensors_with_fp8_optimization(
weight_hook (callable, optional): Function to apply to each weight tensor before optimization
quantization_mode (str): Quantization mode, "tensor", "channel", or "block"
block_size (int, optional): Block size for block-wise quantization (used if quantization_mode is "block")
disable_numpy_memmap (bool): Disable numpy memmap when loading safetensors
weight_transform_hooks (WeightTransformHooks, optional): Hooks for weight transformation during loading
Returns:
dict: FP8 optimized state dict
@@ -288,7 +294,9 @@ def load_safetensors_with_fp8_optimization(
# Process each file
state_dict = {}
for model_file in model_files:
with MemoryEfficientSafeOpen(model_file) as f:
with MemoryEfficientSafeOpen(model_file, disable_numpy_memmap=disable_numpy_memmap) as original_f:
f = TensorWeightAdapter(weight_transform_hooks, original_f) if weight_transform_hooks is not None else original_f
keys = f.keys()
for key in tqdm(keys, desc=f"Loading {os.path.basename(model_file)}", unit="key"):
value = f.get_tensor(key)
@@ -311,6 +319,11 @@ def load_safetensors_with_fp8_optimization(
value = value.to(calc_device)
original_dtype = value.dtype
if original_dtype.itemsize == 1:
raise ValueError(
f"Layer {key} is already in {original_dtype} format. `--fp8_scaled` optimization should not be applied. Please use fp16/bf16/float32 model weights."
+ f" / レイヤー {key} は既に{original_dtype}形式です。`--fp8_scaled` 最適化は適用できません。FP16/BF16/Float32のモデル重みを使用してください。"
)
quantized_weight, scale_tensor = quantize_weight(
key, value, fp8_dtype, max_value, min_value, quantization_mode, block_size
)
@@ -387,7 +400,7 @@ def fp8_linear_forward_patch(self: nn.Linear, x, use_scaled_mm=False, max_value=
else:
o = torch._scaled_mm(x, weight, out_dtype=input_dtype, scale_a=scale_x, scale_b=scale_weight)
o = o.reshape(original_shape[0], original_shape[1], -1) if x.ndim == 3 else o.reshape(original_shape[0], -1)
o = o.reshape(original_shape[0], original_shape[1], -1) if len(original_shape) == 3 else o.reshape(original_shape[0], -1)
return o.to(input_dtype)
else:

View File

@@ -5,7 +5,7 @@ import torch
from tqdm import tqdm
from library.device_utils import synchronize_device
from library.fp8_optimization_utils import load_safetensors_with_fp8_optimization
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.safetensors_utils import MemoryEfficientSafeOpen, TensorWeightAdapter, WeightTransformHooks, get_split_weight_filenames
from library.utils import setup_logging
setup_logging()
@@ -44,7 +44,7 @@ def filter_lora_state_dict(
def load_safetensors_with_lora_and_fp8(
model_files: Union[str, List[str]],
lora_weights_list: Optional[Dict[str, torch.Tensor]],
lora_weights_list: Optional[List[Dict[str, torch.Tensor]]],
lora_multipliers: Optional[List[float]],
fp8_optimization: bool,
calc_device: torch.device,
@@ -52,19 +52,23 @@ def load_safetensors_with_lora_and_fp8(
dit_weight_dtype: Optional[torch.dtype] = None,
target_keys: Optional[List[str]] = None,
exclude_keys: Optional[List[str]] = None,
disable_numpy_memmap: bool = False,
weight_transform_hooks: Optional[WeightTransformHooks] = None,
) -> dict[str, torch.Tensor]:
"""
Merge LoRA weights into the state dict of a model with fp8 optimization if needed.
Args:
model_files (Union[str, List[str]]): Path to the model file or list of paths. If the path matches a pattern like `00001-of-00004`, it will load all files with the same prefix.
lora_weights_list (Optional[Dict[str, torch.Tensor]]): Dictionary of LoRA weight tensors to load.
lora_weights_list (Optional[List[Dict[str, torch.Tensor]]]): List of dictionaries of LoRA weight tensors to load.
lora_multipliers (Optional[List[float]]): List of multipliers for LoRA weights.
fp8_optimization (bool): Whether to apply FP8 optimization.
calc_device (torch.device): Device to calculate on.
move_to_device (bool): Whether to move tensors to the calculation device after loading.
target_keys (Optional[List[str]]): Keys to target for optimization.
exclude_keys (Optional[List[str]]): Keys to exclude from optimization.
disable_numpy_memmap (bool): Whether to disable numpy memmap when loading safetensors.
weight_transform_hooks (Optional[WeightTransformHooks]): Hooks for transforming weights during loading.
"""
# if the file name ends with 00001-of-00004 etc, we need to load the files with the same prefix
@@ -73,19 +77,9 @@ def load_safetensors_with_lora_and_fp8(
extended_model_files = []
for model_file in model_files:
basename = os.path.basename(model_file)
match = re.match(r"^(.*?)(\d+)-of-(\d+)\.safetensors$", basename)
if match:
prefix = basename[: match.start(2)]
count = int(match.group(3))
state_dict = {}
for i in range(count):
filename = f"{prefix}{i + 1:05d}-of-{count:05d}.safetensors"
filepath = os.path.join(os.path.dirname(model_file), filename)
if os.path.exists(filepath):
extended_model_files.append(filepath)
else:
raise FileNotFoundError(f"File {filepath} not found")
split_filenames = get_split_weight_filenames(model_file)
if split_filenames is not None:
extended_model_files.extend(split_filenames)
else:
extended_model_files.append(model_file)
model_files = extended_model_files
@@ -114,7 +108,7 @@ def load_safetensors_with_lora_and_fp8(
logger.info(f"Merging LoRA weights into state dict. multipliers: {lora_multipliers}")
# make hook for LoRA merging
def weight_hook_func(model_weight_key, model_weight, keep_on_calc_device=False):
def weight_hook_func(model_weight_key, model_weight: torch.Tensor, keep_on_calc_device=False):
nonlocal list_of_lora_weight_keys, lora_weights_list, lora_multipliers, calc_device
if not model_weight_key.endswith(".weight"):
@@ -145,6 +139,13 @@ def load_safetensors_with_lora_and_fp8(
down_weight = down_weight.to(calc_device)
up_weight = up_weight.to(calc_device)
original_dtype = model_weight.dtype
if original_dtype.itemsize == 1: # fp8
# temporarily convert to float16 for calculation
model_weight = model_weight.to(torch.float16)
down_weight = down_weight.to(torch.float16)
up_weight = up_weight.to(torch.float16)
# W <- W + U * D
if len(model_weight.size()) == 2:
# linear
@@ -166,6 +167,9 @@ def load_safetensors_with_lora_and_fp8(
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
model_weight = model_weight + multiplier * conved * scale
if original_dtype.itemsize == 1: # fp8
model_weight = model_weight.to(original_dtype) # convert back to original dtype
# remove LoRA keys from set
lora_weight_keys.remove(down_key)
lora_weight_keys.remove(up_key)
@@ -187,6 +191,8 @@ def load_safetensors_with_lora_and_fp8(
target_keys,
exclude_keys,
weight_hook=weight_hook,
disable_numpy_memmap=disable_numpy_memmap,
weight_transform_hooks=weight_transform_hooks,
)
for lora_weight_keys in list_of_lora_weight_keys:
@@ -208,6 +214,8 @@ def load_safetensors_with_fp8_optimization_and_hook(
target_keys: Optional[List[str]] = None,
exclude_keys: Optional[List[str]] = None,
weight_hook: callable = None,
disable_numpy_memmap: bool = False,
weight_transform_hooks: Optional[WeightTransformHooks] = None,
) -> dict[str, torch.Tensor]:
"""
Load state dict from safetensors files and merge LoRA weights into the state dict with fp8 optimization if needed.
@@ -218,7 +226,14 @@ def load_safetensors_with_fp8_optimization_and_hook(
)
# dit_weight_dtype is not used because we use fp8 optimization
state_dict = load_safetensors_with_fp8_optimization(
model_files, calc_device, target_keys, exclude_keys, move_to_device=move_to_device, weight_hook=weight_hook
model_files,
calc_device,
target_keys,
exclude_keys,
move_to_device=move_to_device,
weight_hook=weight_hook,
disable_numpy_memmap=disable_numpy_memmap,
weight_transform_hooks=weight_transform_hooks,
)
else:
logger.info(
@@ -226,7 +241,8 @@ def load_safetensors_with_fp8_optimization_and_hook(
)
state_dict = {}
for model_file in model_files:
with MemoryEfficientSafeOpen(model_file) as f:
with MemoryEfficientSafeOpen(model_file, disable_numpy_memmap=disable_numpy_memmap) as original_f:
f = TensorWeightAdapter(weight_transform_hooks, original_f) if weight_transform_hooks is not None else original_f
for key in tqdm(f.keys(), desc=f"Loading {os.path.basename(model_file)}", leave=False):
if weight_hook is None and move_to_device:
value = f.get_tensor(key, device=calc_device, dtype=dit_weight_dtype)

View File

@@ -1,3 +1,4 @@
from dataclasses import dataclass
import os
import re
import numpy as np
@@ -44,6 +45,7 @@ def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata:
validated[key] = value
return validated
# print(f"Using memory efficient save file: {filename}")
header = {}
offset = 0
@@ -88,15 +90,17 @@ class MemoryEfficientSafeOpen:
by using memory mapping for large tensors and avoiding unnecessary copies.
"""
def __init__(self, filename):
def __init__(self, filename, disable_numpy_memmap=False):
"""Initialize the SafeTensor reader.
Args:
filename (str): Path to the safetensors file to read.
disable_numpy_memmap (bool): If True, disable numpy memory mapping for large tensors, using standard file read instead.
"""
self.filename = filename
self.file = open(filename, "rb")
self.header, self.header_size = self._read_header()
self.disable_numpy_memmap = disable_numpy_memmap
def __enter__(self):
"""Enter context manager."""
@@ -178,7 +182,8 @@ class MemoryEfficientSafeOpen:
# Use memmap for large tensors to avoid intermediate copies.
# If device is cpu, tensor is not copied to gpu, so using memmap locks the file, which is not desired.
# So we only use memmap if device is not cpu.
if num_bytes > 10 * 1024 * 1024 and device is not None and device.type != "cpu":
# If disable_numpy_memmap is True, skip numpy memory mapping to load with standard file read.
if not self.disable_numpy_memmap and num_bytes > 10 * 1024 * 1024 and device is not None and device.type != "cpu":
# Create memory map for zero-copy reading
mm = np.memmap(self.filename, mode="c", dtype=np.uint8, offset=tensor_offset, shape=(num_bytes,))
byte_tensor = torch.from_numpy(mm) # zero copy
@@ -285,7 +290,11 @@ class MemoryEfficientSafeOpen:
def load_safetensors(
path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = None
path: str,
device: Union[str, torch.device],
disable_mmap: bool = False,
dtype: Optional[torch.dtype] = None,
disable_numpy_memmap: bool = False,
) -> dict[str, torch.Tensor]:
if disable_mmap:
# return safetensors.torch.load(open(path, "rb").read())
@@ -293,7 +302,7 @@ def load_safetensors(
# logger.info(f"Loading without mmap (experimental)")
state_dict = {}
device = torch.device(device) if device is not None else None
with MemoryEfficientSafeOpen(path) as f:
with MemoryEfficientSafeOpen(path, disable_numpy_memmap=disable_numpy_memmap) as f:
for key in f.keys():
state_dict[key] = f.get_tensor(key, device=device, dtype=dtype)
synchronize_device(device)
@@ -309,6 +318,29 @@ def load_safetensors(
return state_dict
def get_split_weight_filenames(file_path: str) -> Optional[list[str]]:
"""
Get the list of split weight filenames (full paths) if the file name ends with 00001-of-00004 etc.
Returns None if the file is not split.
"""
basename = os.path.basename(file_path)
match = re.match(r"^(.*?)(\d+)-of-(\d+)\.safetensors$", basename)
if match:
prefix = basename[: match.start(2)]
count = int(match.group(3))
filenames = []
for i in range(count):
filename = f"{prefix}{i + 1:05d}-of-{count:05d}.safetensors"
filepath = os.path.join(os.path.dirname(file_path), filename)
if os.path.exists(filepath):
filenames.append(filepath)
else:
raise FileNotFoundError(f"File {filepath} not found")
return filenames
else:
return None
def load_split_weights(
file_path: str, device: Union[str, torch.device] = "cpu", disable_mmap: bool = False, dtype: Optional[torch.dtype] = None
) -> Dict[str, torch.Tensor]:
@@ -319,19 +351,11 @@ def load_split_weights(
device = torch.device(device)
# if the file name ends with 00001-of-00004 etc, we need to load the files with the same prefix
basename = os.path.basename(file_path)
match = re.match(r"^(.*?)(\d+)-of-(\d+)\.safetensors$", basename)
if match:
prefix = basename[: match.start(2)]
count = int(match.group(3))
split_filenames = get_split_weight_filenames(file_path)
if split_filenames is not None:
state_dict = {}
for i in range(count):
filename = f"{prefix}{i + 1:05d}-of-{count:05d}.safetensors"
filepath = os.path.join(os.path.dirname(file_path), filename)
if os.path.exists(filepath):
state_dict.update(load_safetensors(filepath, device=device, disable_mmap=disable_mmap, dtype=dtype))
else:
raise FileNotFoundError(f"File {filepath} not found")
for filename in split_filenames:
state_dict.update(load_safetensors(filename, device=device, disable_mmap=disable_mmap, dtype=dtype))
else:
state_dict = load_safetensors(file_path, device=device, disable_mmap=disable_mmap, dtype=dtype)
return state_dict
@@ -349,3 +373,107 @@ def find_key(safetensors_file: str, starts_with: Optional[str] = None, ends_with
if (starts_with is None or key.startswith(starts_with)) and (ends_with is None or key.endswith(ends_with)):
return key
return None
@dataclass
class WeightTransformHooks:
split_hook: Optional[callable] = None
concat_hook: Optional[callable] = None
rename_hook: Optional[callable] = None
class TensorWeightAdapter:
"""
A wrapper for weight conversion hooks (split and concat) to be used with MemoryEfficientSafeOpen.
This wrapper adapts the original MemoryEfficientSafeOpen to apply the provided split and concat hooks
when loading tensors.
split_hook: A callable that takes (original_key: str, original_tensor: torch.Tensor) and returns (new_keys: list[str], new_tensors: list[torch.Tensor]).
concat_hook: A callable that takes (original_key: str, tensors: dict[str, torch.Tensor]) and returns (new_key: str, concatenated_tensor: torch.Tensor).
rename_hook: A callable that takes (original_key: str) and returns (new_key: str).
If tensors is None, the hook should return only the new keys (for split) or new key (for concat), without tensors.
No need to implement __enter__ and __exit__ methods, as they are handled by the original MemoryEfficientSafeOpen.
Do not use this wrapper as a context manager directly, like `with WeightConvertHookWrapper(...) as f:`.
**concat_hook is not tested yet.**
"""
def __init__(self, weight_convert_hook: WeightTransformHooks, original_f: MemoryEfficientSafeOpen):
self.original_f = original_f
self.new_key_to_original_key_map: dict[str, Union[str, list[str]]] = (
{}
) # for split: new_key -> original_key; for concat: new_key -> list of original_keys; for direct mapping: new_key -> original_key
self.concat_key_set = set() # set of concatenated keys
self.split_key_set = set() # set of split keys
self.new_keys = []
self.tensor_cache = {} # cache for split tensors
self.split_hook = weight_convert_hook.split_hook
self.concat_hook = weight_convert_hook.concat_hook
self.rename_hook = weight_convert_hook.rename_hook
for key in self.original_f.keys():
if self.split_hook is not None:
converted_keys, _ = self.split_hook(key, None) # get new keys only
if converted_keys is not None:
for converted_key in converted_keys:
self.new_key_to_original_key_map[converted_key] = key
self.split_key_set.add(converted_key)
self.new_keys.extend(converted_keys)
continue # skip concat_hook if split_hook is applied
if self.concat_hook is not None:
converted_key, _ = self.concat_hook(key, None) # get new key only
if converted_key is not None:
if converted_key not in self.concat_key_set: # first time seeing this concatenated key
self.concat_key_set.add(converted_key)
self.new_key_to_original_key_map[converted_key] = []
# multiple original keys map to the same concatenated key
self.new_key_to_original_key_map[converted_key].append(key)
self.new_keys.append(converted_key)
continue # skip to next key
# direct mapping
if self.rename_hook is not None:
new_key = self.rename_hook(key)
self.new_key_to_original_key_map[new_key] = key
else:
new_key = key
self.new_keys.append(new_key)
def keys(self) -> list[str]:
return self.new_keys
def get_tensor(self, new_key: str, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
# load tensor by new_key, applying split or concat hooks as needed
if new_key not in self.new_key_to_original_key_map:
# direct mapping
return self.original_f.get_tensor(new_key, device=device, dtype=dtype)
elif new_key in self.split_key_set:
# split hook: split key is requested multiple times, so we cache the result
original_key = self.new_key_to_original_key_map[new_key]
if original_key not in self.tensor_cache: # not yet split
original_tensor = self.original_f.get_tensor(original_key, device=device, dtype=dtype)
new_keys, new_tensors = self.split_hook(original_key, original_tensor) # apply split hook
for k, t in zip(new_keys, new_tensors):
self.tensor_cache[k] = t
return self.tensor_cache.pop(new_key) # return and remove from cache
elif new_key in self.concat_key_set:
# concat hook: concatenated key is requested only once, so we do not cache the result
tensors = {}
for original_key in self.new_key_to_original_key_map[new_key]:
tensor = self.original_f.get_tensor(original_key, device=device, dtype=dtype)
tensors[original_key] = tensor
_, concatenated_tensors = self.concat_hook(self.new_key_to_original_key_map[new_key][0], tensors) # apply concat hook
return concatenated_tensors
else:
# direct mapping
original_key = self.new_key_to_original_key_map[new_key]
return self.original_f.get_tensor(original_key, device=device, dtype=dtype)