Add save_n_epoch_ratio

This commit is contained in:
forestsource
2023-01-22 02:57:12 +09:00
parent cae42728ab
commit 5e817e4343
4 changed files with 8 additions and 0 deletions

View File

@@ -192,6 +192,8 @@ def train(args):
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
# 学習する
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps