mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Add epoch number to metadata
This commit is contained in:
@@ -194,7 +194,7 @@ def train(args):
|
|||||||
print(f" num epochs / epoch数: {num_train_epochs}")
|
print(f" num epochs / epoch数: {num_train_epochs}")
|
||||||
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
||||||
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
||||||
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
||||||
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
||||||
|
|
||||||
metadata = {
|
metadata = {
|
||||||
@@ -249,6 +249,7 @@ def train(args):
|
|||||||
|
|
||||||
for epoch in range(num_train_epochs):
|
for epoch in range(num_train_epochs):
|
||||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||||
|
metadata["ss_epoch"] = str(epoch+1)
|
||||||
|
|
||||||
network.on_epoch_start(text_encoder, unet)
|
network.on_epoch_start(text_encoder, unet)
|
||||||
|
|
||||||
@@ -352,6 +353,8 @@ def train(args):
|
|||||||
|
|
||||||
# end of epoch
|
# end of epoch
|
||||||
|
|
||||||
|
metadata["ss_epoch"] = str(num_train_epochs)
|
||||||
|
|
||||||
is_main_process = accelerator.is_main_process
|
is_main_process = accelerator.is_main_process
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
network = unwrap_model(network)
|
network = unwrap_model(network)
|
||||||
|
|||||||
Reference in New Issue
Block a user