diff --git a/train_network.py b/train_network.py index e148a92a..dd1fb748 100644 --- a/train_network.py +++ b/train_network.py @@ -197,6 +197,7 @@ def train(args): # dataloaderを準備する # DataLoaderのプロセス数:0はメインプロセスになる n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + train_dataloader = torch.utils.data.DataLoader( train_dataset_group, batch_size=1, @@ -509,6 +510,7 @@ def train(args): loss_list = [] loss_total = 0.0 + del train_dataset_group for epoch in range(num_train_epochs): if is_main_process: print(f"epoch {epoch+1}/{num_train_epochs}")