mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Add no_half_vae for SDXL training, add nan check
This commit is contained in:
@@ -906,6 +906,11 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
|
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
|
||||||
|
|
||||||
|
# check NaN
|
||||||
|
for info, latents1 in zip(batch, latents):
|
||||||
|
if torch.isnan(latents1).any():
|
||||||
|
raise RuntimeError(f"NaN detected in latents: {info.absolute_path}")
|
||||||
|
|
||||||
for info, latent in zip(batch, latents):
|
for info, latent in zip(batch, latents):
|
||||||
if cache_to_disk:
|
if cache_to_disk:
|
||||||
np.savez(
|
np.savez(
|
||||||
|
|||||||
@@ -109,6 +109,7 @@ def train(args):
|
|||||||
|
|
||||||
# mixed precisionに対応した型を用意しておき適宜castする
|
# mixed precisionに対応した型を用意しておき適宜castする
|
||||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||||
|
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
|
||||||
|
|
||||||
# モデルを読み込む
|
# モデルを読み込む
|
||||||
(
|
(
|
||||||
@@ -165,8 +166,7 @@ def train(args):
|
|||||||
|
|
||||||
# 学習を準備する
|
# 学習を準備する
|
||||||
if cache_latents:
|
if cache_latents:
|
||||||
# vae.to(accelerator.device, dtype=weight_dtype)
|
vae.to(accelerator.device, dtype=vae_dtype)
|
||||||
vae.to(accelerator.device, dtype=torch.float32) # VAE in float to avoid NaN
|
|
||||||
vae.requires_grad_(False)
|
vae.requires_grad_(False)
|
||||||
vae.eval()
|
vae.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -201,7 +201,7 @@ def train(args):
|
|||||||
if not cache_latents:
|
if not cache_latents:
|
||||||
vae.requires_grad_(False)
|
vae.requires_grad_(False)
|
||||||
vae.eval()
|
vae.eval()
|
||||||
vae.to(accelerator.device, dtype=weight_dtype)
|
vae.to(accelerator.device, dtype=vae_dtype)
|
||||||
|
|
||||||
for m in training_models:
|
for m in training_models:
|
||||||
m.requires_grad_(True)
|
m.requires_grad_(True)
|
||||||
@@ -342,8 +342,12 @@ def train(args):
|
|||||||
else:
|
else:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# latentに変換
|
# latentに変換
|
||||||
# latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
|
latents = vae.encode(batch["images"].to(vae_dtype)).latent_dist.sample().to(weight_dtype)
|
||||||
latents = vae.encode(batch["images"].to(torch.float32)).latent_dist.sample().to(weight_dtype)
|
|
||||||
|
# NaNが含まれていれば警告を表示し0に置き換える
|
||||||
|
if torch.any(torch.isnan(latents)):
|
||||||
|
accelerator.print("NaN found in latents, replacing with zeros")
|
||||||
|
latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
|
||||||
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
|
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
|
||||||
b_size = latents.shape[0]
|
b_size = latents.shape[0]
|
||||||
|
|
||||||
@@ -592,6 +596,11 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
|
|
||||||
parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
|
parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
|
||||||
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
|
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
|
||||||
|
parser.add_argument(
|
||||||
|
"--no_half_vae",
|
||||||
|
action="store_true",
|
||||||
|
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|||||||
@@ -207,6 +207,7 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
# mixed precisionに対応した型を用意しておき適宜castする
|
# mixed precisionに対応した型を用意しておき適宜castする
|
||||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||||
|
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
|
||||||
|
|
||||||
# モデルを読み込む
|
# モデルを読み込む
|
||||||
model_version, text_encoder, vae, unet = self.load_target_model(args, weight_dtype, accelerator)
|
model_version, text_encoder, vae, unet = self.load_target_model(args, weight_dtype, accelerator)
|
||||||
@@ -216,6 +217,7 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
# モデルに xformers とか memory efficient attention を組み込む
|
# モデルに xformers とか memory efficient attention を組み込む
|
||||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
||||||
|
vae.set_use_memory_efficient_attention_xformers(args.xformers)
|
||||||
|
|
||||||
# 差分追加学習のためにモデルを読み込む
|
# 差分追加学習のためにモデルを読み込む
|
||||||
sys.path.append(os.path.dirname(__file__))
|
sys.path.append(os.path.dirname(__file__))
|
||||||
@@ -241,7 +243,7 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
# 学習を準備する
|
# 学習を準備する
|
||||||
if cache_latents:
|
if cache_latents:
|
||||||
vae.to(accelerator.device, dtype=weight_dtype)
|
vae.to(accelerator.device, dtype=vae_dtype)
|
||||||
vae.requires_grad_(False)
|
vae.requires_grad_(False)
|
||||||
vae.eval()
|
vae.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -415,7 +417,7 @@ class NetworkTrainer:
|
|||||||
if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する
|
if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する
|
||||||
vae.requires_grad_(False)
|
vae.requires_grad_(False)
|
||||||
vae.eval()
|
vae.eval()
|
||||||
vae.to(accelerator.device, dtype=weight_dtype)
|
vae.to(accelerator.device, dtype=vae_dtype)
|
||||||
|
|
||||||
# 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される
|
# 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される
|
||||||
self.cache_text_encoder_outputs_if_needed(
|
self.cache_text_encoder_outputs_if_needed(
|
||||||
@@ -721,7 +723,12 @@ class NetworkTrainer:
|
|||||||
latents = batch["latents"].to(accelerator.device)
|
latents = batch["latents"].to(accelerator.device)
|
||||||
else:
|
else:
|
||||||
# latentに変換
|
# latentに変換
|
||||||
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
|
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample()
|
||||||
|
|
||||||
|
# NaNが含まれていれば警告を表示し0に置き換える
|
||||||
|
if torch.any(torch.isnan(latents)):
|
||||||
|
accelerator.print("NaN found in latents, replacing with zeros")
|
||||||
|
latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
|
||||||
latents = latents * self.vae_scale_factor
|
latents = latents * self.vae_scale_factor
|
||||||
b_size = latents.shape[0]
|
b_size = latents.shape[0]
|
||||||
|
|
||||||
@@ -863,9 +870,7 @@ class NetworkTrainer:
|
|||||||
if args.save_state:
|
if args.save_state:
|
||||||
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
|
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
|
||||||
|
|
||||||
self.sample_images(
|
self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||||
accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
|
|
||||||
)
|
|
||||||
|
|
||||||
# end of epoch
|
# end of epoch
|
||||||
|
|
||||||
@@ -961,6 +966,11 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
nargs="*",
|
nargs="*",
|
||||||
help="multiplier for network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みの倍率",
|
help="multiplier for network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みの倍率",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no_half_vae",
|
||||||
|
action="store_true",
|
||||||
|
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
|
||||||
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user