mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 23:01:22 +00:00
Update train_network.py
This commit is contained in:
@@ -913,12 +913,12 @@ class NetworkTrainer:
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
|
||||
else:
|
||||
if args.train_vae_batch is None or len(batch["images"]) <= args.train_vae_batch:
|
||||
if args.vae_batch_size is None or len(batch["images"]) <= args.vae_batch_size:
|
||||
with torch.no_grad():
|
||||
# latentに変換
|
||||
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype)
|
||||
else:
|
||||
chunks = [batch["images"][i:i + args.train_vae_batch] for i in range(0, len(batch["images"]), args.train_vae_batch)]
|
||||
chunks = [batch["images"][i:i + args.vae_batch_size] for i in range(0, len(batch["images"]), args.vae_batch_size)]
|
||||
list_latents = []
|
||||
for chunk in chunks:
|
||||
with torch.no_grad():
|
||||
@@ -1240,12 +1240,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="initial step number including all epochs, 0 means first step (same as not specifying). overwrites initial_epoch."
|
||||
+ " / 初期ステップ数、全エポックを含むステップ数、0で最初のステップ(未指定時と同じ)。initial_epochを上書きする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_vae_batch",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Sets batch size of VAE when not using cached latents",
|
||||
)
|
||||
|
||||
# parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio")
|
||||
# parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio")
|
||||
# parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio")
|
||||
|
||||
Reference in New Issue
Block a user