Add no_half_vae for SDXL training, add nan check

This commit is contained in:
Kohya S
2023-06-26 20:38:09 +09:00
parent 56ca5dfa15
commit 2c461e4ad3
3 changed files with 35 additions and 11 deletions

View File

@@ -906,6 +906,11 @@ class BaseDataset(torch.utils.data.Dataset):
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
# check NaN
for info, latents1 in zip(batch, latents):
if torch.isnan(latents1).any():
raise RuntimeError(f"NaN detected in latents: {info.absolute_path}")
for info, latent in zip(batch, latents): for info, latent in zip(batch, latents):
if cache_to_disk: if cache_to_disk:
np.savez( np.savez(

View File

@@ -109,6 +109,7 @@ def train(args):
# mixed precisionに対応した型を用意しておき適宜castする # mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args) weight_dtype, save_dtype = train_util.prepare_dtype(args)
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
# モデルを読み込む # モデルを読み込む
( (
@@ -165,8 +166,7 @@ def train(args):
# 学習を準備する # 学習を準備する
if cache_latents: if cache_latents:
# vae.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=vae_dtype)
vae.to(accelerator.device, dtype=torch.float32) # VAE in float to avoid NaN
vae.requires_grad_(False) vae.requires_grad_(False)
vae.eval() vae.eval()
with torch.no_grad(): with torch.no_grad():
@@ -201,7 +201,7 @@ def train(args):
if not cache_latents: if not cache_latents:
vae.requires_grad_(False) vae.requires_grad_(False)
vae.eval() vae.eval()
vae.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=vae_dtype)
for m in training_models: for m in training_models:
m.requires_grad_(True) m.requires_grad_(True)
@@ -342,8 +342,12 @@ def train(args):
else: else:
with torch.no_grad(): with torch.no_grad():
# latentに変換 # latentに変換
# latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() latents = vae.encode(batch["images"].to(vae_dtype)).latent_dist.sample().to(weight_dtype)
latents = vae.encode(batch["images"].to(torch.float32)).latent_dist.sample().to(weight_dtype)
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
b_size = latents.shape[0] b_size = latents.shape[0]
@@ -592,6 +596,11 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する") parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
parser.add_argument(
"--no_half_vae",
action="store_true",
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
)
return parser return parser

View File

@@ -207,6 +207,7 @@ class NetworkTrainer:
# mixed precisionに対応した型を用意しておき適宜castする # mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args) weight_dtype, save_dtype = train_util.prepare_dtype(args)
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
# モデルを読み込む # モデルを読み込む
model_version, text_encoder, vae, unet = self.load_target_model(args, weight_dtype, accelerator) model_version, text_encoder, vae, unet = self.load_target_model(args, weight_dtype, accelerator)
@@ -216,6 +217,7 @@ class NetworkTrainer:
# モデルに xformers とか memory efficient attention を組み込む # モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
vae.set_use_memory_efficient_attention_xformers(args.xformers)
# 差分追加学習のためにモデルを読み込む # 差分追加学習のためにモデルを読み込む
sys.path.append(os.path.dirname(__file__)) sys.path.append(os.path.dirname(__file__))
@@ -241,7 +243,7 @@ class NetworkTrainer:
# 学習を準備する # 学習を準備する
if cache_latents: if cache_latents:
vae.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False) vae.requires_grad_(False)
vae.eval() vae.eval()
with torch.no_grad(): with torch.no_grad():
@@ -415,7 +417,7 @@ class NetworkTrainer:
if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する
vae.requires_grad_(False) vae.requires_grad_(False)
vae.eval() vae.eval()
vae.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=vae_dtype)
# 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される # 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される
self.cache_text_encoder_outputs_if_needed( self.cache_text_encoder_outputs_if_needed(
@@ -721,7 +723,12 @@ class NetworkTrainer:
latents = batch["latents"].to(accelerator.device) latents = batch["latents"].to(accelerator.device)
else: else:
# latentに変換 # latentに変換
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample()
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
latents = latents * self.vae_scale_factor latents = latents * self.vae_scale_factor
b_size = latents.shape[0] b_size = latents.shape[0]
@@ -863,9 +870,7 @@ class NetworkTrainer:
if args.save_state: if args.save_state:
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
self.sample_images( self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
)
# end of epoch # end of epoch
@@ -961,6 +966,11 @@ def setup_parser() -> argparse.ArgumentParser:
nargs="*", nargs="*",
help="multiplier for network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みの倍率", help="multiplier for network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みの倍率",
) )
parser.add_argument(
"--no_half_vae",
action="store_true",
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
)
return parser return parser