diff --git a/fine_tune.py b/fine_tune.py index 17608706..7fbc5877 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -242,6 +242,7 @@ def train(args): shuffle=True, collate_fn=collator, num_workers=n_workers, + pin_memory=args.pin_memory, persistent_workers=args.persistent_data_loader_workers, ) diff --git a/finetune/make_captions.py b/finetune/make_captions.py index 489bdbcc..ded9a747 100644 --- a/finetune/make_captions.py +++ b/finetune/make_captions.py @@ -126,6 +126,7 @@ def main(args): batch_size=args.batch_size, shuffle=False, num_workers=args.max_data_loader_n_workers, + pin_memory=args.pin_memory, collate_fn=collate_fn_remove_corrupted, drop_last=False, ) diff --git a/finetune/make_captions_by_git.py b/finetune/make_captions_by_git.py index edeebadf..babdaea5 100644 --- a/finetune/make_captions_by_git.py +++ b/finetune/make_captions_by_git.py @@ -113,6 +113,7 @@ def main(args): dataset, batch_size=args.batch_size, shuffle=False, + pin_memory=args.pin_memory, num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False, diff --git a/finetune/prepare_buckets_latents.py b/finetune/prepare_buckets_latents.py index 019c737a..77b82971 100644 --- a/finetune/prepare_buckets_latents.py +++ b/finetune/prepare_buckets_latents.py @@ -122,6 +122,7 @@ def main(args): dataset, batch_size=1, shuffle=False, + pin_memory=args.pin_memory, num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False, diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index cbc3d2d6..c786e8a6 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -335,6 +335,7 @@ def main(args): dataset, batch_size=args.batch_size, shuffle=False, + pin_memory=args.pin_memory, num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False, diff --git a/flux_train.py b/flux_train.py index fced3bef..4e1b0b4a 100644 --- a/flux_train.py +++ b/flux_train.py @@ -397,6 +397,7 @@ def train(args): shuffle=True, collate_fn=collator, num_workers=n_workers, + pin_memory=args.pin_memory, persistent_workers=args.persistent_data_loader_workers, ) diff --git a/flux_train_control_net.py b/flux_train_control_net.py index 9d36a41d..6a515154 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -398,6 +398,7 @@ def train(args): shuffle=True, collate_fn=collator, num_workers=n_workers, + pin_memory=args.pin_memory, persistent_workers=args.persistent_data_loader_workers, ) diff --git a/library/train_util.py b/library/train_util.py index 72b5b24d..1e6fe3b8 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -176,6 +176,19 @@ class ImageInfo: self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime + @staticmethod + def _pin_tensor(tensor): + return tensor.pin_memory() if tensor is not None else tensor + + def pin_memory(self): + self.latents = self._pin_tensor(self.latents) + self.latents_flipped = self._pin_tensor(self.latents_flipped) + self.text_encoder_outputs1 = self._pin_tensor(self.text_encoder_outputs1) + self.text_encoder_outputs2 = self._pin_tensor(self.text_encoder_outputs2) + self.text_encoder_pool2 = self._pin_tensor(self.text_encoder_pool2) + self.alpha_mask = self._pin_tensor(self.alpha_mask) + return self + class BucketManager: def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None: @@ -2036,6 +2049,11 @@ class DreamBoothDataset(BaseDataset): self.num_reg_images = num_reg_images + def pin_memory(self): + for key in self.image_data.keys(): + if hasattr(self.image_data[key], 'pin_memory') and callable(self.image_data[key].pin_memory): + self.image_data[key].pin_memory() + class FineTuningDataset(BaseDataset): def __init__( @@ -3734,6 +3752,11 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: action="store_true", help="persistent DataLoader workers (useful for reduce time gap between epoch, but may use more memory) / DataLoader のワーカーを持続させる (エポック間の時間差を少なくするのに有効だが、より多くのメモリを消費する可能性がある)", ) + parser.add_argument( + "--pin_memory", + action="store_true", + help="Pin memory for faster GPU loading / GPU の読み込みを高速化するためのピンメモリ", + ) parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed") parser.add_argument( "--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / gradient checkpointingを有効にする" @@ -6379,6 +6402,10 @@ class collator_class: dataset.set_current_step(self.current_step.value) return examples[0] + def pin_memory(self): + if hasattr(self, 'pin_memory') and callable(self.pin_memory): + self.dataset.pin_memory() + class LossRecorder: def __init__(self): diff --git a/sd3_train.py b/sd3_train.py index 120455e7..116e4988 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -498,6 +498,7 @@ def train(args): shuffle=True, collate_fn=collator, num_workers=n_workers, + pin_memory=args.pin_memory, persistent_workers=args.persistent_data_loader_workers, ) diff --git a/sdxl_train.py b/sdxl_train.py index b9d52924..2b60ebba 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -430,6 +430,7 @@ def train(args): shuffle=True, collate_fn=collator, num_workers=n_workers, + pin_memory=args.pin_memory, persistent_workers=args.persistent_data_loader_workers, ) diff --git a/sdxl_train_control_net.py b/sdxl_train_control_net.py index ffbf03ca..32c9996a 100644 --- a/sdxl_train_control_net.py +++ b/sdxl_train_control_net.py @@ -281,6 +281,7 @@ def train(args): shuffle=True, collate_fn=collator, num_workers=n_workers, + pin_memory=args.pin_memory, persistent_workers=args.persistent_data_loader_workers, ) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 365059b7..d74ed99f 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -272,6 +272,7 @@ def train(args): shuffle=True, collate_fn=collator, num_workers=n_workers, + pin_memory=args.pin_memory, persistent_workers=args.persistent_data_loader_workers, ) diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 5b372bef..098f7f56 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -220,6 +220,7 @@ def train(args): shuffle=True, collate_fn=collator, num_workers=n_workers, + pin_memory=args.pin_memory, persistent_workers=args.persistent_data_loader_workers, ) diff --git a/train_db.py b/train_db.py index ad21f8d1..1b5ec198 100644 --- a/train_db.py +++ b/train_db.py @@ -210,6 +210,7 @@ def train(args): shuffle=True, collate_fn=collator, num_workers=n_workers, + pin_memory=args.pin_memory, persistent_workers=args.persistent_data_loader_workers, ) diff --git a/train_network.py b/train_network.py index 5e82b307..7e1665d5 100644 --- a/train_network.py +++ b/train_network.py @@ -577,6 +577,7 @@ class NetworkTrainer: shuffle=True, collate_fn=collator, num_workers=n_workers, + pin_memory=args.pin_memory, persistent_workers=args.persistent_data_loader_workers, ) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 65da4859..14a548a0 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -408,6 +408,7 @@ class TextualInversionTrainer: shuffle=True, collate_fn=collator, num_workers=n_workers, + pin_memory=args.pin_memory, persistent_workers=args.persistent_data_loader_workers, ) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 2a2b4231..f63dac86 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -316,6 +316,7 @@ def train(args): shuffle=True, collate_fn=collator, num_workers=n_workers, + pin_memory=args.pin_memory, persistent_workers=args.persistent_data_loader_workers, )