diff --git a/fine_tune.py b/fine_tune.py index 8e615203..a0ef978e 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -163,7 +163,7 @@ def train(args): # DataLoaderのプロセス数:0はメインプロセスになる n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで train_dataloader = torch.utils.data.DataLoader( - train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers) + train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers) # 学習ステップ数を計算する if args.max_train_epochs is not None: diff --git a/library/train_util.py b/library/train_util.py index c1e54517..aea1bfd1 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1192,6 +1192,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)") parser.add_argument("--max_data_loader_n_workers", type=int, default=8, help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)") + parser.add_argument("--persistent_data_loader_workers", action="store_true", + help="persistent DataLoader workers (useful for reduce time gap between epoch, but may use more memory) / DataLoader のワーカーを持続させる (エポック間の時間差を少なくするのに有効だが、より多くのメモリを消費する可能性がある)") parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed") parser.add_argument("--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / grandient checkpointingを有効にする") diff --git a/train_db.py b/train_db.py index fe6fd4e6..bf25aae4 100644 --- a/train_db.py +++ b/train_db.py @@ -133,7 +133,7 @@ def train(args): # DataLoaderのプロセス数:0はメインプロセスになる n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで train_dataloader = torch.utils.data.DataLoader( - train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers) + train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers) # 学習ステップ数を計算する if args.max_train_epochs is not None: diff --git a/train_network.py b/train_network.py index 88405221..6dd1a732 100644 --- a/train_network.py +++ b/train_network.py @@ -214,7 +214,7 @@ def train(args): # DataLoaderのプロセス数:0はメインプロセスになる n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで train_dataloader = torch.utils.data.DataLoader( - train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers) + train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers) # 学習ステップ数を計算する if args.max_train_epochs is not None: diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 35b4ede6..ea70195b 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -217,7 +217,7 @@ def train(args): # DataLoaderのプロセス数:0はメインプロセスになる n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで train_dataloader = torch.utils.data.DataLoader( - train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers) + train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers) # 学習ステップ数を計算する if args.max_train_epochs is not None: