Update train_network.py

This commit is contained in:
DKnight54
2025-03-28 17:48:47 +08:00
committed by GitHub
parent 258c8b5e80
commit 666468857b

View File

@@ -913,14 +913,23 @@ class NetworkTrainer:
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
with torch.no_grad():
# latentに変換
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype)
if args.train_vae_batch is None or len(batch["latents"] <= args.train_vae_batch
with torch.no_grad():
# latentに変換
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype)
else:
chunks = [batch["images"][i:i + args.train_vae_batch] for i in range(0, len(batch["images"]), args.train_vae_batch)]
list_latents = []
for chunk in chunks:
with torch.no_grad():
# latentに変換
list_latents.append(vae.encode(chunk.to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype))
latents = torch.cat(list_latents, dim=0)
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.nan_to_num(latents, 0, out=latents)
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.nan_to_num(latents, 0, out=latents)
logger.info(f"Latents Shape: {latents.shape}")
latents = latents * self.vae_scale_factor
# get multiplier for each sample
@@ -1231,6 +1240,12 @@ def setup_parser() -> argparse.ArgumentParser:
help="initial step number including all epochs, 0 means first step (same as not specifying). overwrites initial_epoch."
+ " / 初期ステップ数、全エポックを含むステップ数、0で最初のステップ未指定時と同じ。initial_epochを上書きする",
)
parser.add_argument(
"--train_vae_batch",
type=int,
default=None,
help="Sets batch size of VAE when not using cached latents",
)
# parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio")
# parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio")
# parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio")