From 2e67d74df46df88a9e9b3cae59373b7053b7a40c Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 11 Jul 2023 22:19:14 +0900 Subject: [PATCH] add no_half_vae option --- train_textual_inversion.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 1f085643..cbfd48ce 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -173,6 +173,7 @@ class TextualInversionTrainer: # mixed precisionに対応した型を用意しておき適宜castする weight_dtype, save_dtype = train_util.prepare_dtype(args) + vae_dtype = torch.float32 if args.no_half_vae else weight_dtype # モデルを読み込む model_version, text_encoder_or_list, vae, unet = self.load_target_model(args, weight_dtype, accelerator) @@ -351,7 +352,7 @@ class TextualInversionTrainer: # 学習を準備する if cache_latents: - vae.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() with torch.no_grad(): @@ -447,10 +448,10 @@ class TextualInversionTrainer: else: unet.eval() - if not cache_latents: + if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する vae.requires_grad_(False) vae.eval() - vae.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=vae_dtype) # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする if args.full_fp16: @@ -529,7 +530,7 @@ class TextualInversionTrainer: latents = batch["latents"].to(accelerator.device) else: # latentに変換 - latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() + latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample() latents = latents * self.vae_scale_factor # Get the text embedding for conditioning @@ -744,6 +745,11 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="ignore caption and use default templates for stype / キャプションは使わずデフォルトのスタイル用テンプレートで学習する", ) + parser.add_argument( + "--no_half_vae", + action="store_true", + help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", + ) return parser