mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
feat: Update direct loading fp8 ckpt for LoRA training
This commit is contained in:
@@ -9,13 +9,18 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv
|
||||
The command to install PyTorch is as follows:
|
||||
`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124`
|
||||
|
||||
Aug 27, 2024 (update 2):
|
||||
In FLUX.1 LoRA training, when `--fp8_base` is specified, the FLUX.1 model file with fp8 (`float8_e4m3fn` type) can be loaded directly. Also, in `flux_minimal_inference.py`, it is possible to load it by specifying `fp8 (float8_e4m3fn)` in `--flux_dtype`.
|
||||
|
||||
In `flux_merge_lora.py`, you can now specify the precision at save time with `fp8` (see `--help` for details). Also, if you do not specify the merge model, only the model type conversion will be performed.
|
||||
|
||||
Aug 27, 2024:
|
||||
|
||||
- FLUX.1 LoRA training now supports CLIP-L LoRA. Please remove `--network_train_unet_only`. T5XXL is not trained. The output of T5XXL is still cached, so `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is still required. The trained LoRA can be used with ComfyUI.
|
||||
- `flux_extract_lora.py` and `convert_flux_lora.py` do not support CLIP-L LoRA.
|
||||
- `--sigmoid_scale` is now effective even when `--timestep_sampling shift` is specified. Normally, leave it at 1.0. Larger values make the value before shift closer to a uniform distribution.
|
||||
|
||||
- __Experimental__ `--fp8_base_unet` option is added to `flux_train_network.py`. Flux can be trained with fp8, and CLIP-L can be trained with bf16/fp16. When specifying this option, the `--fp8_base` option is not required (Flux is fp8, and CLIP-L is bf16/fp16, regardless of the `--fp8_base` option).
|
||||
- __Experimental__ `--fp8_base_unet` option is added to `flux_train_network.py`. Flux can be trained with fp8, and CLIP-L can be trained with bf16/fp16. When specifying this option, the `--fp8_base` option is automatically enabled.
|
||||
|
||||
Aug 25, 2024:
|
||||
Added `shift` option to `--timestep_sampling` in FLUX.1 fine-tuning and LoRA training. Shifts timesteps according to the value of `--discrete_flow_shift` (shifts the value of sigmoid of normal distribution random number). It may be good to start with a value of 3.1582 (=e^1.15) for `--discrete_flow_shift`.
|
||||
|
||||
@@ -10,7 +10,6 @@ import einops
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from safetensors.torch import safe_open, load_file
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
import accelerate
|
||||
@@ -21,7 +20,7 @@ from library.device_utils import init_ipex, get_preferred_device
|
||||
init_ipex()
|
||||
|
||||
|
||||
from library.utils import setup_logging
|
||||
from library.utils import setup_logging, str_to_dtype
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
@@ -288,28 +287,6 @@ if __name__ == "__main__":
|
||||
name = "schnell" if "schnell" in args.ckpt_path else "dev" # TODO change this to a more robust way
|
||||
is_schnell = name == "schnell"
|
||||
|
||||
def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype:
|
||||
if s is None:
|
||||
return default_dtype
|
||||
if s in ["bf16", "bfloat16"]:
|
||||
return torch.bfloat16
|
||||
elif s in ["fp16", "float16"]:
|
||||
return torch.float16
|
||||
elif s in ["fp32", "float32"]:
|
||||
return torch.float32
|
||||
elif s in ["fp8_e4m3fn", "e4m3fn", "float8_e4m3fn"]:
|
||||
return torch.float8_e4m3fn
|
||||
elif s in ["fp8_e4m3fnuz", "e4m3fnuz", "float8_e4m3fnuz"]:
|
||||
return torch.float8_e4m3fnuz
|
||||
elif s in ["fp8_e5m2", "e5m2", "float8_e5m2"]:
|
||||
return torch.float8_e5m2
|
||||
elif s in ["fp8_e5m2fnuz", "e5m2fnuz", "float8_e5m2fnuz"]:
|
||||
return torch.float8_e5m2fnuz
|
||||
elif s in ["fp8", "float8"]:
|
||||
return torch.float8_e4m3fn # default fp8
|
||||
else:
|
||||
raise ValueError(f"Unsupported dtype: {s}")
|
||||
|
||||
def is_fp8(dt):
|
||||
return dt in [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz]
|
||||
|
||||
@@ -348,7 +325,7 @@ if __name__ == "__main__":
|
||||
encoding_strategy = strategy_flux.FluxTextEncodingStrategy()
|
||||
|
||||
# DiT
|
||||
model = flux_utils.load_flow_model(name, args.ckpt_path, flux_dtype, loading_device)
|
||||
model = flux_utils.load_flow_model(name, args.ckpt_path, None, loading_device)
|
||||
model.eval()
|
||||
logger.info(f"Casting model to {flux_dtype}")
|
||||
model.to(flux_dtype) # make sure model is dtype
|
||||
|
||||
@@ -29,6 +29,9 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
super().assert_extra_args(args, train_dataset_group)
|
||||
# sdxl_train_util.verify_sdxl_training_args(args)
|
||||
|
||||
if args.fp8_base_unet:
|
||||
args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for FLUX.1
|
||||
|
||||
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
|
||||
logger.warning(
|
||||
"cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります"
|
||||
@@ -61,9 +64,20 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
name = self.get_flux_model_name(args)
|
||||
|
||||
# if we load to cpu, flux.to(fp8) takes a long time
|
||||
if args.fp8_base:
|
||||
loading_dtype = None # as is
|
||||
else:
|
||||
loading_dtype = weight_dtype
|
||||
|
||||
model = flux_utils.load_flow_model(
|
||||
name, args.pretrained_model_name_or_path, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors
|
||||
name, args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors
|
||||
)
|
||||
if args.fp8_base:
|
||||
# check dtype of model
|
||||
if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz:
|
||||
raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}")
|
||||
elif model.dtype == torch.float8_e4m3fn:
|
||||
logger.info("Loaded fp8 FLUX model")
|
||||
|
||||
if args.split_mode:
|
||||
model = self.prepare_split_model(model, weight_dtype, accelerator)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Union
|
||||
from typing import Optional, Union
|
||||
import einops
|
||||
import torch
|
||||
|
||||
@@ -20,7 +20,9 @@ MODEL_VERSION_FLUX_V1 = "flux1"
|
||||
|
||||
|
||||
# temporary copy from sd3_utils TODO refactor
|
||||
def load_safetensors(path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: torch.dtype = torch.float32):
|
||||
def load_safetensors(
|
||||
path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = torch.float32
|
||||
):
|
||||
if disable_mmap:
|
||||
# return safetensors.torch.load(open(path, "rb").read())
|
||||
# use experimental loader
|
||||
@@ -38,11 +40,13 @@ def load_safetensors(path: str, device: Union[str, torch.device], disable_mmap:
|
||||
|
||||
|
||||
def load_flow_model(
|
||||
name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False
|
||||
name: str, ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False
|
||||
) -> flux_models.Flux:
|
||||
logger.info(f"Building Flux model {name}")
|
||||
with torch.device("meta"):
|
||||
model = flux_models.Flux(flux_models.configs[name].params).to(dtype)
|
||||
model = flux_models.Flux(flux_models.configs[name].params)
|
||||
if dtype is not None:
|
||||
model = model.to(dtype)
|
||||
|
||||
# load_sft doesn't support torch.device
|
||||
logger.info(f"Loading state dict from {ckpt_path}")
|
||||
|
||||
@@ -82,6 +82,66 @@ def setup_logging(args=None, log_level=None, reset=False):
|
||||
logger.info(msg_init)
|
||||
|
||||
|
||||
def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype:
|
||||
"""
|
||||
Convert a string to a torch.dtype
|
||||
|
||||
Args:
|
||||
s: string representation of the dtype
|
||||
default_dtype: default dtype to return if s is None
|
||||
|
||||
Returns:
|
||||
torch.dtype: the corresponding torch.dtype
|
||||
|
||||
Raises:
|
||||
ValueError: if the dtype is not supported
|
||||
|
||||
Examples:
|
||||
>>> str_to_dtype("float32")
|
||||
torch.float32
|
||||
>>> str_to_dtype("fp32")
|
||||
torch.float32
|
||||
>>> str_to_dtype("float16")
|
||||
torch.float16
|
||||
>>> str_to_dtype("fp16")
|
||||
torch.float16
|
||||
>>> str_to_dtype("bfloat16")
|
||||
torch.bfloat16
|
||||
>>> str_to_dtype("bf16")
|
||||
torch.bfloat16
|
||||
>>> str_to_dtype("fp8")
|
||||
torch.float8_e4m3fn
|
||||
>>> str_to_dtype("fp8_e4m3fn")
|
||||
torch.float8_e4m3fn
|
||||
>>> str_to_dtype("fp8_e4m3fnuz")
|
||||
torch.float8_e4m3fnuz
|
||||
>>> str_to_dtype("fp8_e5m2")
|
||||
torch.float8_e5m2
|
||||
>>> str_to_dtype("fp8_e5m2fnuz")
|
||||
torch.float8_e5m2fnuz
|
||||
"""
|
||||
if s is None:
|
||||
return default_dtype
|
||||
if s in ["bf16", "bfloat16"]:
|
||||
return torch.bfloat16
|
||||
elif s in ["fp16", "float16"]:
|
||||
return torch.float16
|
||||
elif s in ["fp32", "float32", "float"]:
|
||||
return torch.float32
|
||||
elif s in ["fp8_e4m3fn", "e4m3fn", "float8_e4m3fn"]:
|
||||
return torch.float8_e4m3fn
|
||||
elif s in ["fp8_e4m3fnuz", "e4m3fnuz", "float8_e4m3fnuz"]:
|
||||
return torch.float8_e4m3fnuz
|
||||
elif s in ["fp8_e5m2", "e5m2", "float8_e5m2"]:
|
||||
return torch.float8_e5m2
|
||||
elif s in ["fp8_e5m2fnuz", "e5m2fnuz", "float8_e5m2fnuz"]:
|
||||
return torch.float8_e5m2fnuz
|
||||
elif s in ["fp8", "float8"]:
|
||||
return torch.float8_e4m3fn # default fp8
|
||||
else:
|
||||
raise ValueError(f"Unsupported dtype: {s}")
|
||||
|
||||
|
||||
def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None):
|
||||
"""
|
||||
memory efficient save file
|
||||
@@ -198,7 +258,7 @@ class MemoryEfficientSafeOpen:
|
||||
if tensor_bytes is None:
|
||||
byte_tensor = torch.empty(0, dtype=torch.uint8)
|
||||
else:
|
||||
tensor_bytes = bytearray(tensor_bytes) # make it writable
|
||||
tensor_bytes = bytearray(tensor_bytes) # make it writable
|
||||
byte_tensor = torch.frombuffer(tensor_bytes, dtype=torch.uint8)
|
||||
|
||||
# process float8 types
|
||||
|
||||
@@ -8,7 +8,7 @@ from safetensors import safe_open
|
||||
from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
|
||||
from library.utils import setup_logging
|
||||
from library.utils import setup_logging, str_to_dtype, MemoryEfficientSafeOpen, mem_eff_save_file
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
@@ -34,18 +34,23 @@ def load_state_dict(file_name, dtype):
|
||||
return sd, metadata
|
||||
|
||||
|
||||
def save_to_file(file_name, state_dict, dtype, metadata):
|
||||
def save_to_file(file_name, state_dict, dtype, metadata, mem_eff_save=False):
|
||||
if dtype is not None:
|
||||
logger.info(f"converting to {dtype}...")
|
||||
for key in list(state_dict.keys()):
|
||||
for key in tqdm(list(state_dict.keys())):
|
||||
if type(state_dict[key]) == torch.Tensor:
|
||||
state_dict[key] = state_dict[key].to(dtype)
|
||||
|
||||
logger.info(f"saving to: {file_name}")
|
||||
save_file(state_dict, file_name, metadata=metadata)
|
||||
if mem_eff_save:
|
||||
mem_eff_save_file(state_dict, file_name, metadata=metadata)
|
||||
else:
|
||||
save_file(state_dict, file_name, metadata=metadata)
|
||||
|
||||
|
||||
def merge_to_flux_model(loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype):
|
||||
def merge_to_flux_model(
|
||||
loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype, mem_eff_load_save=False
|
||||
):
|
||||
# create module map without loading state_dict
|
||||
logger.info(f"loading keys from FLUX.1 model: {flux_model}")
|
||||
lora_name_to_module_key = {}
|
||||
@@ -57,7 +62,14 @@ def merge_to_flux_model(loading_device, working_device, flux_model, models, rati
|
||||
lora_name = lora_flux.LoRANetwork.LORA_PREFIX_FLUX + "_" + module_name.replace(".", "_")
|
||||
lora_name_to_module_key[lora_name] = key
|
||||
|
||||
flux_state_dict = load_file(flux_model, device=loading_device)
|
||||
if mem_eff_load_save:
|
||||
flux_state_dict = {}
|
||||
with MemoryEfficientSafeOpen(flux_model) as flux_file:
|
||||
for key in tqdm(flux_file.keys()):
|
||||
flux_state_dict[key] = flux_file.get_tensor(key).to(loading_device) # dtype is not changed
|
||||
else:
|
||||
flux_state_dict = load_file(flux_model, device=loading_device)
|
||||
|
||||
for model, ratio in zip(models, ratios):
|
||||
logger.info(f"loading: {model}")
|
||||
lora_sd, _ = load_state_dict(model, merge_dtype) # loading on CPU
|
||||
@@ -120,9 +132,17 @@ def merge_to_flux_model(loading_device, working_device, flux_model, models, rati
|
||||
return flux_state_dict
|
||||
|
||||
|
||||
def merge_to_flux_model_diffusers(loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype):
|
||||
def merge_to_flux_model_diffusers(
|
||||
loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype, mem_eff_load_save=False
|
||||
):
|
||||
logger.info(f"loading keys from FLUX.1 model: {flux_model}")
|
||||
flux_state_dict = load_file(flux_model, device=loading_device)
|
||||
if mem_eff_load_save:
|
||||
flux_state_dict = {}
|
||||
with MemoryEfficientSafeOpen(flux_model) as flux_file:
|
||||
for key in tqdm(flux_file.keys()):
|
||||
flux_state_dict[key] = flux_file.get_tensor(key).to(loading_device) # dtype is not changed
|
||||
else:
|
||||
flux_state_dict = load_file(flux_model, device=loading_device)
|
||||
|
||||
def create_key_map(n_double_layers, n_single_layers):
|
||||
key_map = {}
|
||||
@@ -474,19 +494,15 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
|
||||
|
||||
|
||||
def merge(args):
|
||||
if args.models is None:
|
||||
args.models = []
|
||||
if args.ratios is None:
|
||||
args.ratios = []
|
||||
|
||||
assert len(args.models) == len(
|
||||
args.ratios
|
||||
), "number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
|
||||
|
||||
def str_to_dtype(p):
|
||||
if p == "float":
|
||||
return torch.float
|
||||
if p == "fp16":
|
||||
return torch.float16
|
||||
if p == "bf16":
|
||||
return torch.bfloat16
|
||||
return None
|
||||
|
||||
merge_dtype = str_to_dtype(args.precision)
|
||||
save_dtype = str_to_dtype(args.save_precision)
|
||||
if save_dtype is None:
|
||||
@@ -500,11 +516,25 @@ def merge(args):
|
||||
if args.flux_model is not None:
|
||||
if not args.diffusers:
|
||||
state_dict = merge_to_flux_model(
|
||||
args.loading_device, args.working_device, args.flux_model, args.models, args.ratios, merge_dtype, save_dtype
|
||||
args.loading_device,
|
||||
args.working_device,
|
||||
args.flux_model,
|
||||
args.models,
|
||||
args.ratios,
|
||||
merge_dtype,
|
||||
save_dtype,
|
||||
args.mem_eff_load_save,
|
||||
)
|
||||
else:
|
||||
state_dict = merge_to_flux_model_diffusers(
|
||||
args.loading_device, args.working_device, args.flux_model, args.models, args.ratios, merge_dtype, save_dtype
|
||||
args.loading_device,
|
||||
args.working_device,
|
||||
args.flux_model,
|
||||
args.models,
|
||||
args.ratios,
|
||||
merge_dtype,
|
||||
save_dtype,
|
||||
args.mem_eff_load_save,
|
||||
)
|
||||
|
||||
if args.no_metadata:
|
||||
@@ -517,7 +547,7 @@ def merge(args):
|
||||
)
|
||||
|
||||
logger.info(f"saving FLUX model to: {args.save_to}")
|
||||
save_to_file(args.save_to, state_dict, save_dtype, sai_metadata)
|
||||
save_to_file(args.save_to, state_dict, save_dtype, sai_metadata, args.mem_eff_load_save)
|
||||
|
||||
else:
|
||||
state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle)
|
||||
@@ -546,14 +576,14 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
"--save_precision",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=[None, "float", "fp16", "bf16"],
|
||||
help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ",
|
||||
help="precision in saving, same to merging if omitted. supported types: "
|
||||
"float32, fp16, bf16, fp8 (same as fp8_e4m3fn), fp8_e4m3fn, fp8_e4m3fnuz, fp8_e5m2, fp8_e5m2fnuz"
|
||||
" / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--precision",
|
||||
type=str,
|
||||
default="float",
|
||||
choices=["float", "fp16", "bf16"],
|
||||
help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -562,6 +592,12 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
default=None,
|
||||
help="FLUX.1 model to load, merge LoRA models if omitted / 読み込むモデル、指定しない場合はLoRAモデルをマージする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mem_eff_load_save",
|
||||
action="store_true",
|
||||
help="use custom memory efficient load and save functions for FLUX.1 model"
|
||||
" / カスタムのメモリ効率の良い読み込みと保存関数をFLUX.1モデルに使用する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--loading_device",
|
||||
type=str,
|
||||
|
||||
Reference in New Issue
Block a user