Add epoch number to metadata

This commit is contained in:
space-nuko
2023-01-10 02:50:04 -08:00
parent 2e4ce0fdff
commit 0c4423d9dc

View File

@@ -194,7 +194,7 @@ def train(args):
print(f" num epochs / epoch数: {num_train_epochs}") print(f" num epochs / epoch数: {num_train_epochs}")
print(f" batch size per device / バッチサイズ: {args.train_batch_size}") print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
metadata = { metadata = {
@@ -249,6 +249,7 @@ def train(args):
for epoch in range(num_train_epochs): for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}") print(f"epoch {epoch+1}/{num_train_epochs}")
metadata["ss_epoch"] = str(epoch+1)
network.on_epoch_start(text_encoder, unet) network.on_epoch_start(text_encoder, unet)
@@ -352,6 +353,8 @@ def train(args):
# end of epoch # end of epoch
metadata["ss_epoch"] = str(num_train_epochs)
is_main_process = accelerator.is_main_process is_main_process = accelerator.is_main_process
if is_main_process: if is_main_process:
network = unwrap_model(network) network = unwrap_model(network)