mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Fix fp16 mixed precision, model is in bf16 without full_bf16
This commit is contained in:
11
README.md
11
README.md
@@ -4,21 +4,28 @@ This repository contains training, generation and utility scripts for Stable Dif
|
|||||||
|
|
||||||
SD3 training is done with `sd3_train.py`.
|
SD3 training is done with `sd3_train.py`.
|
||||||
|
|
||||||
|
__Jun 29, 2024__: Fixed mixed precision training with fp16 is not working. Fixed the model is in bf16 dtype even without `--full_bf16` option (this could worsen the training result).
|
||||||
|
|
||||||
|
`fp16` and `bf16` are available for mixed precision training. We are not sure which is better.
|
||||||
|
|
||||||
`optimizer_type = "adafactor"` is recommended for 24GB VRAM GPUs. `cache_text_encoder_outputs_to_disk` and `cache_latents_to_disk` are necessary currently.
|
`optimizer_type = "adafactor"` is recommended for 24GB VRAM GPUs. `cache_text_encoder_outputs_to_disk` and `cache_latents_to_disk` are necessary currently.
|
||||||
|
|
||||||
`clip_l`, `clip_g` and `t5xxl` can be specified if the checkpoint does not include them.
|
`clip_l`, `clip_g` and `t5xxl` can be specified if the checkpoint does not include them.
|
||||||
|
|
||||||
t5xxl doesn't seem to work with `fp16`, so use`bf16` or `fp32`.
|
t5xxl doesn't seem to work with `fp16`, so 1) use`bf16` for mixed precision, or 2) use `bf16` or `float32` for `t5xxl_dtype`.
|
||||||
|
|
||||||
There are `t5xxl_device` and `t5xxl_dtype` options for `t5xxl` device and dtype.
|
There are `t5xxl_device` and `t5xxl_dtype` options for `t5xxl` device and dtype.
|
||||||
|
|
||||||
|
`text_encoder_batch_size` is added experimentally for caching faster.
|
||||||
|
|
||||||
```toml
|
```toml
|
||||||
learning_rate = 1e-5 # seems to be too high
|
learning_rate = 1e-6 # seems to depend on the batch size
|
||||||
optimizer_type = "adafactor"
|
optimizer_type = "adafactor"
|
||||||
optimizer_args = [ "scale_parameter=False", "relative_step=False", "warmup_init=False" ]
|
optimizer_args = [ "scale_parameter=False", "relative_step=False", "warmup_init=False" ]
|
||||||
cache_text_encoder_outputs = true
|
cache_text_encoder_outputs = true
|
||||||
cache_text_encoder_outputs_to_disk = true
|
cache_text_encoder_outputs_to_disk = true
|
||||||
vae_batch_size = 1
|
vae_batch_size = 1
|
||||||
|
text_encoder_batch_size = 4
|
||||||
cache_latents = true
|
cache_latents = true
|
||||||
cache_latents_to_disk = true
|
cache_latents_to_disk = true
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -28,14 +28,14 @@ logger = logging.getLogger(__name__)
|
|||||||
from .sdxl_train_util import match_mixed_precision
|
from .sdxl_train_util import match_mixed_precision
|
||||||
|
|
||||||
|
|
||||||
def load_target_model(args, accelerator, attn_mode, weight_dtype, t5xxl_device, t5xxl_dtype) -> Tuple[
|
def load_target_model(args, accelerator, attn_mode, weight_dtype, clip_dtype, t5xxl_device, t5xxl_dtype, vae_dtype) -> Tuple[
|
||||||
sd3_models.MMDiT,
|
sd3_models.MMDiT,
|
||||||
Optional[sd3_models.SDClipModel],
|
Optional[sd3_models.SDClipModel],
|
||||||
Optional[sd3_models.SDXLClipG],
|
Optional[sd3_models.SDXLClipG],
|
||||||
Optional[sd3_models.T5XXLModel],
|
Optional[sd3_models.T5XXLModel],
|
||||||
sd3_models.SDVAE,
|
sd3_models.SDVAE,
|
||||||
]:
|
]:
|
||||||
model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16
|
model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16, None or fp16/bf16
|
||||||
|
|
||||||
for pi in range(accelerator.state.num_processes):
|
for pi in range(accelerator.state.num_processes):
|
||||||
if pi == accelerator.state.local_process_index:
|
if pi == accelerator.state.local_process_index:
|
||||||
@@ -49,13 +49,15 @@ def load_target_model(args, accelerator, attn_mode, weight_dtype, t5xxl_device,
|
|||||||
args.vae,
|
args.vae,
|
||||||
attn_mode,
|
attn_mode,
|
||||||
accelerator.device if args.lowram else "cpu",
|
accelerator.device if args.lowram else "cpu",
|
||||||
weight_dtype,
|
model_dtype,
|
||||||
args.disable_mmap_load_safetensors,
|
args.disable_mmap_load_safetensors,
|
||||||
|
clip_dtype,
|
||||||
t5xxl_device,
|
t5xxl_device,
|
||||||
t5xxl_dtype,
|
t5xxl_dtype,
|
||||||
|
vae_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
# work on low-ram device
|
# work on low-ram device: models are already loaded on accelerator.device, but we ensure they are on device
|
||||||
if args.lowram:
|
if args.lowram:
|
||||||
if clip_l is not None:
|
if clip_l is not None:
|
||||||
clip_l.to(accelerator.device)
|
clip_l.to(accelerator.device)
|
||||||
|
|||||||
@@ -28,11 +28,41 @@ def load_models(
|
|||||||
vae_path: str,
|
vae_path: str,
|
||||||
attn_mode: str,
|
attn_mode: str,
|
||||||
device: Union[str, torch.device],
|
device: Union[str, torch.device],
|
||||||
weight_dtype: torch.dtype,
|
default_dtype: Optional[Union[str, torch.dtype]] = None,
|
||||||
disable_mmap: bool = False,
|
disable_mmap: bool = False,
|
||||||
t5xxl_device: Optional[str] = None,
|
clip_dtype: Optional[Union[str, torch.dtype]] = None,
|
||||||
t5xxl_dtype: Optional[str] = None,
|
t5xxl_device: Optional[Union[str, torch.device]] = None,
|
||||||
|
t5xxl_dtype: Optional[Union[str, torch.dtype]] = None,
|
||||||
|
vae_dtype: Optional[Union[str, torch.dtype]] = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Load SD3 models from checkpoint files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ckpt_path: Path to the SD3 checkpoint file.
|
||||||
|
clip_l_path: Path to the clip_l checkpoint file.
|
||||||
|
clip_g_path: Path to the clip_g checkpoint file.
|
||||||
|
t5xxl_path: Path to the t5xxl checkpoint file.
|
||||||
|
vae_path: Path to the VAE checkpoint file.
|
||||||
|
attn_mode: Attention mode for MMDiT model.
|
||||||
|
device: Device for MMDiT model.
|
||||||
|
default_dtype: Default dtype for each model. In training, it's usually None. None means using float32.
|
||||||
|
disable_mmap: Disable memory mapping when loading state dict.
|
||||||
|
clip_dtype: Dtype for Clip models, or None to use default dtype.
|
||||||
|
t5xxl_device: Device for T5XXL model to load T5XXL in another device (eg. gpu). Default is None to use device.
|
||||||
|
t5xxl_dtype: Dtype for T5XXL model, or None to use default dtype.
|
||||||
|
vae_dtype: Dtype for VAE model, or None to use default dtype.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of MMDiT, ClipL, ClipG, T5XXL, and VAE models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# In SD1/2 and SDXL, the model is created with empty weights and then loaded with state dict.
|
||||||
|
# However, in SD3, Clip and T5XXL models are created with dtype, so we need to set dtype before loading state dict.
|
||||||
|
# Therefore, we need clip_dtype and t5xxl_dtype.
|
||||||
|
|
||||||
|
# default_dtype is used for full_fp16/full_bf16 training.
|
||||||
|
|
||||||
def load_state_dict(path: str, dvc: Union[str, torch.device] = device):
|
def load_state_dict(path: str, dvc: Union[str, torch.device] = device):
|
||||||
if disable_mmap:
|
if disable_mmap:
|
||||||
return safetensors.torch.load(open(path, "rb").read())
|
return safetensors.torch.load(open(path, "rb").read())
|
||||||
@@ -43,6 +73,9 @@ def load_models(
|
|||||||
return load_file(path) # prevent device invalid Error
|
return load_file(path) # prevent device invalid Error
|
||||||
|
|
||||||
t5xxl_device = t5xxl_device or device
|
t5xxl_device = t5xxl_device or device
|
||||||
|
clip_dtype = clip_dtype or default_dtype or torch.float32
|
||||||
|
t5xxl_dtype = t5xxl_dtype or default_dtype or torch.float32
|
||||||
|
vae_dtype = vae_dtype or default_dtype or torch.float32
|
||||||
|
|
||||||
logger.info(f"Loading SD3 models from {ckpt_path}...")
|
logger.info(f"Loading SD3 models from {ckpt_path}...")
|
||||||
state_dict = load_state_dict(ckpt_path)
|
state_dict = load_state_dict(ckpt_path)
|
||||||
@@ -124,7 +157,7 @@ def load_models(
|
|||||||
mmdit = sd3_models.create_mmdit_sd3_medium_configs(attn_mode)
|
mmdit = sd3_models.create_mmdit_sd3_medium_configs(attn_mode)
|
||||||
|
|
||||||
logger.info("Loading state dict...")
|
logger.info("Loading state dict...")
|
||||||
info = sdxl_model_util._load_state_dict_on_device(mmdit, state_dict, device, weight_dtype)
|
info = sdxl_model_util._load_state_dict_on_device(mmdit, state_dict, device, default_dtype)
|
||||||
logger.info(f"Loaded MMDiT: {info}")
|
logger.info(f"Loaded MMDiT: {info}")
|
||||||
|
|
||||||
# load ClipG and ClipL
|
# load ClipG and ClipL
|
||||||
@@ -132,7 +165,7 @@ def load_models(
|
|||||||
clip_l = None
|
clip_l = None
|
||||||
else:
|
else:
|
||||||
logger.info("Building ClipL")
|
logger.info("Building ClipL")
|
||||||
clip_l = sd3_models.create_clip_l(device, weight_dtype, clip_l_sd)
|
clip_l = sd3_models.create_clip_l(device, clip_dtype, clip_l_sd)
|
||||||
logger.info("Loading state dict...")
|
logger.info("Loading state dict...")
|
||||||
info = clip_l.load_state_dict(clip_l_sd)
|
info = clip_l.load_state_dict(clip_l_sd)
|
||||||
logger.info(f"Loaded ClipL: {info}")
|
logger.info(f"Loaded ClipL: {info}")
|
||||||
@@ -142,7 +175,7 @@ def load_models(
|
|||||||
clip_g = None
|
clip_g = None
|
||||||
else:
|
else:
|
||||||
logger.info("Building ClipG")
|
logger.info("Building ClipG")
|
||||||
clip_g = sd3_models.create_clip_g(device, weight_dtype, clip_g_sd)
|
clip_g = sd3_models.create_clip_g(device, clip_dtype, clip_g_sd)
|
||||||
logger.info("Loading state dict...")
|
logger.info("Loading state dict...")
|
||||||
info = clip_g.load_state_dict(clip_g_sd)
|
info = clip_g.load_state_dict(clip_g_sd)
|
||||||
logger.info(f"Loaded ClipG: {info}")
|
logger.info(f"Loaded ClipG: {info}")
|
||||||
@@ -165,6 +198,7 @@ def load_models(
|
|||||||
logger.info("Loading state dict...")
|
logger.info("Loading state dict...")
|
||||||
info = vae.load_state_dict(vae_sd)
|
info = vae.load_state_dict(vae_sd)
|
||||||
logger.info(f"Loaded VAE: {info}")
|
logger.info(f"Loaded VAE: {info}")
|
||||||
|
vae.to(device=device, dtype=vae_dtype)
|
||||||
|
|
||||||
return mmdit, clip_l, clip_g, t5xxl, vae
|
return mmdit, clip_l, clip_g, t5xxl, vae
|
||||||
|
|
||||||
|
|||||||
@@ -182,6 +182,8 @@ def train(args):
|
|||||||
raise ValueError(f"unexpected t5xxl_dtype: {args.t5xxl_dtype}")
|
raise ValueError(f"unexpected t5xxl_dtype: {args.t5xxl_dtype}")
|
||||||
t5xxl_device = accelerator.device if args.t5xxl_device is None else args.t5xxl_device
|
t5xxl_device = accelerator.device if args.t5xxl_device is None else args.t5xxl_device
|
||||||
|
|
||||||
|
clip_dtype = weight_dtype # if not args.train_text_encoder else None
|
||||||
|
|
||||||
# モデルを読み込む
|
# モデルを読み込む
|
||||||
attn_mode = "xformers" if args.xformers else "torch"
|
attn_mode = "xformers" if args.xformers else "torch"
|
||||||
|
|
||||||
@@ -189,8 +191,9 @@ def train(args):
|
|||||||
attn_mode == "torch"
|
attn_mode == "torch"
|
||||||
), f"attn_mode {attn_mode} is not supported. Please use `--sdpa` instead of `--xformers`. / attn_mode {attn_mode} はサポートされていません。`--xformers`の代わりに`--sdpa`を使ってください。"
|
), f"attn_mode {attn_mode} is not supported. Please use `--sdpa` instead of `--xformers`. / attn_mode {attn_mode} はサポートされていません。`--xformers`の代わりに`--sdpa`を使ってください。"
|
||||||
|
|
||||||
|
# models are usually loaded on CPU and moved to GPU later. This is to avoid OOM on GPU0.
|
||||||
mmdit, clip_l, clip_g, t5xxl, vae = sd3_train_utils.load_target_model(
|
mmdit, clip_l, clip_g, t5xxl, vae = sd3_train_utils.load_target_model(
|
||||||
args, accelerator, attn_mode, weight_dtype, t5xxl_device, t5xxl_dtype
|
args, accelerator, attn_mode, None, clip_dtype, t5xxl_device, t5xxl_dtype, vae_dtype
|
||||||
)
|
)
|
||||||
assert clip_l is not None, "clip_l is required / clip_lは必須です"
|
assert clip_l is not None, "clip_l is required / clip_lは必須です"
|
||||||
assert clip_g is not None, "clip_g is required / clip_gは必須です"
|
assert clip_g is not None, "clip_g is required / clip_gは必須です"
|
||||||
@@ -868,8 +871,9 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
custom_train_functions.add_custom_train_arguments(parser)
|
custom_train_functions.add_custom_train_arguments(parser)
|
||||||
sd3_train_utils.add_sd3_training_arguments(parser)
|
sd3_train_utils.add_sd3_training_arguments(parser)
|
||||||
|
|
||||||
# TE training is disabled temporarily
|
# parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
|
||||||
|
|
||||||
|
# TE training is disabled temporarily
|
||||||
# parser.add_argument(
|
# parser.add_argument(
|
||||||
# "--learning_rate_te1",
|
# "--learning_rate_te1",
|
||||||
# type=float,
|
# type=float,
|
||||||
@@ -886,7 +890,6 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
# parser.add_argument(
|
# parser.add_argument(
|
||||||
# "--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する"
|
# "--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する"
|
||||||
# )
|
# )
|
||||||
# parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
|
|
||||||
# parser.add_argument(
|
# parser.add_argument(
|
||||||
# "--no_half_vae",
|
# "--no_half_vae",
|
||||||
# action="store_true",
|
# action="store_true",
|
||||||
|
|||||||
Reference in New Issue
Block a user