typo修正、stepをglobal_stepに修正、バグ修正

This commit is contained in:
u-haru
2023-03-23 09:53:14 +09:00
parent a9b26b73e0
commit 447c56bf50
6 changed files with 7 additions and 7 deletions

View File

@@ -241,6 +241,7 @@ def train(args):
text_encoder.train()
for step, batch in enumerate(train_dataloader):
train_dataset_group.set_current_step(global_step)
# 指定したステップ数でText Encoderの学習を止める
if global_step == args.stop_text_encoder_training:
print(f"stop text encoder training at step {global_step}")
@@ -249,7 +250,6 @@ def train(args):
text_encoder.requires_grad_(False)
with accelerator.accumulate(unet):
train_dataset_group.set_current_step(step + 1)
with torch.no_grad():
# latentに変換
if cache_latents: