feat: Update direct loading fp8 ckpt for LoRA training

This commit is contained in:
Kohya S
2024-08-27 21:40:02 +09:00
parent 0087a46e14
commit 3be712e3e0
6 changed files with 151 additions and 55 deletions

View File

@@ -8,7 +8,7 @@ from safetensors import safe_open
from safetensors.torch import load_file, save_file
from tqdm import tqdm
from library.utils import setup_logging
from library.utils import setup_logging, str_to_dtype, MemoryEfficientSafeOpen, mem_eff_save_file
setup_logging()
import logging
@@ -34,18 +34,23 @@ def load_state_dict(file_name, dtype):
return sd, metadata
def save_to_file(file_name, state_dict, dtype, metadata):
def save_to_file(file_name, state_dict, dtype, metadata, mem_eff_save=False):
if dtype is not None:
logger.info(f"converting to {dtype}...")
for key in list(state_dict.keys()):
for key in tqdm(list(state_dict.keys())):
if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype)
logger.info(f"saving to: {file_name}")
save_file(state_dict, file_name, metadata=metadata)
if mem_eff_save:
mem_eff_save_file(state_dict, file_name, metadata=metadata)
else:
save_file(state_dict, file_name, metadata=metadata)
def merge_to_flux_model(loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype):
def merge_to_flux_model(
loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype, mem_eff_load_save=False
):
# create module map without loading state_dict
logger.info(f"loading keys from FLUX.1 model: {flux_model}")
lora_name_to_module_key = {}
@@ -57,7 +62,14 @@ def merge_to_flux_model(loading_device, working_device, flux_model, models, rati
lora_name = lora_flux.LoRANetwork.LORA_PREFIX_FLUX + "_" + module_name.replace(".", "_")
lora_name_to_module_key[lora_name] = key
flux_state_dict = load_file(flux_model, device=loading_device)
if mem_eff_load_save:
flux_state_dict = {}
with MemoryEfficientSafeOpen(flux_model) as flux_file:
for key in tqdm(flux_file.keys()):
flux_state_dict[key] = flux_file.get_tensor(key).to(loading_device) # dtype is not changed
else:
flux_state_dict = load_file(flux_model, device=loading_device)
for model, ratio in zip(models, ratios):
logger.info(f"loading: {model}")
lora_sd, _ = load_state_dict(model, merge_dtype) # loading on CPU
@@ -120,9 +132,17 @@ def merge_to_flux_model(loading_device, working_device, flux_model, models, rati
return flux_state_dict
def merge_to_flux_model_diffusers(loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype):
def merge_to_flux_model_diffusers(
loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype, mem_eff_load_save=False
):
logger.info(f"loading keys from FLUX.1 model: {flux_model}")
flux_state_dict = load_file(flux_model, device=loading_device)
if mem_eff_load_save:
flux_state_dict = {}
with MemoryEfficientSafeOpen(flux_model) as flux_file:
for key in tqdm(flux_file.keys()):
flux_state_dict[key] = flux_file.get_tensor(key).to(loading_device) # dtype is not changed
else:
flux_state_dict = load_file(flux_model, device=loading_device)
def create_key_map(n_double_layers, n_single_layers):
key_map = {}
@@ -474,19 +494,15 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
def merge(args):
if args.models is None:
args.models = []
if args.ratios is None:
args.ratios = []
assert len(args.models) == len(
args.ratios
), "number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
def str_to_dtype(p):
if p == "float":
return torch.float
if p == "fp16":
return torch.float16
if p == "bf16":
return torch.bfloat16
return None
merge_dtype = str_to_dtype(args.precision)
save_dtype = str_to_dtype(args.save_precision)
if save_dtype is None:
@@ -500,11 +516,25 @@ def merge(args):
if args.flux_model is not None:
if not args.diffusers:
state_dict = merge_to_flux_model(
args.loading_device, args.working_device, args.flux_model, args.models, args.ratios, merge_dtype, save_dtype
args.loading_device,
args.working_device,
args.flux_model,
args.models,
args.ratios,
merge_dtype,
save_dtype,
args.mem_eff_load_save,
)
else:
state_dict = merge_to_flux_model_diffusers(
args.loading_device, args.working_device, args.flux_model, args.models, args.ratios, merge_dtype, save_dtype
args.loading_device,
args.working_device,
args.flux_model,
args.models,
args.ratios,
merge_dtype,
save_dtype,
args.mem_eff_load_save,
)
if args.no_metadata:
@@ -517,7 +547,7 @@ def merge(args):
)
logger.info(f"saving FLUX model to: {args.save_to}")
save_to_file(args.save_to, state_dict, save_dtype, sai_metadata)
save_to_file(args.save_to, state_dict, save_dtype, sai_metadata, args.mem_eff_load_save)
else:
state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle)
@@ -546,14 +576,14 @@ def setup_parser() -> argparse.ArgumentParser:
"--save_precision",
type=str,
default=None,
choices=[None, "float", "fp16", "bf16"],
help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ",
help="precision in saving, same to merging if omitted. supported types: "
"float32, fp16, bf16, fp8 (same as fp8_e4m3fn), fp8_e4m3fn, fp8_e4m3fnuz, fp8_e5m2, fp8_e5m2fnuz"
" / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ",
)
parser.add_argument(
"--precision",
type=str,
default="float",
choices=["float", "fp16", "bf16"],
help="precision in merging (float is recommended) / マージの計算時の精度floatを推奨",
)
parser.add_argument(
@@ -562,6 +592,12 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="FLUX.1 model to load, merge LoRA models if omitted / 読み込むモデル、指定しない場合はLoRAモデルをマージする",
)
parser.add_argument(
"--mem_eff_load_save",
action="store_true",
help="use custom memory efficient load and save functions for FLUX.1 model"
" / カスタムのメモリ効率の良い読み込みと保存関数をFLUX.1モデルに使用する",
)
parser.add_argument(
"--loading_device",
type=str,