feat: block-wise fp8 quantization

This commit is contained in:
Kohya S
2025-09-18 21:20:54 +09:00
parent 2ce506e187
commit f6b4bdc83f
4 changed files with 186 additions and 102 deletions

View File

@@ -1,12 +1,8 @@
# copy from Musubi Tuner
import os
import re
from typing import Dict, List, Optional, Union
import torch
from tqdm import tqdm
from library.device_utils import synchronize_device
from library.fp8_optimization_utils import load_safetensors_with_fp8_optimization
from library.safetensors_utils import MemoryEfficientSafeOpen
@@ -84,7 +80,7 @@ def load_safetensors_with_lora_and_fp8(
count = int(match.group(3))
state_dict = {}
for i in range(count):
filename = f"{prefix}{i+1:05d}-of-{count:05d}.safetensors"
filename = f"{prefix}{i + 1:05d}-of-{count:05d}.safetensors"
filepath = os.path.join(os.path.dirname(model_file), filename)
if os.path.exists(filepath):
extended_model_files.append(filepath)
@@ -118,7 +114,7 @@ def load_safetensors_with_lora_and_fp8(
logger.info(f"Merging LoRA weights into state dict. multipliers: {lora_multipliers}")
# make hook for LoRA merging
def weight_hook_func(model_weight_key, model_weight):
def weight_hook_func(model_weight_key, model_weight, keep_on_calc_device=False):
nonlocal list_of_lora_weight_keys, lora_weights_list, lora_multipliers, calc_device
if not model_weight_key.endswith(".weight"):
@@ -176,7 +172,8 @@ def load_safetensors_with_lora_and_fp8(
if alpha_key in lora_weight_keys:
lora_weight_keys.remove(alpha_key)
model_weight = model_weight.to(original_device) # move back to original device
if not keep_on_calc_device and original_device != calc_device:
model_weight = model_weight.to(original_device) # move back to original device
return model_weight
weight_hook = weight_hook_func
@@ -231,19 +228,18 @@ def load_safetensors_with_fp8_optimization_and_hook(
for model_file in model_files:
with MemoryEfficientSafeOpen(model_file) as f:
for key in tqdm(f.keys(), desc=f"Loading {os.path.basename(model_file)}", leave=False):
value = f.get_tensor(key)
if weight_hook is not None:
value = weight_hook(key, value)
if move_to_device:
if dit_weight_dtype is None:
value = value.to(calc_device, non_blocking=True)
else:
if weight_hook is None and move_to_device:
value = f.get_tensor(key, device=calc_device, dtype=dit_weight_dtype)
else:
value = f.get_tensor(key) # we cannot directly load to device because get_tensor does non-blocking transfer
if weight_hook is not None:
value = weight_hook(key, value, keep_on_calc_device=move_to_device)
if move_to_device:
value = value.to(calc_device, dtype=dit_weight_dtype, non_blocking=True)
elif dit_weight_dtype is not None:
value = value.to(dit_weight_dtype)
elif dit_weight_dtype is not None:
value = value.to(dit_weight_dtype)
state_dict[key] = value
if move_to_device:
synchronize_device(calc_device)