Fix fp16 mixed precision, model is in bf16 without full_bf16

This commit is contained in:
Kohya S
2024-06-29 17:21:25 +09:00
parent 66cf435479
commit 19086465e8
4 changed files with 61 additions and 15 deletions

View File

@@ -4,21 +4,28 @@ This repository contains training, generation and utility scripts for Stable Dif
SD3 training is done with `sd3_train.py`. SD3 training is done with `sd3_train.py`.
__Jun 29, 2024__: Fixed mixed precision training with fp16 is not working. Fixed the model is in bf16 dtype even without `--full_bf16` option (this could worsen the training result).
`fp16` and `bf16` are available for mixed precision training. We are not sure which is better.
`optimizer_type = "adafactor"` is recommended for 24GB VRAM GPUs. `cache_text_encoder_outputs_to_disk` and `cache_latents_to_disk` are necessary currently. `optimizer_type = "adafactor"` is recommended for 24GB VRAM GPUs. `cache_text_encoder_outputs_to_disk` and `cache_latents_to_disk` are necessary currently.
`clip_l`, `clip_g` and `t5xxl` can be specified if the checkpoint does not include them. `clip_l`, `clip_g` and `t5xxl` can be specified if the checkpoint does not include them.
t5xxl doesn't seem to work with `fp16`, so use`bf16` or `fp32`. t5xxl doesn't seem to work with `fp16`, so 1) use`bf16` for mixed precision, or 2) use `bf16` or `float32` for `t5xxl_dtype`.
There are `t5xxl_device` and `t5xxl_dtype` options for `t5xxl` device and dtype. There are `t5xxl_device` and `t5xxl_dtype` options for `t5xxl` device and dtype.
`text_encoder_batch_size` is added experimentally for caching faster.
```toml ```toml
learning_rate = 1e-5 # seems to be too high learning_rate = 1e-6 # seems to depend on the batch size
optimizer_type = "adafactor" optimizer_type = "adafactor"
optimizer_args = [ "scale_parameter=False", "relative_step=False", "warmup_init=False" ] optimizer_args = [ "scale_parameter=False", "relative_step=False", "warmup_init=False" ]
cache_text_encoder_outputs = true cache_text_encoder_outputs = true
cache_text_encoder_outputs_to_disk = true cache_text_encoder_outputs_to_disk = true
vae_batch_size = 1 vae_batch_size = 1
text_encoder_batch_size = 4
cache_latents = true cache_latents = true
cache_latents_to_disk = true cache_latents_to_disk = true
``` ```

View File

@@ -28,14 +28,14 @@ logger = logging.getLogger(__name__)
from .sdxl_train_util import match_mixed_precision from .sdxl_train_util import match_mixed_precision
def load_target_model(args, accelerator, attn_mode, weight_dtype, t5xxl_device, t5xxl_dtype) -> Tuple[ def load_target_model(args, accelerator, attn_mode, weight_dtype, clip_dtype, t5xxl_device, t5xxl_dtype, vae_dtype) -> Tuple[
sd3_models.MMDiT, sd3_models.MMDiT,
Optional[sd3_models.SDClipModel], Optional[sd3_models.SDClipModel],
Optional[sd3_models.SDXLClipG], Optional[sd3_models.SDXLClipG],
Optional[sd3_models.T5XXLModel], Optional[sd3_models.T5XXLModel],
sd3_models.SDVAE, sd3_models.SDVAE,
]: ]:
model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16 model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16, None or fp16/bf16
for pi in range(accelerator.state.num_processes): for pi in range(accelerator.state.num_processes):
if pi == accelerator.state.local_process_index: if pi == accelerator.state.local_process_index:
@@ -49,13 +49,15 @@ def load_target_model(args, accelerator, attn_mode, weight_dtype, t5xxl_device,
args.vae, args.vae,
attn_mode, attn_mode,
accelerator.device if args.lowram else "cpu", accelerator.device if args.lowram else "cpu",
weight_dtype, model_dtype,
args.disable_mmap_load_safetensors, args.disable_mmap_load_safetensors,
clip_dtype,
t5xxl_device, t5xxl_device,
t5xxl_dtype, t5xxl_dtype,
vae_dtype,
) )
# work on low-ram device # work on low-ram device: models are already loaded on accelerator.device, but we ensure they are on device
if args.lowram: if args.lowram:
if clip_l is not None: if clip_l is not None:
clip_l.to(accelerator.device) clip_l.to(accelerator.device)

View File

@@ -28,11 +28,41 @@ def load_models(
vae_path: str, vae_path: str,
attn_mode: str, attn_mode: str,
device: Union[str, torch.device], device: Union[str, torch.device],
weight_dtype: torch.dtype, default_dtype: Optional[Union[str, torch.dtype]] = None,
disable_mmap: bool = False, disable_mmap: bool = False,
t5xxl_device: Optional[str] = None, clip_dtype: Optional[Union[str, torch.dtype]] = None,
t5xxl_dtype: Optional[str] = None, t5xxl_device: Optional[Union[str, torch.device]] = None,
t5xxl_dtype: Optional[Union[str, torch.dtype]] = None,
vae_dtype: Optional[Union[str, torch.dtype]] = None,
): ):
"""
Load SD3 models from checkpoint files.
Args:
ckpt_path: Path to the SD3 checkpoint file.
clip_l_path: Path to the clip_l checkpoint file.
clip_g_path: Path to the clip_g checkpoint file.
t5xxl_path: Path to the t5xxl checkpoint file.
vae_path: Path to the VAE checkpoint file.
attn_mode: Attention mode for MMDiT model.
device: Device for MMDiT model.
default_dtype: Default dtype for each model. In training, it's usually None. None means using float32.
disable_mmap: Disable memory mapping when loading state dict.
clip_dtype: Dtype for Clip models, or None to use default dtype.
t5xxl_device: Device for T5XXL model to load T5XXL in another device (eg. gpu). Default is None to use device.
t5xxl_dtype: Dtype for T5XXL model, or None to use default dtype.
vae_dtype: Dtype for VAE model, or None to use default dtype.
Returns:
Tuple of MMDiT, ClipL, ClipG, T5XXL, and VAE models.
"""
# In SD1/2 and SDXL, the model is created with empty weights and then loaded with state dict.
# However, in SD3, Clip and T5XXL models are created with dtype, so we need to set dtype before loading state dict.
# Therefore, we need clip_dtype and t5xxl_dtype.
# default_dtype is used for full_fp16/full_bf16 training.
def load_state_dict(path: str, dvc: Union[str, torch.device] = device): def load_state_dict(path: str, dvc: Union[str, torch.device] = device):
if disable_mmap: if disable_mmap:
return safetensors.torch.load(open(path, "rb").read()) return safetensors.torch.load(open(path, "rb").read())
@@ -43,6 +73,9 @@ def load_models(
return load_file(path) # prevent device invalid Error return load_file(path) # prevent device invalid Error
t5xxl_device = t5xxl_device or device t5xxl_device = t5xxl_device or device
clip_dtype = clip_dtype or default_dtype or torch.float32
t5xxl_dtype = t5xxl_dtype or default_dtype or torch.float32
vae_dtype = vae_dtype or default_dtype or torch.float32
logger.info(f"Loading SD3 models from {ckpt_path}...") logger.info(f"Loading SD3 models from {ckpt_path}...")
state_dict = load_state_dict(ckpt_path) state_dict = load_state_dict(ckpt_path)
@@ -124,7 +157,7 @@ def load_models(
mmdit = sd3_models.create_mmdit_sd3_medium_configs(attn_mode) mmdit = sd3_models.create_mmdit_sd3_medium_configs(attn_mode)
logger.info("Loading state dict...") logger.info("Loading state dict...")
info = sdxl_model_util._load_state_dict_on_device(mmdit, state_dict, device, weight_dtype) info = sdxl_model_util._load_state_dict_on_device(mmdit, state_dict, device, default_dtype)
logger.info(f"Loaded MMDiT: {info}") logger.info(f"Loaded MMDiT: {info}")
# load ClipG and ClipL # load ClipG and ClipL
@@ -132,7 +165,7 @@ def load_models(
clip_l = None clip_l = None
else: else:
logger.info("Building ClipL") logger.info("Building ClipL")
clip_l = sd3_models.create_clip_l(device, weight_dtype, clip_l_sd) clip_l = sd3_models.create_clip_l(device, clip_dtype, clip_l_sd)
logger.info("Loading state dict...") logger.info("Loading state dict...")
info = clip_l.load_state_dict(clip_l_sd) info = clip_l.load_state_dict(clip_l_sd)
logger.info(f"Loaded ClipL: {info}") logger.info(f"Loaded ClipL: {info}")
@@ -142,7 +175,7 @@ def load_models(
clip_g = None clip_g = None
else: else:
logger.info("Building ClipG") logger.info("Building ClipG")
clip_g = sd3_models.create_clip_g(device, weight_dtype, clip_g_sd) clip_g = sd3_models.create_clip_g(device, clip_dtype, clip_g_sd)
logger.info("Loading state dict...") logger.info("Loading state dict...")
info = clip_g.load_state_dict(clip_g_sd) info = clip_g.load_state_dict(clip_g_sd)
logger.info(f"Loaded ClipG: {info}") logger.info(f"Loaded ClipG: {info}")
@@ -165,6 +198,7 @@ def load_models(
logger.info("Loading state dict...") logger.info("Loading state dict...")
info = vae.load_state_dict(vae_sd) info = vae.load_state_dict(vae_sd)
logger.info(f"Loaded VAE: {info}") logger.info(f"Loaded VAE: {info}")
vae.to(device=device, dtype=vae_dtype)
return mmdit, clip_l, clip_g, t5xxl, vae return mmdit, clip_l, clip_g, t5xxl, vae

View File

@@ -182,6 +182,8 @@ def train(args):
raise ValueError(f"unexpected t5xxl_dtype: {args.t5xxl_dtype}") raise ValueError(f"unexpected t5xxl_dtype: {args.t5xxl_dtype}")
t5xxl_device = accelerator.device if args.t5xxl_device is None else args.t5xxl_device t5xxl_device = accelerator.device if args.t5xxl_device is None else args.t5xxl_device
clip_dtype = weight_dtype # if not args.train_text_encoder else None
# モデルを読み込む # モデルを読み込む
attn_mode = "xformers" if args.xformers else "torch" attn_mode = "xformers" if args.xformers else "torch"
@@ -189,8 +191,9 @@ def train(args):
attn_mode == "torch" attn_mode == "torch"
), f"attn_mode {attn_mode} is not supported. Please use `--sdpa` instead of `--xformers`. / attn_mode {attn_mode} はサポートされていません。`--xformers`の代わりに`--sdpa`を使ってください。" ), f"attn_mode {attn_mode} is not supported. Please use `--sdpa` instead of `--xformers`. / attn_mode {attn_mode} はサポートされていません。`--xformers`の代わりに`--sdpa`を使ってください。"
# models are usually loaded on CPU and moved to GPU later. This is to avoid OOM on GPU0.
mmdit, clip_l, clip_g, t5xxl, vae = sd3_train_utils.load_target_model( mmdit, clip_l, clip_g, t5xxl, vae = sd3_train_utils.load_target_model(
args, accelerator, attn_mode, weight_dtype, t5xxl_device, t5xxl_dtype args, accelerator, attn_mode, None, clip_dtype, t5xxl_device, t5xxl_dtype, vae_dtype
) )
assert clip_l is not None, "clip_l is required / clip_lは必須です" assert clip_l is not None, "clip_l is required / clip_lは必須です"
assert clip_g is not None, "clip_g is required / clip_gは必須です" assert clip_g is not None, "clip_g is required / clip_gは必須です"
@@ -868,8 +871,9 @@ def setup_parser() -> argparse.ArgumentParser:
custom_train_functions.add_custom_train_arguments(parser) custom_train_functions.add_custom_train_arguments(parser)
sd3_train_utils.add_sd3_training_arguments(parser) sd3_train_utils.add_sd3_training_arguments(parser)
# TE training is disabled temporarily # parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
# TE training is disabled temporarily
# parser.add_argument( # parser.add_argument(
# "--learning_rate_te1", # "--learning_rate_te1",
# type=float, # type=float,
@@ -886,7 +890,6 @@ def setup_parser() -> argparse.ArgumentParser:
# parser.add_argument( # parser.add_argument(
# "--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する" # "--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する"
# ) # )
# parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
# parser.add_argument( # parser.add_argument(
# "--no_half_vae", # "--no_half_vae",
# action="store_true", # action="store_true",