データセットにepoch、stepが通達されないバグ修正

This commit is contained in:
u-haru
2023-03-26 01:44:25 +09:00
parent 1b89b2a10e
commit 292cdb8379
3 changed files with 24 additions and 11 deletions

View File

@@ -499,11 +499,12 @@ def load_user_config(file: str) -> dict:
def blueprint_args_conflict(args,blueprint:Blueprint): def blueprint_args_conflict(args,blueprint:Blueprint):
# train_dataset_group.set_current_epoch()とtrain_dataset_group.set_current_step()がWorkerを生成するタイミングで適用される影響で、persistent_workers有効時はずっと一定になってしまうため無効にする # train_dataset_group.set_current_epoch()とtrain_dataset_group.set_current_step()がWorkerを生成するタイミングで適用される影響で、persistent_workers有効時はずっと一定になってしまうため無効にする
for b in blueprint.dataset_group.datasets: # for b in blueprint.dataset_group.datasets:
for t in b.subsets: # for t in b.subsets:
if args.persistent_data_loader_workers and (t.params.caption_dropout_every_n_epochs > 0 or t.params.token_warmup_step>0): # if args.persistent_data_loader_workers and (t.params.caption_dropout_every_n_epochs > 0 or t.params.token_warmup_step>0):
print("Warning: %s: --persistent_data_loader_workers option is disabled because it conflicts with caption_dropout_every_n_epochs and token_wormup_step. / caption_dropout_every_n_epochs及びtoken_warmup_stepと競合するため、--persistent_data_loader_workersオプションは無効になります。"%(t.params.image_dir)) # print("Warning: %s: --persistent_data_loader_workers option is disabled because it conflicts with caption_dropout_every_n_epochs and token_wormup_step. / caption_dropout_every_n_epochs及びtoken_warmup_stepと競合するため、--persistent_data_loader_workersオプションは無効になります。"%(t.params.image_dir))
args.persistent_data_loader_workers = False # # args.persistent_data_loader_workers = False
return
# for config test # for config test
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -517,6 +517,7 @@ class BaseDataset(torch.utils.data.Dataset):
else: else:
caption = caption.replace(str_from, str_to) caption = caption.replace(str_from, str_to)
print(self.current_step, self.max_train_steps, caption)
return caption return caption
def get_input_ids(self, caption): def get_input_ids(self, caption):

View File

@@ -8,6 +8,7 @@ import random
import time import time
import json import json
import toml import toml
from multiprocessing import Value
from tqdm import tqdm from tqdm import tqdm
import torch import torch
@@ -24,8 +25,14 @@ from library.config_util import (
BlueprintGenerator, BlueprintGenerator,
) )
class collater_class:
def collate_fn(examples): def __init__(self,epoch,step):
self.current_epoch=epoch
self.current_step=step
def __call__(self, examples):
dataset = torch.utils.data.get_worker_info().dataset
dataset.set_current_epoch(self.current_epoch.value)
dataset.set_current_step(self.current_step.value)
return examples[0] return examples[0]
@@ -101,6 +108,10 @@ def train(args):
config_util.blueprint_args_conflict(args,blueprint) config_util.blueprint_args_conflict(args,blueprint)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
current_epoch = Value('i',0)
current_step = Value('i',0)
collater = collater_class(current_epoch,current_step)
if args.debug_dataset: if args.debug_dataset:
train_util.debug_dataset(train_dataset_group) train_util.debug_dataset(train_dataset_group)
return return
@@ -190,7 +201,7 @@ def train(args):
train_dataset_group, train_dataset_group,
batch_size=1, batch_size=1,
shuffle=True, shuffle=True,
collate_fn=collate_fn, collate_fn=collater,
num_workers=n_workers, num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers, persistent_workers=args.persistent_data_loader_workers,
) )
@@ -501,14 +512,14 @@ def train(args):
for epoch in range(num_train_epochs): for epoch in range(num_train_epochs):
if is_main_process: if is_main_process:
print(f"epoch {epoch+1}/{num_train_epochs}") print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset_group.set_current_epoch(epoch + 1) current_epoch.value = epoch+1
train_dataset_group.set_current_step(global_step)
metadata["ss_epoch"] = str(epoch + 1) metadata["ss_epoch"] = str(epoch + 1)
network.on_epoch_start(text_encoder, unet) network.on_epoch_start(text_encoder, unet)
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(network): with accelerator.accumulate(network):
with torch.no_grad(): with torch.no_grad():
if "latents" in batch and batch["latents"] is not None: if "latents" in batch and batch["latents"] is not None: