diff --git a/README.md b/README.md index 34aa2bb2..3eed636c 100644 --- a/README.md +++ b/README.md @@ -4,21 +4,28 @@ This repository contains training, generation and utility scripts for Stable Dif SD3 training is done with `sd3_train.py`. +__Jun 29, 2024__: Fixed mixed precision training with fp16 is not working. Fixed the model is in bf16 dtype even without `--full_bf16` option (this could worsen the training result). + +`fp16` and `bf16` are available for mixed precision training. We are not sure which is better. + `optimizer_type = "adafactor"` is recommended for 24GB VRAM GPUs. `cache_text_encoder_outputs_to_disk` and `cache_latents_to_disk` are necessary currently. `clip_l`, `clip_g` and `t5xxl` can be specified if the checkpoint does not include them. -t5xxl doesn't seem to work with `fp16`, so use`bf16` or `fp32`. +t5xxl doesn't seem to work with `fp16`, so 1) use`bf16` for mixed precision, or 2) use `bf16` or `float32` for `t5xxl_dtype`. There are `t5xxl_device` and `t5xxl_dtype` options for `t5xxl` device and dtype. +`text_encoder_batch_size` is added experimentally for caching faster. + ```toml -learning_rate = 1e-5 # seems to be too high +learning_rate = 1e-6 # seems to depend on the batch size optimizer_type = "adafactor" optimizer_args = [ "scale_parameter=False", "relative_step=False", "warmup_init=False" ] cache_text_encoder_outputs = true cache_text_encoder_outputs_to_disk = true vae_batch_size = 1 +text_encoder_batch_size = 4 cache_latents = true cache_latents_to_disk = true ``` diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index 70c83c0b..c8d52e1c 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -28,14 +28,14 @@ logger = logging.getLogger(__name__) from .sdxl_train_util import match_mixed_precision -def load_target_model(args, accelerator, attn_mode, weight_dtype, t5xxl_device, t5xxl_dtype) -> Tuple[ +def load_target_model(args, accelerator, attn_mode, weight_dtype, clip_dtype, t5xxl_device, t5xxl_dtype, vae_dtype) -> Tuple[ sd3_models.MMDiT, Optional[sd3_models.SDClipModel], Optional[sd3_models.SDXLClipG], Optional[sd3_models.T5XXLModel], sd3_models.SDVAE, ]: - model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16 + model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16, None or fp16/bf16 for pi in range(accelerator.state.num_processes): if pi == accelerator.state.local_process_index: @@ -49,13 +49,15 @@ def load_target_model(args, accelerator, attn_mode, weight_dtype, t5xxl_device, args.vae, attn_mode, accelerator.device if args.lowram else "cpu", - weight_dtype, + model_dtype, args.disable_mmap_load_safetensors, + clip_dtype, t5xxl_device, t5xxl_dtype, + vae_dtype, ) - # work on low-ram device + # work on low-ram device: models are already loaded on accelerator.device, but we ensure they are on device if args.lowram: if clip_l is not None: clip_l.to(accelerator.device) diff --git a/library/sd3_utils.py b/library/sd3_utils.py index c2c91412..45b49b04 100644 --- a/library/sd3_utils.py +++ b/library/sd3_utils.py @@ -28,11 +28,41 @@ def load_models( vae_path: str, attn_mode: str, device: Union[str, torch.device], - weight_dtype: torch.dtype, + default_dtype: Optional[Union[str, torch.dtype]] = None, disable_mmap: bool = False, - t5xxl_device: Optional[str] = None, - t5xxl_dtype: Optional[str] = None, + clip_dtype: Optional[Union[str, torch.dtype]] = None, + t5xxl_device: Optional[Union[str, torch.device]] = None, + t5xxl_dtype: Optional[Union[str, torch.dtype]] = None, + vae_dtype: Optional[Union[str, torch.dtype]] = None, ): + """ + Load SD3 models from checkpoint files. + + Args: + ckpt_path: Path to the SD3 checkpoint file. + clip_l_path: Path to the clip_l checkpoint file. + clip_g_path: Path to the clip_g checkpoint file. + t5xxl_path: Path to the t5xxl checkpoint file. + vae_path: Path to the VAE checkpoint file. + attn_mode: Attention mode for MMDiT model. + device: Device for MMDiT model. + default_dtype: Default dtype for each model. In training, it's usually None. None means using float32. + disable_mmap: Disable memory mapping when loading state dict. + clip_dtype: Dtype for Clip models, or None to use default dtype. + t5xxl_device: Device for T5XXL model to load T5XXL in another device (eg. gpu). Default is None to use device. + t5xxl_dtype: Dtype for T5XXL model, or None to use default dtype. + vae_dtype: Dtype for VAE model, or None to use default dtype. + + Returns: + Tuple of MMDiT, ClipL, ClipG, T5XXL, and VAE models. + """ + + # In SD1/2 and SDXL, the model is created with empty weights and then loaded with state dict. + # However, in SD3, Clip and T5XXL models are created with dtype, so we need to set dtype before loading state dict. + # Therefore, we need clip_dtype and t5xxl_dtype. + + # default_dtype is used for full_fp16/full_bf16 training. + def load_state_dict(path: str, dvc: Union[str, torch.device] = device): if disable_mmap: return safetensors.torch.load(open(path, "rb").read()) @@ -43,6 +73,9 @@ def load_models( return load_file(path) # prevent device invalid Error t5xxl_device = t5xxl_device or device + clip_dtype = clip_dtype or default_dtype or torch.float32 + t5xxl_dtype = t5xxl_dtype or default_dtype or torch.float32 + vae_dtype = vae_dtype or default_dtype or torch.float32 logger.info(f"Loading SD3 models from {ckpt_path}...") state_dict = load_state_dict(ckpt_path) @@ -124,7 +157,7 @@ def load_models( mmdit = sd3_models.create_mmdit_sd3_medium_configs(attn_mode) logger.info("Loading state dict...") - info = sdxl_model_util._load_state_dict_on_device(mmdit, state_dict, device, weight_dtype) + info = sdxl_model_util._load_state_dict_on_device(mmdit, state_dict, device, default_dtype) logger.info(f"Loaded MMDiT: {info}") # load ClipG and ClipL @@ -132,7 +165,7 @@ def load_models( clip_l = None else: logger.info("Building ClipL") - clip_l = sd3_models.create_clip_l(device, weight_dtype, clip_l_sd) + clip_l = sd3_models.create_clip_l(device, clip_dtype, clip_l_sd) logger.info("Loading state dict...") info = clip_l.load_state_dict(clip_l_sd) logger.info(f"Loaded ClipL: {info}") @@ -142,7 +175,7 @@ def load_models( clip_g = None else: logger.info("Building ClipG") - clip_g = sd3_models.create_clip_g(device, weight_dtype, clip_g_sd) + clip_g = sd3_models.create_clip_g(device, clip_dtype, clip_g_sd) logger.info("Loading state dict...") info = clip_g.load_state_dict(clip_g_sd) logger.info(f"Loaded ClipG: {info}") @@ -165,6 +198,7 @@ def load_models( logger.info("Loading state dict...") info = vae.load_state_dict(vae_sd) logger.info(f"Loaded VAE: {info}") + vae.to(device=device, dtype=vae_dtype) return mmdit, clip_l, clip_g, t5xxl, vae diff --git a/sd3_train.py b/sd3_train.py index b6c932c4..bd30cdc7 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -182,6 +182,8 @@ def train(args): raise ValueError(f"unexpected t5xxl_dtype: {args.t5xxl_dtype}") t5xxl_device = accelerator.device if args.t5xxl_device is None else args.t5xxl_device + clip_dtype = weight_dtype # if not args.train_text_encoder else None + # モデルを読み込む attn_mode = "xformers" if args.xformers else "torch" @@ -189,8 +191,9 @@ def train(args): attn_mode == "torch" ), f"attn_mode {attn_mode} is not supported. Please use `--sdpa` instead of `--xformers`. / attn_mode {attn_mode} はサポートされていません。`--xformers`の代わりに`--sdpa`を使ってください。" + # models are usually loaded on CPU and moved to GPU later. This is to avoid OOM on GPU0. mmdit, clip_l, clip_g, t5xxl, vae = sd3_train_utils.load_target_model( - args, accelerator, attn_mode, weight_dtype, t5xxl_device, t5xxl_dtype + args, accelerator, attn_mode, None, clip_dtype, t5xxl_device, t5xxl_dtype, vae_dtype ) assert clip_l is not None, "clip_l is required / clip_lは必須です" assert clip_g is not None, "clip_g is required / clip_gは必須です" @@ -868,8 +871,9 @@ def setup_parser() -> argparse.ArgumentParser: custom_train_functions.add_custom_train_arguments(parser) sd3_train_utils.add_sd3_training_arguments(parser) - # TE training is disabled temporarily + # parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") + # TE training is disabled temporarily # parser.add_argument( # "--learning_rate_te1", # type=float, @@ -886,7 +890,6 @@ def setup_parser() -> argparse.ArgumentParser: # parser.add_argument( # "--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する" # ) - # parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") # parser.add_argument( # "--no_half_vae", # action="store_true",