Merge branch 'kohya-ss:main' into min-SNR

This commit is contained in:
AI-Casanova
2023-03-21 13:19:15 -05:00
committed by GitHub
25 changed files with 253 additions and 82 deletions

View File

@@ -139,7 +139,7 @@ def train(args):
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae)
train_dataset_group.cache_latents(vae, args.vae_batch_size)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
@@ -196,7 +196,7 @@ def train(args):
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes)
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps)
if is_main_process:
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
@@ -654,7 +654,7 @@ def train(args):
print("model saved.")
if __name__ == "__main__":
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser)
@@ -697,6 +697,12 @@ if __name__ == "__main__":
"--training_comment", type=str, default=None, help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列"
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)