implement token warmup

This commit is contained in:
u-haru
2023-03-23 07:37:14 +09:00
parent 432353185c
commit a9b26b73e0
6 changed files with 75 additions and 2 deletions

View File

@@ -162,6 +162,9 @@ def train(args):
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
if args.stop_text_encoder_training is None:
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
@@ -246,6 +249,7 @@ def train(args):
text_encoder.requires_grad_(False)
with accelerator.accumulate(unet):
train_dataset_group.set_current_step(step + 1)
with torch.no_grad():
# latentに変換
if cache_latents: