データセットにepoch、stepが通達されないバグ修正

This commit is contained in:
u-haru
2023-03-26 01:44:25 +09:00
parent 1b89b2a10e
commit 292cdb8379
3 changed files with 24 additions and 11 deletions

View File

@@ -499,11 +499,12 @@ def load_user_config(file: str) -> dict:
def blueprint_args_conflict(args,blueprint:Blueprint):
# train_dataset_group.set_current_epoch()とtrain_dataset_group.set_current_step()がWorkerを生成するタイミングで適用される影響で、persistent_workers有効時はずっと一定になってしまうため無効にする
for b in blueprint.dataset_group.datasets:
for t in b.subsets:
if args.persistent_data_loader_workers and (t.params.caption_dropout_every_n_epochs > 0 or t.params.token_warmup_step>0):
print("Warning: %s: --persistent_data_loader_workers option is disabled because it conflicts with caption_dropout_every_n_epochs and token_wormup_step. / caption_dropout_every_n_epochs及びtoken_warmup_stepと競合するため、--persistent_data_loader_workersオプションは無効になります。"%(t.params.image_dir))
args.persistent_data_loader_workers = False
# for b in blueprint.dataset_group.datasets:
# for t in b.subsets:
# if args.persistent_data_loader_workers and (t.params.caption_dropout_every_n_epochs > 0 or t.params.token_warmup_step>0):
# print("Warning: %s: --persistent_data_loader_workers option is disabled because it conflicts with caption_dropout_every_n_epochs and token_wormup_step. / caption_dropout_every_n_epochs及びtoken_warmup_stepと競合するため、--persistent_data_loader_workersオプションは無効になります。"%(t.params.image_dir))
# # args.persistent_data_loader_workers = False
return
# for config test
if __name__ == "__main__":

View File

@@ -517,6 +517,7 @@ class BaseDataset(torch.utils.data.Dataset):
else:
caption = caption.replace(str_from, str_to)
print(self.current_step, self.max_train_steps, caption)
return caption
def get_input_ids(self, caption):