diff --git a/library/config_util.py b/library/config_util.py index 9c8c90c2..6817d9a3 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -499,11 +499,12 @@ def load_user_config(file: str) -> dict: def blueprint_args_conflict(args,blueprint:Blueprint): # train_dataset_group.set_current_epoch()とtrain_dataset_group.set_current_step()がWorkerを生成するタイミングで適用される影響で、persistent_workers有効時はずっと一定になってしまうため無効にする - for b in blueprint.dataset_group.datasets: - for t in b.subsets: - if args.persistent_data_loader_workers and (t.params.caption_dropout_every_n_epochs > 0 or t.params.token_warmup_step>0): - print("Warning: %s: --persistent_data_loader_workers option is disabled because it conflicts with caption_dropout_every_n_epochs and token_wormup_step. / caption_dropout_every_n_epochs及びtoken_warmup_stepと競合するため、--persistent_data_loader_workersオプションは無効になります。"%(t.params.image_dir)) - args.persistent_data_loader_workers = False + # for b in blueprint.dataset_group.datasets: + # for t in b.subsets: + # if args.persistent_data_loader_workers and (t.params.caption_dropout_every_n_epochs > 0 or t.params.token_warmup_step>0): + # print("Warning: %s: --persistent_data_loader_workers option is disabled because it conflicts with caption_dropout_every_n_epochs and token_wormup_step. / caption_dropout_every_n_epochs及びtoken_warmup_stepと競合するため、--persistent_data_loader_workersオプションは無効になります。"%(t.params.image_dir)) + # # args.persistent_data_loader_workers = False + return # for config test if __name__ == "__main__": diff --git a/library/train_util.py b/library/train_util.py index d1df9c58..223e403b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -517,6 +517,7 @@ class BaseDataset(torch.utils.data.Dataset): else: caption = caption.replace(str_from, str_to) + print(self.current_step, self.max_train_steps, caption) return caption def get_input_ids(self, caption): diff --git a/train_network.py b/train_network.py index 02a2d925..e148a92a 100644 --- a/train_network.py +++ b/train_network.py @@ -8,6 +8,7 @@ import random import time import json import toml +from multiprocessing import Value from tqdm import tqdm import torch @@ -24,9 +25,15 @@ from library.config_util import ( BlueprintGenerator, ) - -def collate_fn(examples): - return examples[0] +class collater_class: + def __init__(self,epoch,step): + self.current_epoch=epoch + self.current_step=step + def __call__(self, examples): + dataset = torch.utils.data.get_worker_info().dataset + dataset.set_current_epoch(self.current_epoch.value) + dataset.set_current_step(self.current_step.value) + return examples[0] # TODO 他のスクリプトと共通化する @@ -101,6 +108,10 @@ def train(args): config_util.blueprint_args_conflict(args,blueprint) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + current_epoch = Value('i',0) + current_step = Value('i',0) + collater = collater_class(current_epoch,current_step) + if args.debug_dataset: train_util.debug_dataset(train_dataset_group) return @@ -190,7 +201,7 @@ def train(args): train_dataset_group, batch_size=1, shuffle=True, - collate_fn=collate_fn, + collate_fn=collater, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) @@ -501,14 +512,14 @@ def train(args): for epoch in range(num_train_epochs): if is_main_process: print(f"epoch {epoch+1}/{num_train_epochs}") - train_dataset_group.set_current_epoch(epoch + 1) - train_dataset_group.set_current_step(global_step) + current_epoch.value = epoch+1 metadata["ss_epoch"] = str(epoch + 1) network.on_epoch_start(text_encoder, unet) for step, batch in enumerate(train_dataloader): + current_step.value = global_step with accelerator.accumulate(network): with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: