add --persistent_data_loader_workers option

This commit is contained in:
hitomi
2023-02-01 16:02:15 +08:00
parent 4cabb37977
commit 26a81d075c
5 changed files with 9 additions and 7 deletions

View File

@@ -140,7 +140,7 @@ class BaseDataset(torch.utils.data.Dataset):
if type(str_to) == list:
caption = random.choice(str_to)
else:
caption = str_to
caption = str_to
else:
caption = caption.replace(str_from, str_to)
@@ -246,7 +246,7 @@ class BaseDataset(torch.utils.data.Dataset):
mean_img_ar_error = np.mean(np.abs(img_ar_errors))
self.bucket_info["mean_img_ar_error"] = mean_img_ar_error
print(f"mean ar error (without repeats): {mean_img_ar_error}")
# 参照用indexを作る
self.buckets_indices: list(BucketBatchIndex) = []
@@ -1154,6 +1154,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
help="training epochs (overrides max_train_steps) / 学習エポック数max_train_stepsを上書きします")
parser.add_argument("--max_data_loader_n_workers", type=int, default=8,
help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります")
parser.add_argument("--persistent_data_loader_workers", action="store_true",
help="persistent DataLoader workers (useful for reduce time gap between epoch, but may use more memory) / DataLoader のワーカーを持続させる (エポック間の時間差を少なくするのに有効だが、より多くのメモリを消費する可能性がある)")
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
parser.add_argument("--gradient_checkpointing", action="store_true",
help="enable gradient checkpointing / grandient checkpointingを有効にする")