Add no_half_vae for SDXL training, add nan check

This commit is contained in:
Kohya S
2023-06-26 20:38:09 +09:00
parent 56ca5dfa15
commit 2c461e4ad3
3 changed files with 35 additions and 11 deletions

View File

@@ -109,6 +109,7 @@ def train(args):
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
# モデルを読み込む
(
@@ -165,8 +166,7 @@ def train(args):
# 学習を準備する
if cache_latents:
# vae.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=torch.float32) # VAE in float to avoid NaN
vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
@@ -201,7 +201,7 @@ def train(args):
if not cache_latents:
vae.requires_grad_(False)
vae.eval()
vae.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=vae_dtype)
for m in training_models:
m.requires_grad_(True)
@@ -342,8 +342,12 @@ def train(args):
else:
with torch.no_grad():
# latentに変換
# latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
latents = vae.encode(batch["images"].to(torch.float32)).latent_dist.sample().to(weight_dtype)
latents = vae.encode(batch["images"].to(vae_dtype)).latent_dist.sample().to(weight_dtype)
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
b_size = latents.shape[0]
@@ -592,6 +596,11 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
parser.add_argument(
"--no_half_vae",
action="store_true",
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
)
return parser