From a4b34a9c3ce36c93ae12d5b4b98ea9f93e1606a3 Mon Sep 17 00:00:00 2001 From: u-haru <40634644+u-haru@users.noreply.github.com> Date: Sun, 26 Mar 2023 03:26:55 +0900 Subject: [PATCH] =?UTF-8?q?blueprint=5Fargs=5Fconflict=E3=81=AF=E4=B8=8D?= =?UTF-8?q?=E8=A6=81=E3=81=AA=E3=81=9F=E3=82=81=E5=89=8A=E9=99=A4=E3=80=81?= =?UTF-8?q?shuffle=E3=81=8C=E6=AF=8E=E5=9B=9E=E8=A1=8C=E3=82=8F=E3=82=8C?= =?UTF-8?q?=E3=82=8B=E4=B8=8D=E5=85=B7=E5=90=88=E4=BF=AE=E6=AD=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fine_tune.py | 1 - library/config_util.py | 9 --------- library/train_util.py | 3 ++- train_db.py | 1 - train_network.py | 1 - train_textual_inversion.py | 1 - 6 files changed, 2 insertions(+), 14 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 96ec3d96..f387179a 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -59,7 +59,6 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - config_util.blueprint_args_conflict(args,blueprint) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) current_epoch = Value('i',0) diff --git a/library/config_util.py b/library/config_util.py index 6817d9a3..b1543f63 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -497,15 +497,6 @@ def load_user_config(file: str) -> dict: return config -def blueprint_args_conflict(args,blueprint:Blueprint): - # train_dataset_group.set_current_epoch()とtrain_dataset_group.set_current_step()がWorkerを生成するタイミングで適用される影響で、persistent_workers有効時はずっと一定になってしまうため無効にする - # for b in blueprint.dataset_group.datasets: - # for t in b.subsets: - # if args.persistent_data_loader_workers and (t.params.caption_dropout_every_n_epochs > 0 or t.params.token_warmup_step>0): - # print("Warning: %s: --persistent_data_loader_workers option is disabled because it conflicts with caption_dropout_every_n_epochs and token_wormup_step. / caption_dropout_every_n_epochs及びtoken_warmup_stepと競合するため、--persistent_data_loader_workersオプションは無効になります。"%(t.params.image_dir)) - # # args.persistent_data_loader_workers = False - return - # for config test if __name__ == "__main__": parser = argparse.ArgumentParser() diff --git a/library/train_util.py b/library/train_util.py index 1ba38adc..af9637de 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -437,8 +437,9 @@ class BaseDataset(torch.utils.data.Dataset): self.replacements = {} def set_current_epoch(self, epoch): + if not self.current_epoch == epoch: + self.shuffle_buckets() self.current_epoch = epoch - self.shuffle_buckets() def set_current_step(self, step): self.current_step = step diff --git a/train_db.py b/train_db.py index 50d50345..3a3d2df8 100644 --- a/train_db.py +++ b/train_db.py @@ -54,7 +54,6 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - config_util.blueprint_args_conflict(args,blueprint) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) current_epoch = Value('i',0) diff --git a/train_network.py b/train_network.py index 79d118d0..a7dbd374 100644 --- a/train_network.py +++ b/train_network.py @@ -94,7 +94,6 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - config_util.blueprint_args_conflict(args,blueprint) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) current_epoch = Value('i',0) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 4f2e2724..149308b4 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -180,7 +180,6 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - config_util.blueprint_args_conflict(args,blueprint) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) current_epoch = Value('i',0)