diff --git a/train_network.py b/train_network.py index 6b8ed9bd..7d082e20 100644 --- a/train_network.py +++ b/train_network.py @@ -1478,6 +1478,8 @@ class NetworkTrainer: ) progress_bar.unpause() + clean_memory_on_device(accelerator.device) + # 指定ステップごとにモデルを保存 if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: accelerator.wait_for_everyone() @@ -1693,6 +1695,8 @@ class NetworkTrainer: progress_bar.unpause() optimizer_train_fn() + clean_memory_on_device(accelerator.device) + # end of epoch # metadata["ss_epoch"] = str(num_train_epochs)