Merge pull request #92 from forestsource/add_save_n_epoch_ratio

Add save_n_epoch_ratio
This commit is contained in:
Kohya S
2023-01-24 18:59:47 +09:00
committed by GitHub
4 changed files with 8 additions and 0 deletions

View File

@@ -212,6 +212,8 @@ def train(args):
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
# 学習する
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps