diff --git a/train_network.py b/train_network.py index 02a2d925..ef10921f 100644 --- a/train_network.py +++ b/train_network.py @@ -8,6 +8,7 @@ import random import time import json import toml +from multiprocessing import Value from tqdm import tqdm import torch @@ -24,9 +25,18 @@ from library.config_util import ( BlueprintGenerator, ) - -def collate_fn(examples): - return examples[0] +class collater_class: + def __init__(self,epoch,step): + self.current_epoch=epoch + self.current_step=step + def __call__(self, examples): + dataset = torch.utils.data.get_worker_info().dataset + dataset.set_current_epoch(self.current_epoch.value) + dataset.set_current_step(self.current_step.value) + # print("self.current_step:%d"%self.current_step) + # print("dataset_lengh:%d"%len(dataset)) + print("id(self)(collate):%d"%id(self)) + return examples[0] # TODO 他のスクリプトと共通化する @@ -101,6 +111,10 @@ def train(args): config_util.blueprint_args_conflict(args,blueprint) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + current_epoch = Value('i',0) + current_step = Value('i',0) + collater = collater_class(current_epoch,current_step) + if args.debug_dataset: train_util.debug_dataset(train_dataset_group) return @@ -186,11 +200,12 @@ def train(args): # dataloaderを準備する # DataLoaderのプロセス数:0はメインプロセスになる n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + train_dataloader = torch.utils.data.DataLoader( train_dataset_group, batch_size=1, shuffle=True, - collate_fn=collate_fn, + collate_fn=collater, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) @@ -498,17 +513,18 @@ def train(args): loss_list = [] loss_total = 0.0 + del train_dataset_group for epoch in range(num_train_epochs): if is_main_process: print(f"epoch {epoch+1}/{num_train_epochs}") - train_dataset_group.set_current_epoch(epoch + 1) - train_dataset_group.set_current_step(global_step) + current_epoch.value = epoch+1 metadata["ss_epoch"] = str(epoch + 1) network.on_epoch_start(text_encoder, unet) for step, batch in enumerate(train_dataloader): + current_step.value = global_step with accelerator.accumulate(network): with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: