diff --git a/train_network.py b/train_network.py index 2f8797d2..c7354873 100644 --- a/train_network.py +++ b/train_network.py @@ -1493,6 +1493,8 @@ class NetworkTrainer: ) progress_bar.unpause() + clean_memory_on_device(accelerator.device) + # 指定ステップごとにモデルを保存 if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: accelerator.wait_for_everyone() @@ -1708,6 +1710,8 @@ class NetworkTrainer: progress_bar.unpause() optimizer_train_fn() + clean_memory_on_device(accelerator.device) + # end of epoch # metadata["ss_epoch"] = str(num_train_epochs)