From 0c4423d9dc65465cf1a65762aca6c8b642a52759 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Tue, 10 Jan 2023 02:50:04 -0800 Subject: [PATCH] Add epoch number to metadata --- train_network.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index 3f45bae0..c5593c46 100644 --- a/train_network.py +++ b/train_network.py @@ -194,7 +194,7 @@ def train(args): print(f" num epochs / epoch数: {num_train_epochs}") print(f" batch size per device / バッチサイズ: {args.train_batch_size}") print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") - print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") metadata = { @@ -249,6 +249,7 @@ def train(args): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") + metadata["ss_epoch"] = str(epoch+1) network.on_epoch_start(text_encoder, unet) @@ -352,6 +353,8 @@ def train(args): # end of epoch + metadata["ss_epoch"] = str(num_train_epochs) + is_main_process = accelerator.is_main_process if is_main_process: network = unwrap_model(network)