mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
feat: block-wise fp8 quantization
This commit is contained in:
@@ -1,12 +1,8 @@
|
||||
# copy from Musubi Tuner
|
||||
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List, Optional, Union
|
||||
import torch
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from library.device_utils import synchronize_device
|
||||
from library.fp8_optimization_utils import load_safetensors_with_fp8_optimization
|
||||
from library.safetensors_utils import MemoryEfficientSafeOpen
|
||||
@@ -84,7 +80,7 @@ def load_safetensors_with_lora_and_fp8(
|
||||
count = int(match.group(3))
|
||||
state_dict = {}
|
||||
for i in range(count):
|
||||
filename = f"{prefix}{i+1:05d}-of-{count:05d}.safetensors"
|
||||
filename = f"{prefix}{i + 1:05d}-of-{count:05d}.safetensors"
|
||||
filepath = os.path.join(os.path.dirname(model_file), filename)
|
||||
if os.path.exists(filepath):
|
||||
extended_model_files.append(filepath)
|
||||
@@ -118,7 +114,7 @@ def load_safetensors_with_lora_and_fp8(
|
||||
logger.info(f"Merging LoRA weights into state dict. multipliers: {lora_multipliers}")
|
||||
|
||||
# make hook for LoRA merging
|
||||
def weight_hook_func(model_weight_key, model_weight):
|
||||
def weight_hook_func(model_weight_key, model_weight, keep_on_calc_device=False):
|
||||
nonlocal list_of_lora_weight_keys, lora_weights_list, lora_multipliers, calc_device
|
||||
|
||||
if not model_weight_key.endswith(".weight"):
|
||||
@@ -176,7 +172,8 @@ def load_safetensors_with_lora_and_fp8(
|
||||
if alpha_key in lora_weight_keys:
|
||||
lora_weight_keys.remove(alpha_key)
|
||||
|
||||
model_weight = model_weight.to(original_device) # move back to original device
|
||||
if not keep_on_calc_device and original_device != calc_device:
|
||||
model_weight = model_weight.to(original_device) # move back to original device
|
||||
return model_weight
|
||||
|
||||
weight_hook = weight_hook_func
|
||||
@@ -231,19 +228,18 @@ def load_safetensors_with_fp8_optimization_and_hook(
|
||||
for model_file in model_files:
|
||||
with MemoryEfficientSafeOpen(model_file) as f:
|
||||
for key in tqdm(f.keys(), desc=f"Loading {os.path.basename(model_file)}", leave=False):
|
||||
value = f.get_tensor(key)
|
||||
if weight_hook is not None:
|
||||
value = weight_hook(key, value)
|
||||
if move_to_device:
|
||||
if dit_weight_dtype is None:
|
||||
value = value.to(calc_device, non_blocking=True)
|
||||
else:
|
||||
if weight_hook is None and move_to_device:
|
||||
value = f.get_tensor(key, device=calc_device, dtype=dit_weight_dtype)
|
||||
else:
|
||||
value = f.get_tensor(key) # we cannot directly load to device because get_tensor does non-blocking transfer
|
||||
if weight_hook is not None:
|
||||
value = weight_hook(key, value, keep_on_calc_device=move_to_device)
|
||||
if move_to_device:
|
||||
value = value.to(calc_device, dtype=dit_weight_dtype, non_blocking=True)
|
||||
elif dit_weight_dtype is not None:
|
||||
value = value.to(dit_weight_dtype)
|
||||
elif dit_weight_dtype is not None:
|
||||
value = value.to(dit_weight_dtype)
|
||||
|
||||
state_dict[key] = value
|
||||
|
||||
if move_to_device:
|
||||
synchronize_device(calc_device)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user