diff --git a/fine_tune.py b/fine_tune.py index ffbbbb09..1de1de32 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -244,6 +244,7 @@ def train(args): shuffle=True, collate_fn=collator, num_workers=n_workers, + pin_memory=args.pin_memory, persistent_workers=args.persistent_data_loader_workers, ) diff --git a/finetune/make_captions.py b/finetune/make_captions.py index 489bdbcc..cc9a1444 100644 --- a/finetune/make_captions.py +++ b/finetune/make_captions.py @@ -126,6 +126,7 @@ def main(args): batch_size=args.batch_size, shuffle=False, num_workers=args.max_data_loader_n_workers, + pin_memory=args.pin_memory, collate_fn=collate_fn_remove_corrupted, drop_last=False, ) @@ -187,6 +188,11 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)", ) + parser.add_argument( + "--pin_memory", + action="store_true", + help="Pin memory for faster GPU loading / GPU の読み込みを高速化するためのピンメモリ", + ) parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数(多いと精度が上がるが時間がかかる)") parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p") parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長") diff --git a/finetune/make_captions_by_git.py b/finetune/make_captions_by_git.py index edeebadf..c4c61257 100644 --- a/finetune/make_captions_by_git.py +++ b/finetune/make_captions_by_git.py @@ -113,6 +113,7 @@ def main(args): dataset, batch_size=args.batch_size, shuffle=False, + pin_memory=args.pin_memory, num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False, @@ -164,6 +165,11 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)", ) + parser.add_argument( + "--pin_memory", + action="store_true", + help="Pin memory for faster GPU loading / GPU の読み込みを高速化するためのピンメモリ", + ) parser.add_argument("--max_length", type=int, default=50, help="max length of caption / captionの最大長") parser.add_argument( "--remove_words", diff --git a/finetune/prepare_buckets_latents.py b/finetune/prepare_buckets_latents.py index 019c737a..ef536db0 100644 --- a/finetune/prepare_buckets_latents.py +++ b/finetune/prepare_buckets_latents.py @@ -122,6 +122,7 @@ def main(args): dataset, batch_size=1, shuffle=False, + pin_memory=args.pin_memory, num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False, @@ -223,6 +224,11 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)", ) + parser.add_argument( + "--pin_memory", + action="store_true", + help="Pin memory for faster GPU loading / GPU の読み込みを高速化するためのピンメモリ", + ) parser.add_argument( "--max_resolution", type=str, diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index 07a6510e..50f16691 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -336,6 +336,7 @@ def main(args): dataset, batch_size=args.batch_size, shuffle=False, + pin_memory=args.pin_memory, num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False, @@ -410,6 +411,11 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)", ) + parser.add_argument( + "--pin_memory", + action="store_true", + help="Pin memory for faster GPU loading / GPU の読み込みを高速化するためのピンメモリ", + ) parser.add_argument( "--caption_extention", type=str, diff --git a/flux_train.py b/flux_train.py index 4aa67220..d4363db1 100644 --- a/flux_train.py +++ b/flux_train.py @@ -398,6 +398,7 @@ def train(args): shuffle=True, collate_fn=collator, num_workers=n_workers, + pin_memory=args.pin_memory, persistent_workers=args.persistent_data_loader_workers, ) diff --git a/flux_train_control_net.py b/flux_train_control_net.py index 01991405..d6694476 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -405,6 +405,7 @@ def train(args): shuffle=True, collate_fn=collator, num_workers=n_workers, + pin_memory=args.pin_memory, persistent_workers=args.persistent_data_loader_workers, ) diff --git a/library/train_util.py b/library/train_util.py index 756d88b1..4258e3b8 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -12,9 +12,8 @@ import pathlib import re import shutil import time -import typing from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union -from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState +from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState, DataLoaderConfiguration import glob import math import os @@ -209,6 +208,19 @@ class ImageInfo: self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime self.resize_interpolation: Optional[str] = None + @staticmethod + def _pin_tensor(tensor): + return tensor.pin_memory() if tensor is not None else tensor + + def pin_memory(self): + self.latents = self._pin_tensor(self.latents) + self.latents_flipped = self._pin_tensor(self.latents_flipped) + self.text_encoder_outputs1 = self._pin_tensor(self.text_encoder_outputs1) + self.text_encoder_outputs2 = self._pin_tensor(self.text_encoder_outputs2) + self.text_encoder_pool2 = self._pin_tensor(self.text_encoder_pool2) + self.alpha_mask = self._pin_tensor(self.alpha_mask) + return self + class BucketManager: def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None: @@ -2181,6 +2193,11 @@ class DreamBoothDataset(BaseDataset): self.num_reg_images = num_reg_images + def pin_memory(self): + for key in self.image_data.keys(): + if hasattr(self.image_data[key], 'pin_memory') and callable(self.image_data[key].pin_memory): + self.image_data[key].pin_memory() + class FineTuningDataset(BaseDataset): def __init__( @@ -4005,6 +4022,11 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: action="store_true", help="persistent DataLoader workers (useful for reduce time gap between epoch, but may use more memory) / DataLoader のワーカーを持続させる (エポック間の時間差を少なくするのに有効だが、より多くのメモリを消費する可能性がある)", ) + parser.add_argument( + "--pin_memory", + action="store_true", + help="Pin memory for faster GPU loading / GPU の読み込みを高速化するためのピンメモリ", + ) parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed") parser.add_argument( "--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / gradient checkpointingを有効にする" @@ -5529,6 +5551,8 @@ def prepare_accelerator(args: argparse.Namespace): kwargs_handlers = [i for i in kwargs_handlers if i is not None] deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args) + dataloader_config = DataLoaderConfiguration(non_blocking=args.pin_memory) + accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, @@ -5537,6 +5561,7 @@ def prepare_accelerator(args: argparse.Namespace): kwargs_handlers=kwargs_handlers, dynamo_backend=dynamo_backend, deepspeed_plugin=deepspeed_plugin, + dataloader_config=dataloader_config ) print("accelerator device:", accelerator.device) return accelerator @@ -6700,6 +6725,10 @@ class collator_class: dataset.set_current_step(self.current_step.value) return examples[0] + def pin_memory(self): + if hasattr(self, 'pin_memory') and callable(self.pin_memory): + self.dataset.pin_memory() + class LossRecorder: def __init__(self): diff --git a/sd3_train.py b/sd3_train.py index c6a2fdd8..f90a811e 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -502,6 +502,7 @@ def train(args): shuffle=True, collate_fn=collator, num_workers=n_workers, + pin_memory=args.pin_memory, persistent_workers=args.persistent_data_loader_workers, ) diff --git a/sdxl_train.py b/sdxl_train.py index f454263a..33cf1d28 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -431,6 +431,7 @@ def train(args): shuffle=True, collate_fn=collator, num_workers=n_workers, + pin_memory=args.pin_memory, persistent_workers=args.persistent_data_loader_workers, ) diff --git a/sdxl_train_control_net.py b/sdxl_train_control_net.py index 3d107e57..2b814880 100644 --- a/sdxl_train_control_net.py +++ b/sdxl_train_control_net.py @@ -282,6 +282,7 @@ def train(args): shuffle=True, collate_fn=collator, num_workers=n_workers, + pin_memory=args.pin_memory, persistent_workers=args.persistent_data_loader_workers, ) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 4dd4b8d9..2f01145e 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -273,6 +273,7 @@ def train(args): shuffle=True, collate_fn=collator, num_workers=n_workers, + pin_memory=args.pin_memory, persistent_workers=args.persistent_data_loader_workers, ) diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 0a9f4a92..4355e3d9 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -221,6 +221,7 @@ def train(args): shuffle=True, collate_fn=collator, num_workers=n_workers, + pin_memory=args.pin_memory, persistent_workers=args.persistent_data_loader_workers, ) diff --git a/tests/test_pin_memory.py b/tests/test_pin_memory.py new file mode 100644 index 00000000..ab5fffac --- /dev/null +++ b/tests/test_pin_memory.py @@ -0,0 +1,175 @@ +import argparse +import pytest +import torch +import importlib.metadata +import multiprocessing + +def test_pin_memory_argument(): + """ + Test that the pin_memory argument is correctly added to argument parsers + """ + from library.train_util import add_training_arguments + + parser = argparse.ArgumentParser() + add_training_arguments(parser, support_dreambooth=True) + + # Parse an empty list of arguments to check the default + args = parser.parse_args([]) + assert hasattr(args, "pin_memory"), "pin_memory argument should be present in argument parser" + assert args.pin_memory is False, "pin_memory should default to False" + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_dreambooth_dataset_pin_memory(): + """ + Test pin_memory functionality using a simple mock dataset + """ + from library.train_util import DreamBoothDataset, DreamBoothSubset, collator_class + + # Create a mock DreamBoothSubset with minimal arguments + def create_mock_subset(): + return DreamBoothSubset( + image_dir='/mock/path', + is_reg=False, + class_tokens='test_token', + caption_extension='.txt', + alpha_mask=False, + num_repeats=1, + shuffle_caption=False, + caption_separator=',', + keep_tokens=0, + keep_tokens_separator='', + secondary_separator='', + enable_wildcard=False, + color_aug=False, + flip_aug=False, + face_crop_aug_range=None, + random_crop=False, + caption_dropout_rate=0, + caption_dropout_every_n_epochs=0, + caption_tag_dropout_rate=0, + caption_prefix='', + caption_suffix='', + token_warmup_min=1, + token_warmup_step=0, + cache_info=False + ) + + # Create a simplified mock dataset + class SimpleMockDataset(torch.utils.data.Dataset): + def __init__(self): + self.data = [torch.randn(64, 64) for _ in range(4)] + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + return self.data[idx] + + # Create a DataLoader to test pin_memory + dataloader = torch.utils.data.DataLoader( + SimpleMockDataset(), + batch_size=2, + num_workers=0, # Use 0 to avoid multiprocessing overhead + pin_memory=True + ) + + # Verify pin_memory works correctly + for batch in dataloader: + # Pinning only works when CUDA is available + assert all(tensor.is_pinned() for tensor in batch), "All tensors should be pinned" + break + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_pin_memory_cuda_transfer(): + """ + Test pin_memory functionality for CUDA tensor transfer + """ + # Create a simple dataset + class SimpleCUDADataset(torch.utils.data.Dataset): + def __init__(self): + self.data = [torch.randn(64, 64) for _ in range(4)] + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + return self.data[idx] + + # Create a DataLoader with pin_memory enabled + dataloader = torch.utils.data.DataLoader( + SimpleCUDADataset(), + batch_size=2, + num_workers=0, # Use 0 to avoid multiprocessing overhead + pin_memory=True + ) + + # Verify CUDA transfer works with pinned memory + for batch in dataloader: + cuda_batch = [tensor.to('cuda', non_blocking=True) for tensor in batch] + assert all(tensor.is_pinned() for tensor in batch), "All tensors should be pinned" + break + +def test_training_scripts_pin_memory_support(): + """ + Verify that multiple training scripts support pin_memory argument + """ + training_scripts = [ + "fine_tune.py", + "flux_train.py", + "sd3_train.py", + "sdxl_train.py", + "train_network.py", + "train_textual_inversion.py", + "sdxl_train_control_net.py", + "flux_train_control_net.py", + ] + + from library.train_util import add_training_arguments + + for script in training_scripts: + parser = argparse.ArgumentParser() + add_training_arguments(parser, support_dreambooth=True) + + # Parse arguments to check pin_memory + args = parser.parse_args([]) + assert hasattr(args, "pin_memory"), f"{script} should have pin_memory argument" + +def test_accelerator_pin_memory_config(): + """ + Test that the Accelerator is configured with pin_memory option + Checks compatibility and configuration based on Accelerate library version + """ + from library.train_util import prepare_accelerator + + # Check Accelerate library version + try: + accelerate_version = importlib.metadata.version("accelerate") + print(f"Accelerate library version: {accelerate_version}") + except importlib.metadata.PackageNotFoundError: + pytest.skip("Accelerate library not installed") + + # Minimal args to pass initial checks + args = argparse.Namespace( + gradient_accumulation_steps=1, + mixed_precision="no", + log_with=None, + kwargs_handlers=[], + deepspeed_plugin=None, + pin_memory=True, + logging_dir=None, + torch_compile=False, + log_prefix=None, + ddp_gradient_as_bucket_view=False, + ddp_static_graph=False, + ddp_timeout=None, + wandb_api_key=None, + dynamo_backend="NO", + deepspeed=False, + ) + + # Prepare accelerator + accelerator = prepare_accelerator(args) + + # Check for dataloader_config + assert hasattr(accelerator, "dataloader_config"), "Accelerator should have dataloader_config when pin_memory is enabled" + assert accelerator.dataloader_config.non_blocking is True, "Dataloader should be configured with pin_memory" diff --git a/train_db.py b/train_db.py index 4bf3b31c..cb84f897 100644 --- a/train_db.py +++ b/train_db.py @@ -212,6 +212,7 @@ def train(args): shuffle=True, collate_fn=collator, num_workers=n_workers, + pin_memory=args.pin_memory, persistent_workers=args.persistent_data_loader_workers, ) diff --git a/train_network.py b/train_network.py index 6cebf5fc..0eb51526 100644 --- a/train_network.py +++ b/train_network.py @@ -783,6 +783,7 @@ class NetworkTrainer: shuffle=True, collate_fn=collator, num_workers=n_workers, + pin_memory=args.pin_memory, persistent_workers=args.persistent_data_loader_workers, ) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 8575698d..9c669a53 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -412,6 +412,7 @@ class TextualInversionTrainer: shuffle=True, collate_fn=collator, num_workers=n_workers, + pin_memory=args.pin_memory, persistent_workers=args.persistent_data_loader_workers, ) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 77821095..d11fe8e6 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -317,6 +317,7 @@ def train(args): shuffle=True, collate_fn=collator, num_workers=n_workers, + pin_memory=args.pin_memory, persistent_workers=args.persistent_data_loader_workers, )