From 2c461e4ad39f00189114671d8a3fcdd748e6447d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 26 Jun 2023 20:38:09 +0900 Subject: [PATCH] Add no_half_vae for SDXL training, add nan check --- library/train_util.py | 5 +++++ sdxl_train.py | 19 ++++++++++++++----- train_network.py | 22 ++++++++++++++++------ 3 files changed, 35 insertions(+), 11 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index e609705e..c27e1b27 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -906,6 +906,11 @@ class BaseDataset(torch.utils.data.Dataset): latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") + # check NaN + for info, latents1 in zip(batch, latents): + if torch.isnan(latents1).any(): + raise RuntimeError(f"NaN detected in latents: {info.absolute_path}") + for info, latent in zip(batch, latents): if cache_to_disk: np.savez( diff --git a/sdxl_train.py b/sdxl_train.py index 2683038b..56240744 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -109,6 +109,7 @@ def train(args): # mixed precisionに対応した型を用意しておき適宜castする weight_dtype, save_dtype = train_util.prepare_dtype(args) + vae_dtype = torch.float32 if args.no_half_vae else weight_dtype # モデルを読み込む ( @@ -165,8 +166,7 @@ def train(args): # 学習を準備する if cache_latents: - # vae.to(accelerator.device, dtype=weight_dtype) - vae.to(accelerator.device, dtype=torch.float32) # VAE in float to avoid NaN + vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() with torch.no_grad(): @@ -201,7 +201,7 @@ def train(args): if not cache_latents: vae.requires_grad_(False) vae.eval() - vae.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=vae_dtype) for m in training_models: m.requires_grad_(True) @@ -342,8 +342,12 @@ def train(args): else: with torch.no_grad(): # latentに変換 - # latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() - latents = vae.encode(batch["images"].to(torch.float32)).latent_dist.sample().to(weight_dtype) + latents = vae.encode(batch["images"].to(vae_dtype)).latent_dist.sample().to(weight_dtype) + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents) latents = latents * sdxl_model_util.VAE_SCALE_FACTOR b_size = latents.shape[0] @@ -592,6 +596,11 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する") parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") + parser.add_argument( + "--no_half_vae", + action="store_true", + help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", + ) return parser diff --git a/train_network.py b/train_network.py index e2920db4..2bb94b1e 100644 --- a/train_network.py +++ b/train_network.py @@ -207,6 +207,7 @@ class NetworkTrainer: # mixed precisionに対応した型を用意しておき適宜castする weight_dtype, save_dtype = train_util.prepare_dtype(args) + vae_dtype = torch.float32 if args.no_half_vae else weight_dtype # モデルを読み込む model_version, text_encoder, vae, unet = self.load_target_model(args, weight_dtype, accelerator) @@ -216,6 +217,7 @@ class NetworkTrainer: # モデルに xformers とか memory efficient attention を組み込む train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) + vae.set_use_memory_efficient_attention_xformers(args.xformers) # 差分追加学習のためにモデルを読み込む sys.path.append(os.path.dirname(__file__)) @@ -241,7 +243,7 @@ class NetworkTrainer: # 学習を準備する if cache_latents: - vae.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() with torch.no_grad(): @@ -415,7 +417,7 @@ class NetworkTrainer: if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する vae.requires_grad_(False) vae.eval() - vae.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=vae_dtype) # 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される self.cache_text_encoder_outputs_if_needed( @@ -721,7 +723,12 @@ class NetworkTrainer: latents = batch["latents"].to(accelerator.device) else: # latentに変換 - latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() + latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample() + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents) latents = latents * self.vae_scale_factor b_size = latents.shape[0] @@ -863,9 +870,7 @@ class NetworkTrainer: if args.save_state: train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) - self.sample_images( - accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet - ) + self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) # end of epoch @@ -961,6 +966,11 @@ def setup_parser() -> argparse.ArgumentParser: nargs="*", help="multiplier for network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みの倍率", ) + parser.add_argument( + "--no_half_vae", + action="store_true", + help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", + ) return parser