feat: Update direct loading fp8 ckpt for LoRA training

This commit is contained in:
Kohya S
2024-08-27 21:40:02 +09:00
parent 0087a46e14
commit 3be712e3e0
6 changed files with 151 additions and 55 deletions

View File

@@ -9,13 +9,18 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv
The command to install PyTorch is as follows:
`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124`
Aug 27, 2024 (update 2):
In FLUX.1 LoRA training, when `--fp8_base` is specified, the FLUX.1 model file with fp8 (`float8_e4m3fn` type) can be loaded directly. Also, in `flux_minimal_inference.py`, it is possible to load it by specifying `fp8 (float8_e4m3fn)` in `--flux_dtype`.
In `flux_merge_lora.py`, you can now specify the precision at save time with `fp8` (see `--help` for details). Also, if you do not specify the merge model, only the model type conversion will be performed.
Aug 27, 2024:
- FLUX.1 LoRA training now supports CLIP-L LoRA. Please remove `--network_train_unet_only`. T5XXL is not trained. The output of T5XXL is still cached, so `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is still required. The trained LoRA can be used with ComfyUI.
- `flux_extract_lora.py` and `convert_flux_lora.py` do not support CLIP-L LoRA.
- `--sigmoid_scale` is now effective even when `--timestep_sampling shift` is specified. Normally, leave it at 1.0. Larger values make the value before shift closer to a uniform distribution.
- __Experimental__ `--fp8_base_unet` option is added to `flux_train_network.py`. Flux can be trained with fp8, and CLIP-L can be trained with bf16/fp16. When specifying this option, the `--fp8_base` option is not required (Flux is fp8, and CLIP-L is bf16/fp16, regardless of the `--fp8_base` option).
- __Experimental__ `--fp8_base_unet` option is added to `flux_train_network.py`. Flux can be trained with fp8, and CLIP-L can be trained with bf16/fp16. When specifying this option, the `--fp8_base` option is automatically enabled.
Aug 25, 2024:
Added `shift` option to `--timestep_sampling` in FLUX.1 fine-tuning and LoRA training. Shifts timesteps according to the value of `--discrete_flow_shift` (shifts the value of sigmoid of normal distribution random number). It may be good to start with a value of 3.1582 (=e^1.15) for `--discrete_flow_shift`.

View File

@@ -10,7 +10,6 @@ import einops
import numpy as np
import torch
from safetensors.torch import safe_open, load_file
from tqdm import tqdm
from PIL import Image
import accelerate
@@ -21,7 +20,7 @@ from library.device_utils import init_ipex, get_preferred_device
init_ipex()
from library.utils import setup_logging
from library.utils import setup_logging, str_to_dtype
setup_logging()
import logging
@@ -288,28 +287,6 @@ if __name__ == "__main__":
name = "schnell" if "schnell" in args.ckpt_path else "dev" # TODO change this to a more robust way
is_schnell = name == "schnell"
def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype:
if s is None:
return default_dtype
if s in ["bf16", "bfloat16"]:
return torch.bfloat16
elif s in ["fp16", "float16"]:
return torch.float16
elif s in ["fp32", "float32"]:
return torch.float32
elif s in ["fp8_e4m3fn", "e4m3fn", "float8_e4m3fn"]:
return torch.float8_e4m3fn
elif s in ["fp8_e4m3fnuz", "e4m3fnuz", "float8_e4m3fnuz"]:
return torch.float8_e4m3fnuz
elif s in ["fp8_e5m2", "e5m2", "float8_e5m2"]:
return torch.float8_e5m2
elif s in ["fp8_e5m2fnuz", "e5m2fnuz", "float8_e5m2fnuz"]:
return torch.float8_e5m2fnuz
elif s in ["fp8", "float8"]:
return torch.float8_e4m3fn # default fp8
else:
raise ValueError(f"Unsupported dtype: {s}")
def is_fp8(dt):
return dt in [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz]
@@ -348,7 +325,7 @@ if __name__ == "__main__":
encoding_strategy = strategy_flux.FluxTextEncodingStrategy()
# DiT
model = flux_utils.load_flow_model(name, args.ckpt_path, flux_dtype, loading_device)
model = flux_utils.load_flow_model(name, args.ckpt_path, None, loading_device)
model.eval()
logger.info(f"Casting model to {flux_dtype}")
model.to(flux_dtype) # make sure model is dtype

View File

@@ -29,6 +29,9 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
super().assert_extra_args(args, train_dataset_group)
# sdxl_train_util.verify_sdxl_training_args(args)
if args.fp8_base_unet:
args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for FLUX.1
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
logger.warning(
"cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります"
@@ -61,9 +64,20 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
name = self.get_flux_model_name(args)
# if we load to cpu, flux.to(fp8) takes a long time
if args.fp8_base:
loading_dtype = None # as is
else:
loading_dtype = weight_dtype
model = flux_utils.load_flow_model(
name, args.pretrained_model_name_or_path, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors
name, args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors
)
if args.fp8_base:
# check dtype of model
if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz:
raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}")
elif model.dtype == torch.float8_e4m3fn:
logger.info("Loaded fp8 FLUX model")
if args.split_mode:
model = self.prepare_split_model(model, weight_dtype, accelerator)

View File

@@ -1,5 +1,5 @@
import json
from typing import Union
from typing import Optional, Union
import einops
import torch
@@ -20,7 +20,9 @@ MODEL_VERSION_FLUX_V1 = "flux1"
# temporary copy from sd3_utils TODO refactor
def load_safetensors(path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: torch.dtype = torch.float32):
def load_safetensors(
path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = torch.float32
):
if disable_mmap:
# return safetensors.torch.load(open(path, "rb").read())
# use experimental loader
@@ -38,11 +40,13 @@ def load_safetensors(path: str, device: Union[str, torch.device], disable_mmap:
def load_flow_model(
name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False
name: str, ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False
) -> flux_models.Flux:
logger.info(f"Building Flux model {name}")
with torch.device("meta"):
model = flux_models.Flux(flux_models.configs[name].params).to(dtype)
model = flux_models.Flux(flux_models.configs[name].params)
if dtype is not None:
model = model.to(dtype)
# load_sft doesn't support torch.device
logger.info(f"Loading state dict from {ckpt_path}")

View File

@@ -82,6 +82,66 @@ def setup_logging(args=None, log_level=None, reset=False):
logger.info(msg_init)
def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype:
"""
Convert a string to a torch.dtype
Args:
s: string representation of the dtype
default_dtype: default dtype to return if s is None
Returns:
torch.dtype: the corresponding torch.dtype
Raises:
ValueError: if the dtype is not supported
Examples:
>>> str_to_dtype("float32")
torch.float32
>>> str_to_dtype("fp32")
torch.float32
>>> str_to_dtype("float16")
torch.float16
>>> str_to_dtype("fp16")
torch.float16
>>> str_to_dtype("bfloat16")
torch.bfloat16
>>> str_to_dtype("bf16")
torch.bfloat16
>>> str_to_dtype("fp8")
torch.float8_e4m3fn
>>> str_to_dtype("fp8_e4m3fn")
torch.float8_e4m3fn
>>> str_to_dtype("fp8_e4m3fnuz")
torch.float8_e4m3fnuz
>>> str_to_dtype("fp8_e5m2")
torch.float8_e5m2
>>> str_to_dtype("fp8_e5m2fnuz")
torch.float8_e5m2fnuz
"""
if s is None:
return default_dtype
if s in ["bf16", "bfloat16"]:
return torch.bfloat16
elif s in ["fp16", "float16"]:
return torch.float16
elif s in ["fp32", "float32", "float"]:
return torch.float32
elif s in ["fp8_e4m3fn", "e4m3fn", "float8_e4m3fn"]:
return torch.float8_e4m3fn
elif s in ["fp8_e4m3fnuz", "e4m3fnuz", "float8_e4m3fnuz"]:
return torch.float8_e4m3fnuz
elif s in ["fp8_e5m2", "e5m2", "float8_e5m2"]:
return torch.float8_e5m2
elif s in ["fp8_e5m2fnuz", "e5m2fnuz", "float8_e5m2fnuz"]:
return torch.float8_e5m2fnuz
elif s in ["fp8", "float8"]:
return torch.float8_e4m3fn # default fp8
else:
raise ValueError(f"Unsupported dtype: {s}")
def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None):
"""
memory efficient save file
@@ -198,7 +258,7 @@ class MemoryEfficientSafeOpen:
if tensor_bytes is None:
byte_tensor = torch.empty(0, dtype=torch.uint8)
else:
tensor_bytes = bytearray(tensor_bytes) # make it writable
tensor_bytes = bytearray(tensor_bytes) # make it writable
byte_tensor = torch.frombuffer(tensor_bytes, dtype=torch.uint8)
# process float8 types

View File

@@ -8,7 +8,7 @@ from safetensors import safe_open
from safetensors.torch import load_file, save_file
from tqdm import tqdm
from library.utils import setup_logging
from library.utils import setup_logging, str_to_dtype, MemoryEfficientSafeOpen, mem_eff_save_file
setup_logging()
import logging
@@ -34,18 +34,23 @@ def load_state_dict(file_name, dtype):
return sd, metadata
def save_to_file(file_name, state_dict, dtype, metadata):
def save_to_file(file_name, state_dict, dtype, metadata, mem_eff_save=False):
if dtype is not None:
logger.info(f"converting to {dtype}...")
for key in list(state_dict.keys()):
for key in tqdm(list(state_dict.keys())):
if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype)
logger.info(f"saving to: {file_name}")
save_file(state_dict, file_name, metadata=metadata)
if mem_eff_save:
mem_eff_save_file(state_dict, file_name, metadata=metadata)
else:
save_file(state_dict, file_name, metadata=metadata)
def merge_to_flux_model(loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype):
def merge_to_flux_model(
loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype, mem_eff_load_save=False
):
# create module map without loading state_dict
logger.info(f"loading keys from FLUX.1 model: {flux_model}")
lora_name_to_module_key = {}
@@ -57,7 +62,14 @@ def merge_to_flux_model(loading_device, working_device, flux_model, models, rati
lora_name = lora_flux.LoRANetwork.LORA_PREFIX_FLUX + "_" + module_name.replace(".", "_")
lora_name_to_module_key[lora_name] = key
flux_state_dict = load_file(flux_model, device=loading_device)
if mem_eff_load_save:
flux_state_dict = {}
with MemoryEfficientSafeOpen(flux_model) as flux_file:
for key in tqdm(flux_file.keys()):
flux_state_dict[key] = flux_file.get_tensor(key).to(loading_device) # dtype is not changed
else:
flux_state_dict = load_file(flux_model, device=loading_device)
for model, ratio in zip(models, ratios):
logger.info(f"loading: {model}")
lora_sd, _ = load_state_dict(model, merge_dtype) # loading on CPU
@@ -120,9 +132,17 @@ def merge_to_flux_model(loading_device, working_device, flux_model, models, rati
return flux_state_dict
def merge_to_flux_model_diffusers(loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype):
def merge_to_flux_model_diffusers(
loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype, mem_eff_load_save=False
):
logger.info(f"loading keys from FLUX.1 model: {flux_model}")
flux_state_dict = load_file(flux_model, device=loading_device)
if mem_eff_load_save:
flux_state_dict = {}
with MemoryEfficientSafeOpen(flux_model) as flux_file:
for key in tqdm(flux_file.keys()):
flux_state_dict[key] = flux_file.get_tensor(key).to(loading_device) # dtype is not changed
else:
flux_state_dict = load_file(flux_model, device=loading_device)
def create_key_map(n_double_layers, n_single_layers):
key_map = {}
@@ -474,19 +494,15 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
def merge(args):
if args.models is None:
args.models = []
if args.ratios is None:
args.ratios = []
assert len(args.models) == len(
args.ratios
), "number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
def str_to_dtype(p):
if p == "float":
return torch.float
if p == "fp16":
return torch.float16
if p == "bf16":
return torch.bfloat16
return None
merge_dtype = str_to_dtype(args.precision)
save_dtype = str_to_dtype(args.save_precision)
if save_dtype is None:
@@ -500,11 +516,25 @@ def merge(args):
if args.flux_model is not None:
if not args.diffusers:
state_dict = merge_to_flux_model(
args.loading_device, args.working_device, args.flux_model, args.models, args.ratios, merge_dtype, save_dtype
args.loading_device,
args.working_device,
args.flux_model,
args.models,
args.ratios,
merge_dtype,
save_dtype,
args.mem_eff_load_save,
)
else:
state_dict = merge_to_flux_model_diffusers(
args.loading_device, args.working_device, args.flux_model, args.models, args.ratios, merge_dtype, save_dtype
args.loading_device,
args.working_device,
args.flux_model,
args.models,
args.ratios,
merge_dtype,
save_dtype,
args.mem_eff_load_save,
)
if args.no_metadata:
@@ -517,7 +547,7 @@ def merge(args):
)
logger.info(f"saving FLUX model to: {args.save_to}")
save_to_file(args.save_to, state_dict, save_dtype, sai_metadata)
save_to_file(args.save_to, state_dict, save_dtype, sai_metadata, args.mem_eff_load_save)
else:
state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle)
@@ -546,14 +576,14 @@ def setup_parser() -> argparse.ArgumentParser:
"--save_precision",
type=str,
default=None,
choices=[None, "float", "fp16", "bf16"],
help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ",
help="precision in saving, same to merging if omitted. supported types: "
"float32, fp16, bf16, fp8 (same as fp8_e4m3fn), fp8_e4m3fn, fp8_e4m3fnuz, fp8_e5m2, fp8_e5m2fnuz"
" / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ",
)
parser.add_argument(
"--precision",
type=str,
default="float",
choices=["float", "fp16", "bf16"],
help="precision in merging (float is recommended) / マージの計算時の精度floatを推奨",
)
parser.add_argument(
@@ -562,6 +592,12 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="FLUX.1 model to load, merge LoRA models if omitted / 読み込むモデル、指定しない場合はLoRAモデルをマージする",
)
parser.add_argument(
"--mem_eff_load_save",
action="store_true",
help="use custom memory efficient load and save functions for FLUX.1 model"
" / カスタムのメモリ効率の良い読み込みと保存関数をFLUX.1モデルに使用する",
)
parser.add_argument(
"--loading_device",
type=str,