implement token warmup

This commit is contained in:
u-haru
2023-03-23 07:37:14 +09:00
parent 432353185c
commit a9b26b73e0
6 changed files with 75 additions and 2 deletions

View File

@@ -260,6 +260,9 @@ def train(args):
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
# lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
@@ -338,6 +341,7 @@ def train(args):
loss_total = 0
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(text_encoder):
train_dataset_group.set_current_step(step + 1)
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)