From 5647f12bc386e5abe73ba9b474f4e2e4ec1d5083 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 29 Mar 2025 01:58:52 +0800 Subject: [PATCH] Update train_network.py --- train_network.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/train_network.py b/train_network.py index ac0747c0..f2a6241b 100644 --- a/train_network.py +++ b/train_network.py @@ -913,12 +913,12 @@ class NetworkTrainer: if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) else: - if args.train_vae_batch is None or len(batch["images"]) <= args.train_vae_batch: + if args.vae_batch_size is None or len(batch["images"]) <= args.vae_batch_size: with torch.no_grad(): # latentに変換 latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype) else: - chunks = [batch["images"][i:i + args.train_vae_batch] for i in range(0, len(batch["images"]), args.train_vae_batch)] + chunks = [batch["images"][i:i + args.vae_batch_size] for i in range(0, len(batch["images"]), args.vae_batch_size)] list_latents = [] for chunk in chunks: with torch.no_grad(): @@ -1240,12 +1240,7 @@ def setup_parser() -> argparse.ArgumentParser: help="initial step number including all epochs, 0 means first step (same as not specifying). overwrites initial_epoch." + " / 初期ステップ数、全エポックを含むステップ数、0で最初のステップ(未指定時と同じ)。initial_epochを上書きする", ) - parser.add_argument( - "--train_vae_batch", - type=int, - default=None, - help="Sets batch size of VAE when not using cached latents", - ) + # parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio") # parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio") # parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio")