diff --git a/fine_tune.py b/fine_tune.py index 96ec3d96..f387179a 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -59,7 +59,6 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - config_util.blueprint_args_conflict(args,blueprint) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) current_epoch = Value('i',0) diff --git a/library/config_util.py b/library/config_util.py index 6817d9a3..b1543f63 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -497,15 +497,6 @@ def load_user_config(file: str) -> dict: return config -def blueprint_args_conflict(args,blueprint:Blueprint): - # train_dataset_group.set_current_epoch()とtrain_dataset_group.set_current_step()がWorkerを生成するタイミングで適用される影響で、persistent_workers有効時はずっと一定になってしまうため無効にする - # for b in blueprint.dataset_group.datasets: - # for t in b.subsets: - # if args.persistent_data_loader_workers and (t.params.caption_dropout_every_n_epochs > 0 or t.params.token_warmup_step>0): - # print("Warning: %s: --persistent_data_loader_workers option is disabled because it conflicts with caption_dropout_every_n_epochs and token_wormup_step. / caption_dropout_every_n_epochs及びtoken_warmup_stepと競合するため、--persistent_data_loader_workersオプションは無効になります。"%(t.params.image_dir)) - # # args.persistent_data_loader_workers = False - return - # for config test if __name__ == "__main__": parser = argparse.ArgumentParser() diff --git a/library/train_util.py b/library/train_util.py index 1ba38adc..af9637de 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -437,8 +437,9 @@ class BaseDataset(torch.utils.data.Dataset): self.replacements = {} def set_current_epoch(self, epoch): + if not self.current_epoch == epoch: + self.shuffle_buckets() self.current_epoch = epoch - self.shuffle_buckets() def set_current_step(self, step): self.current_step = step diff --git a/train_db.py b/train_db.py index 50d50345..3a3d2df8 100644 --- a/train_db.py +++ b/train_db.py @@ -54,7 +54,6 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - config_util.blueprint_args_conflict(args,blueprint) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) current_epoch = Value('i',0) diff --git a/train_network.py b/train_network.py index 79d118d0..a7dbd374 100644 --- a/train_network.py +++ b/train_network.py @@ -94,7 +94,6 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - config_util.blueprint_args_conflict(args,blueprint) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) current_epoch = Value('i',0) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 4f2e2724..149308b4 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -180,7 +180,6 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - config_util.blueprint_args_conflict(args,blueprint) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) current_epoch = Value('i',0)