From 5ec90990de870a4579721db947c2f74b9ce3ed69 Mon Sep 17 00:00:00 2001 From: u-haru <40634644+u-haru@users.noreply.github.com> Date: Sun, 26 Mar 2023 01:41:24 +0900 Subject: [PATCH] =?UTF-8?q?=E3=83=87=E3=83=BC=E3=82=BF=E3=82=BB=E3=83=83?= =?UTF-8?q?=E3=83=88=E3=81=ABepoch=E3=80=81step=E3=81=8C=E9=80=9A=E9=81=94?= =?UTF-8?q?=E3=81=95=E3=82=8C=E3=81=AA=E3=81=84=E3=83=90=E3=82=B0=E4=BF=AE?= =?UTF-8?q?=E6=AD=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- train_network.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/train_network.py b/train_network.py index 02a2d925..ef10921f 100644 --- a/train_network.py +++ b/train_network.py @@ -8,6 +8,7 @@ import random import time import json import toml +from multiprocessing import Value from tqdm import tqdm import torch @@ -24,9 +25,18 @@ from library.config_util import ( BlueprintGenerator, ) - -def collate_fn(examples): - return examples[0] +class collater_class: + def __init__(self,epoch,step): + self.current_epoch=epoch + self.current_step=step + def __call__(self, examples): + dataset = torch.utils.data.get_worker_info().dataset + dataset.set_current_epoch(self.current_epoch.value) + dataset.set_current_step(self.current_step.value) + # print("self.current_step:%d"%self.current_step) + # print("dataset_lengh:%d"%len(dataset)) + print("id(self)(collate):%d"%id(self)) + return examples[0] # TODO 他のスクリプトと共通化する @@ -101,6 +111,10 @@ def train(args): config_util.blueprint_args_conflict(args,blueprint) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + current_epoch = Value('i',0) + current_step = Value('i',0) + collater = collater_class(current_epoch,current_step) + if args.debug_dataset: train_util.debug_dataset(train_dataset_group) return @@ -186,11 +200,12 @@ def train(args): # dataloaderを準備する # DataLoaderのプロセス数:0はメインプロセスになる n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + train_dataloader = torch.utils.data.DataLoader( train_dataset_group, batch_size=1, shuffle=True, - collate_fn=collate_fn, + collate_fn=collater, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) @@ -498,17 +513,18 @@ def train(args): loss_list = [] loss_total = 0.0 + del train_dataset_group for epoch in range(num_train_epochs): if is_main_process: print(f"epoch {epoch+1}/{num_train_epochs}") - train_dataset_group.set_current_epoch(epoch + 1) - train_dataset_group.set_current_step(global_step) + current_epoch.value = epoch+1 metadata["ss_epoch"] = str(epoch + 1) network.on_epoch_start(text_encoder, unet) for step, batch in enumerate(train_dataloader): + current_step.value = global_step with accelerator.accumulate(network): with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: