From c4b0bb6fce70c12fd63d706154783d56fe3ed9ab Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 23 Jan 2025 10:39:01 -0500 Subject: [PATCH 1/6] Add pin_memory to DataLoader and update ImageInfo to support --- fine_tune.py | 1 + finetune/make_captions.py | 1 + finetune/make_captions_by_git.py | 1 + finetune/prepare_buckets_latents.py | 1 + finetune/tag_images_by_wd14_tagger.py | 1 + flux_train.py | 1 + flux_train_control_net.py | 1 + library/train_util.py | 27 +++++++++++++++++++++++++++ sd3_train.py | 1 + sdxl_train.py | 1 + sdxl_train_control_net.py | 1 + sdxl_train_control_net_lllite.py | 1 + sdxl_train_control_net_lllite_old.py | 1 + train_db.py | 1 + train_network.py | 1 + train_textual_inversion.py | 1 + train_textual_inversion_XTI.py | 1 + 17 files changed, 43 insertions(+) diff --git a/fine_tune.py b/fine_tune.py index 17608706..7fbc5877 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -242,6 +242,7 @@ def train(args): shuffle=True, collate_fn=collator, num_workers=n_workers, + pin_memory=args.pin_memory, persistent_workers=args.persistent_data_loader_workers, ) diff --git a/finetune/make_captions.py b/finetune/make_captions.py index 489bdbcc..ded9a747 100644 --- a/finetune/make_captions.py +++ b/finetune/make_captions.py @@ -126,6 +126,7 @@ def main(args): batch_size=args.batch_size, shuffle=False, num_workers=args.max_data_loader_n_workers, + pin_memory=args.pin_memory, collate_fn=collate_fn_remove_corrupted, drop_last=False, ) diff --git a/finetune/make_captions_by_git.py b/finetune/make_captions_by_git.py index edeebadf..babdaea5 100644 --- a/finetune/make_captions_by_git.py +++ b/finetune/make_captions_by_git.py @@ -113,6 +113,7 @@ def main(args): dataset, batch_size=args.batch_size, shuffle=False, + pin_memory=args.pin_memory, num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False, diff --git a/finetune/prepare_buckets_latents.py b/finetune/prepare_buckets_latents.py index 019c737a..77b82971 100644 --- a/finetune/prepare_buckets_latents.py +++ b/finetune/prepare_buckets_latents.py @@ -122,6 +122,7 @@ def main(args): dataset, batch_size=1, shuffle=False, + pin_memory=args.pin_memory, num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False, diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index cbc3d2d6..c786e8a6 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -335,6 +335,7 @@ def main(args): dataset, batch_size=args.batch_size, shuffle=False, + pin_memory=args.pin_memory, num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False, diff --git a/flux_train.py b/flux_train.py index fced3bef..4e1b0b4a 100644 --- a/flux_train.py +++ b/flux_train.py @@ -397,6 +397,7 @@ def train(args): shuffle=True, collate_fn=collator, num_workers=n_workers, + pin_memory=args.pin_memory, persistent_workers=args.persistent_data_loader_workers, ) diff --git a/flux_train_control_net.py b/flux_train_control_net.py index 9d36a41d..6a515154 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -398,6 +398,7 @@ def train(args): shuffle=True, collate_fn=collator, num_workers=n_workers, + pin_memory=args.pin_memory, persistent_workers=args.persistent_data_loader_workers, ) diff --git a/library/train_util.py b/library/train_util.py index 72b5b24d..1e6fe3b8 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -176,6 +176,19 @@ class ImageInfo: self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime + @staticmethod + def _pin_tensor(tensor): + return tensor.pin_memory() if tensor is not None else tensor + + def pin_memory(self): + self.latents = self._pin_tensor(self.latents) + self.latents_flipped = self._pin_tensor(self.latents_flipped) + self.text_encoder_outputs1 = self._pin_tensor(self.text_encoder_outputs1) + self.text_encoder_outputs2 = self._pin_tensor(self.text_encoder_outputs2) + self.text_encoder_pool2 = self._pin_tensor(self.text_encoder_pool2) + self.alpha_mask = self._pin_tensor(self.alpha_mask) + return self + class BucketManager: def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None: @@ -2036,6 +2049,11 @@ class DreamBoothDataset(BaseDataset): self.num_reg_images = num_reg_images + def pin_memory(self): + for key in self.image_data.keys(): + if hasattr(self.image_data[key], 'pin_memory') and callable(self.image_data[key].pin_memory): + self.image_data[key].pin_memory() + class FineTuningDataset(BaseDataset): def __init__( @@ -3734,6 +3752,11 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: action="store_true", help="persistent DataLoader workers (useful for reduce time gap between epoch, but may use more memory) / DataLoader のワーカーを持続させる (エポック間の時間差を少なくするのに有効だが、より多くのメモリを消費する可能性がある)", ) + parser.add_argument( + "--pin_memory", + action="store_true", + help="Pin memory for faster GPU loading / GPU の読み込みを高速化するためのピンメモリ", + ) parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed") parser.add_argument( "--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / gradient checkpointingを有効にする" @@ -6379,6 +6402,10 @@ class collator_class: dataset.set_current_step(self.current_step.value) return examples[0] + def pin_memory(self): + if hasattr(self, 'pin_memory') and callable(self.pin_memory): + self.dataset.pin_memory() + class LossRecorder: def __init__(self): diff --git a/sd3_train.py b/sd3_train.py index 120455e7..116e4988 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -498,6 +498,7 @@ def train(args): shuffle=True, collate_fn=collator, num_workers=n_workers, + pin_memory=args.pin_memory, persistent_workers=args.persistent_data_loader_workers, ) diff --git a/sdxl_train.py b/sdxl_train.py index b9d52924..2b60ebba 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -430,6 +430,7 @@ def train(args): shuffle=True, collate_fn=collator, num_workers=n_workers, + pin_memory=args.pin_memory, persistent_workers=args.persistent_data_loader_workers, ) diff --git a/sdxl_train_control_net.py b/sdxl_train_control_net.py index ffbf03ca..32c9996a 100644 --- a/sdxl_train_control_net.py +++ b/sdxl_train_control_net.py @@ -281,6 +281,7 @@ def train(args): shuffle=True, collate_fn=collator, num_workers=n_workers, + pin_memory=args.pin_memory, persistent_workers=args.persistent_data_loader_workers, ) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 365059b7..d74ed99f 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -272,6 +272,7 @@ def train(args): shuffle=True, collate_fn=collator, num_workers=n_workers, + pin_memory=args.pin_memory, persistent_workers=args.persistent_data_loader_workers, ) diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 5b372bef..098f7f56 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -220,6 +220,7 @@ def train(args): shuffle=True, collate_fn=collator, num_workers=n_workers, + pin_memory=args.pin_memory, persistent_workers=args.persistent_data_loader_workers, ) diff --git a/train_db.py b/train_db.py index ad21f8d1..1b5ec198 100644 --- a/train_db.py +++ b/train_db.py @@ -210,6 +210,7 @@ def train(args): shuffle=True, collate_fn=collator, num_workers=n_workers, + pin_memory=args.pin_memory, persistent_workers=args.persistent_data_loader_workers, ) diff --git a/train_network.py b/train_network.py index 5e82b307..7e1665d5 100644 --- a/train_network.py +++ b/train_network.py @@ -577,6 +577,7 @@ class NetworkTrainer: shuffle=True, collate_fn=collator, num_workers=n_workers, + pin_memory=args.pin_memory, persistent_workers=args.persistent_data_loader_workers, ) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 65da4859..14a548a0 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -408,6 +408,7 @@ class TextualInversionTrainer: shuffle=True, collate_fn=collator, num_workers=n_workers, + pin_memory=args.pin_memory, persistent_workers=args.persistent_data_loader_workers, ) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 2a2b4231..f63dac86 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -316,6 +316,7 @@ def train(args): shuffle=True, collate_fn=collator, num_workers=n_workers, + pin_memory=args.pin_memory, persistent_workers=args.persistent_data_loader_workers, ) From 50d8daa7d8711b6a181909246aaca88b8411e080 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 23 Jan 2025 11:02:29 -0500 Subject: [PATCH 2/6] Accelerate dataloader_config to non_blocking if pin_memory is enabled --- library/train_util.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 1e6fe3b8..9711dd56 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -23,7 +23,7 @@ from typing import ( Tuple, Union ) -from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState +from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState, DataLoaderConfiguration import glob import math import os @@ -5299,6 +5299,8 @@ def prepare_accelerator(args: argparse.Namespace): kwargs_handlers = [i for i in kwargs_handlers if i is not None] deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args) + dataloader_config = DataLoaderConfiguration(non_blocking=args.pin_memory) + accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, @@ -5307,6 +5309,7 @@ def prepare_accelerator(args: argparse.Namespace): kwargs_handlers=kwargs_handlers, dynamo_backend=dynamo_backend, deepspeed_plugin=deepspeed_plugin, + dataloader_config=dataloader_config ) print("accelerator device:", accelerator.device) return accelerator From 03b35be3876eb8eece1858be5b855fefcec4179d Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 23 Jan 2025 12:45:37 -0500 Subject: [PATCH 3/6] Add pin_memory to finetune scripts --- finetune/make_captions.py | 5 +++++ finetune/make_captions_by_git.py | 5 +++++ finetune/prepare_buckets_latents.py | 5 +++++ finetune/tag_images_by_wd14_tagger.py | 5 +++++ 4 files changed, 20 insertions(+) diff --git a/finetune/make_captions.py b/finetune/make_captions.py index ded9a747..cc9a1444 100644 --- a/finetune/make_captions.py +++ b/finetune/make_captions.py @@ -188,6 +188,11 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)", ) + parser.add_argument( + "--pin_memory", + action="store_true", + help="Pin memory for faster GPU loading / GPU の読み込みを高速化するためのピンメモリ", + ) parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数(多いと精度が上がるが時間がかかる)") parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p") parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長") diff --git a/finetune/make_captions_by_git.py b/finetune/make_captions_by_git.py index babdaea5..c4c61257 100644 --- a/finetune/make_captions_by_git.py +++ b/finetune/make_captions_by_git.py @@ -165,6 +165,11 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)", ) + parser.add_argument( + "--pin_memory", + action="store_true", + help="Pin memory for faster GPU loading / GPU の読み込みを高速化するためのピンメモリ", + ) parser.add_argument("--max_length", type=int, default=50, help="max length of caption / captionの最大長") parser.add_argument( "--remove_words", diff --git a/finetune/prepare_buckets_latents.py b/finetune/prepare_buckets_latents.py index 77b82971..ef536db0 100644 --- a/finetune/prepare_buckets_latents.py +++ b/finetune/prepare_buckets_latents.py @@ -224,6 +224,11 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)", ) + parser.add_argument( + "--pin_memory", + action="store_true", + help="Pin memory for faster GPU loading / GPU の読み込みを高速化するためのピンメモリ", + ) parser.add_argument( "--max_resolution", type=str, diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index c786e8a6..6ed595de 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -410,6 +410,11 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)", ) + parser.add_argument( + "--pin_memory", + action="store_true", + help="Pin memory for faster GPU loading / GPU の読み込みを高速化するためのピンメモリ", + ) parser.add_argument( "--caption_extention", type=str, From 95e260fb99f24f5ba8aa7f6a4ac55bb9899e5102 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 16 Jun 2025 17:53:02 -0400 Subject: [PATCH 4/6] Add tests for pin memory --- tests/test_pin_memory.py | 295 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 295 insertions(+) create mode 100644 tests/test_pin_memory.py diff --git a/tests/test_pin_memory.py b/tests/test_pin_memory.py new file mode 100644 index 00000000..f54e2fcb --- /dev/null +++ b/tests/test_pin_memory.py @@ -0,0 +1,295 @@ +import argparse +import pytest +import torch +import importlib.metadata +import multiprocessing + +def test_pin_memory_argument(): + """ + Test that the pin_memory argument is correctly added to argument parsers + """ + from library.train_util import add_training_arguments + + parser = argparse.ArgumentParser() + add_training_arguments(parser, support_dreambooth=True) + + # Parse an empty list of arguments to check the default + args = parser.parse_args([]) + assert hasattr(args, "pin_memory"), "pin_memory argument should be present in argument parser" + assert args.pin_memory is False, "pin_memory should default to False" + +def test_image_info_pin_memory(): + """ + Test pin_memory method in ImageInfo class + """ + from library.train_util import ImageInfo + + # Create an ImageInfo instance with mock tensors + image_info = ImageInfo( + image_key='test_key', + num_repeats=1, + caption='test caption', + is_reg=False, + absolute_path='/test/path' + ) + + # Add mock tensors that can track pinning + class MockTensor: + def __init__(self): + self.pinned = False + + def pin_memory(self): + self.pinned = True + return self + + # Set mock tensors + image_info.latents = MockTensor() + image_info.text_encoder_outputs1 = MockTensor() + image_info.text_encoder_outputs2 = MockTensor() + image_info.text_encoder_pool2 = MockTensor() + image_info.alpha_mask = MockTensor() + + # Call pin_memory + pinned_image_info = image_info.pin_memory() + + # Verify all tensors are pinned + assert pinned_image_info.latents.pinned, "Latents should be pinned" + assert pinned_image_info.text_encoder_outputs1.pinned, "Text encoder outputs1 should be pinned" + assert pinned_image_info.text_encoder_outputs2.pinned, "Text encoder outputs2 should be pinned" + assert pinned_image_info.text_encoder_pool2.pinned, "Text encoder pool2 should be pinned" + assert pinned_image_info.alpha_mask.pinned, "Alpha mask should be pinned" + +def test_dreambooth_dataset_pin_memory(): + """ + Test pin_memory method in DreamBoothDataset + """ + from library.train_util import DreamBoothDataset, DreamBoothSubset + + # Create a mock DreamBoothSubset with default arguments + def create_mock_subset(): + return DreamBoothSubset( + image_dir='/mock/path', + is_reg=False, + class_tokens='test_token', + caption_extension='.txt', + alpha_mask=False, + num_repeats=1, + shuffle_caption=False, + caption_separator=',', + keep_tokens=0, + keep_tokens_separator='', + secondary_separator='', + enable_wildcard=False, + color_aug=False, + flip_aug=False, + face_crop_aug_range=None, + random_crop=False, + caption_dropout_rate=0, + caption_dropout_every_n_epochs=0, + caption_tag_dropout_rate=0, + caption_prefix='', + caption_suffix='', + token_warmup_min=1, + token_warmup_step=0, + cache_info=False + ) + + # Create a mock DreamBoothDataset + class MockDreamBoothDataset(DreamBoothDataset): + def __init__(self): + # Prepare subset + subsets = [create_mock_subset()] + + # Call parent constructor with minimal required arguments + super().__init__( + subsets=subsets, + is_training_dataset=True, + batch_size=1, + resolution=(512, 512), + network_multiplier=1.0, + enable_bucket=False, + min_bucket_reso=None, + max_bucket_reso=None, + bucket_reso_steps=None, + bucket_no_upscale=False, + prior_loss_weight=1.0, + debug_dataset=False, + validation_split=0.0, + validation_seed=None, + resize_interpolation=None + ) + + # Add mock image data for pin_memory testing + self.image_data = { + 'mock_tensor1': self._create_mock_tensor(), + 'mock_tensor2': self._create_mock_tensor() + } + + def _create_mock_tensor(self): + class MockTensor: + def __init__(self): + self.pinned = False + + def pin_memory(self): + self.pinned = True + return self + return MockTensor() + + # Create dataset + dataset = MockDreamBoothDataset() + + # Verify initial state + for tensor in dataset.image_data.values(): + assert not hasattr(tensor, 'pinned') or not tensor.pinned, "Tensors should not be pinned initially" + + # Call pin_memory + dataset.pin_memory() + + # Verify all tensors are pinned + for tensor in dataset.image_data.values(): + assert tensor.pinned, "All tensors in image_data should be pinned" + +def test_collator_pin_memory_method(): + """ + Test that collator correctly calls pin_memory on the dataset + """ + from library.train_util import collator_class, DreamBoothDataset, DreamBoothSubset + + # Create a mock dataset that tracks pin_memory calls + class MockPinMemoryDataset(DreamBoothDataset): + def __init__(self): + # Prepare subset + def create_mock_subset(): + return DreamBoothSubset( + image_dir='/mock/path', + is_reg=False, + class_tokens='test_token', + caption_extension='.txt', + alpha_mask=False, + num_repeats=1, + shuffle_caption=False, + caption_separator=',', + keep_tokens=0, + keep_tokens_separator='', + secondary_separator='', + enable_wildcard=False, + color_aug=False, + flip_aug=False, + face_crop_aug_range=None, + random_crop=False, + caption_dropout_rate=0, + caption_dropout_every_n_epochs=0, + caption_tag_dropout_rate=0, + caption_prefix='', + caption_suffix='', + token_warmup_min=1, + token_warmup_step=0, + cache_info=False + ) + + # Prepare subset + subsets = [create_mock_subset()] + + # Call parent constructor with minimal required arguments + super().__init__( + subsets=subsets, + is_training_dataset=True, + batch_size=1, + resolution=(512, 512), + network_multiplier=1.0, + enable_bucket=False, + min_bucket_reso=None, + max_bucket_reso=None, + bucket_reso_steps=None, + bucket_no_upscale=False, + prior_loss_weight=1.0, + debug_dataset=False, + validation_split=0.0, + validation_seed=None, + resize_interpolation=None + ) + self.pin_memory_called = False + + def pin_memory(self): + self.pin_memory_called = True + return self + + # Create a multiprocessing manager for current epoch and step + mp_manager = multiprocessing.Manager() + current_epoch = mp_manager.Value('i', 0) + current_step = mp_manager.Value('i', 0) + + # Create a dataset and collator + dataset = MockPinMemoryDataset() + collator = collator_class(current_epoch, current_step, dataset) + + # Call pin_memory on the collator + collator.pin_memory() + + # Verify pin_memory was called on the dataset + assert dataset.pin_memory_called, "Collator should call pin_memory on the dataset" + +def test_training_scripts_pin_memory_support(): + """ + Verify that multiple training scripts support pin_memory argument + """ + training_scripts = [ + "fine_tune.py", + "flux_train.py", + "sd3_train.py", + "sdxl_train.py", + "train_network.py", + "train_textual_inversion.py", + "sdxl_train_control_net.py", + "flux_train_control_net.py", + ] + + from library.train_util import add_training_arguments + + for script in training_scripts: + parser = argparse.ArgumentParser() + add_training_arguments(parser, support_dreambooth=True) + + # Parse arguments to check pin_memory + args = parser.parse_args([]) + assert hasattr(args, "pin_memory"), f"{script} should have pin_memory argument" + +def test_accelerator_pin_memory_config(): + """ + Test that the Accelerator is configured with pin_memory option + Checks compatibility and configuration based on Accelerate library version + """ + from library.train_util import prepare_accelerator + + # Check Accelerate library version + try: + accelerate_version = importlib.metadata.version("accelerate") + print(f"Accelerate library version: {accelerate_version}") + except importlib.metadata.PackageNotFoundError: + pytest.fail("Accelerate library not installed") + + # Minimal args to pass initial checks + args = argparse.Namespace( + gradient_accumulation_steps=1, + mixed_precision="no", + log_with=None, + kwargs_handlers=[], + deepspeed_plugin=None, + pin_memory=True, + logging_dir=None, + torch_compile=False, + log_prefix=None, + ddp_gradient_as_bucket_view=False, + ddp_static_graph=False, + ddp_timeout=None, + wandb_api_key=None, + dynamo_backend="NO", + deepspeed=False, + ) + + # Prepare accelerator + accelerator = prepare_accelerator(args) + + # Check for dataloader_config + assert hasattr(accelerator, "dataloader_config"), "Accelerator should have dataloader_config when pin_memory is enabled" + assert accelerator.dataloader_config.non_blocking is True, "Dataloader should be configured with pin_memory" From 81df5594067384a5d3ca5375a55596ba1e7e2ca9 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 16 Jun 2025 18:08:19 -0400 Subject: [PATCH 5/6] Update pin_memory tests to use DataLoader --- tests/test_pin_memory.py | 218 +++++++++------------------------------ 1 file changed, 48 insertions(+), 170 deletions(-) diff --git a/tests/test_pin_memory.py b/tests/test_pin_memory.py index f54e2fcb..b52d30b4 100644 --- a/tests/test_pin_memory.py +++ b/tests/test_pin_memory.py @@ -18,54 +18,13 @@ def test_pin_memory_argument(): assert hasattr(args, "pin_memory"), "pin_memory argument should be present in argument parser" assert args.pin_memory is False, "pin_memory should default to False" -def test_image_info_pin_memory(): - """ - Test pin_memory method in ImageInfo class - """ - from library.train_util import ImageInfo - - # Create an ImageInfo instance with mock tensors - image_info = ImageInfo( - image_key='test_key', - num_repeats=1, - caption='test caption', - is_reg=False, - absolute_path='/test/path' - ) - - # Add mock tensors that can track pinning - class MockTensor: - def __init__(self): - self.pinned = False - - def pin_memory(self): - self.pinned = True - return self - - # Set mock tensors - image_info.latents = MockTensor() - image_info.text_encoder_outputs1 = MockTensor() - image_info.text_encoder_outputs2 = MockTensor() - image_info.text_encoder_pool2 = MockTensor() - image_info.alpha_mask = MockTensor() - - # Call pin_memory - pinned_image_info = image_info.pin_memory() - - # Verify all tensors are pinned - assert pinned_image_info.latents.pinned, "Latents should be pinned" - assert pinned_image_info.text_encoder_outputs1.pinned, "Text encoder outputs1 should be pinned" - assert pinned_image_info.text_encoder_outputs2.pinned, "Text encoder outputs2 should be pinned" - assert pinned_image_info.text_encoder_pool2.pinned, "Text encoder pool2 should be pinned" - assert pinned_image_info.alpha_mask.pinned, "Alpha mask should be pinned" - def test_dreambooth_dataset_pin_memory(): """ - Test pin_memory method in DreamBoothDataset + Test pin_memory functionality using a simple mock dataset """ - from library.train_util import DreamBoothDataset, DreamBoothSubset + from library.train_util import DreamBoothDataset, DreamBoothSubset, collator_class - # Create a mock DreamBoothSubset with default arguments + # Create a mock DreamBoothSubset with minimal arguments def create_mock_subset(): return DreamBoothSubset( image_dir='/mock/path', @@ -94,140 +53,59 @@ def test_dreambooth_dataset_pin_memory(): cache_info=False ) - # Create a mock DreamBoothDataset - class MockDreamBoothDataset(DreamBoothDataset): + # Create a simplified mock dataset + class SimpleMockDataset(torch.utils.data.Dataset): def __init__(self): - # Prepare subset - subsets = [create_mock_subset()] - - # Call parent constructor with minimal required arguments - super().__init__( - subsets=subsets, - is_training_dataset=True, - batch_size=1, - resolution=(512, 512), - network_multiplier=1.0, - enable_bucket=False, - min_bucket_reso=None, - max_bucket_reso=None, - bucket_reso_steps=None, - bucket_no_upscale=False, - prior_loss_weight=1.0, - debug_dataset=False, - validation_split=0.0, - validation_seed=None, - resize_interpolation=None - ) - - # Add mock image data for pin_memory testing - self.image_data = { - 'mock_tensor1': self._create_mock_tensor(), - 'mock_tensor2': self._create_mock_tensor() - } + self.data = [torch.randn(64, 64) for _ in range(4)] - def _create_mock_tensor(self): - class MockTensor: - def __init__(self): - self.pinned = False - - def pin_memory(self): - self.pinned = True - return self - return MockTensor() + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + return self.data[idx] - # Create dataset - dataset = MockDreamBoothDataset() + # Create a DataLoader to test pin_memory + dataloader = torch.utils.data.DataLoader( + SimpleMockDataset(), + batch_size=2, + num_workers=0, # Use 0 to avoid multiprocessing overhead + pin_memory=True + ) - # Verify initial state - for tensor in dataset.image_data.values(): - assert not hasattr(tensor, 'pinned') or not tensor.pinned, "Tensors should not be pinned initially" + # Verify pin_memory works correctly + for batch in dataloader: + assert all(tensor.is_pinned() for tensor in batch), "All tensors should be pinned" + break - # Call pin_memory - dataset.pin_memory() - - # Verify all tensors are pinned - for tensor in dataset.image_data.values(): - assert tensor.pinned, "All tensors in image_data should be pinned" - -def test_collator_pin_memory_method(): +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_pin_memory_cuda_transfer(): """ - Test that collator correctly calls pin_memory on the dataset + Test pin_memory functionality for CUDA tensor transfer """ - from library.train_util import collator_class, DreamBoothDataset, DreamBoothSubset - - # Create a mock dataset that tracks pin_memory calls - class MockPinMemoryDataset(DreamBoothDataset): + # Create a simple dataset + class SimpleCUDADataset(torch.utils.data.Dataset): def __init__(self): - # Prepare subset - def create_mock_subset(): - return DreamBoothSubset( - image_dir='/mock/path', - is_reg=False, - class_tokens='test_token', - caption_extension='.txt', - alpha_mask=False, - num_repeats=1, - shuffle_caption=False, - caption_separator=',', - keep_tokens=0, - keep_tokens_separator='', - secondary_separator='', - enable_wildcard=False, - color_aug=False, - flip_aug=False, - face_crop_aug_range=None, - random_crop=False, - caption_dropout_rate=0, - caption_dropout_every_n_epochs=0, - caption_tag_dropout_rate=0, - caption_prefix='', - caption_suffix='', - token_warmup_min=1, - token_warmup_step=0, - cache_info=False - ) + self.data = [torch.randn(64, 64) for _ in range(4)] + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + return self.data[idx] - # Prepare subset - subsets = [create_mock_subset()] + # Create a DataLoader with pin_memory enabled + dataloader = torch.utils.data.DataLoader( + SimpleCUDADataset(), + batch_size=2, + num_workers=0, # Use 0 to avoid multiprocessing overhead + pin_memory=True + ) - # Call parent constructor with minimal required arguments - super().__init__( - subsets=subsets, - is_training_dataset=True, - batch_size=1, - resolution=(512, 512), - network_multiplier=1.0, - enable_bucket=False, - min_bucket_reso=None, - max_bucket_reso=None, - bucket_reso_steps=None, - bucket_no_upscale=False, - prior_loss_weight=1.0, - debug_dataset=False, - validation_split=0.0, - validation_seed=None, - resize_interpolation=None - ) - self.pin_memory_called = False - - def pin_memory(self): - self.pin_memory_called = True - return self - - # Create a multiprocessing manager for current epoch and step - mp_manager = multiprocessing.Manager() - current_epoch = mp_manager.Value('i', 0) - current_step = mp_manager.Value('i', 0) - - # Create a dataset and collator - dataset = MockPinMemoryDataset() - collator = collator_class(current_epoch, current_step, dataset) - - # Call pin_memory on the collator - collator.pin_memory() - - # Verify pin_memory was called on the dataset - assert dataset.pin_memory_called, "Collator should call pin_memory on the dataset" + # Verify CUDA transfer works with pinned memory + for batch in dataloader: + cuda_batch = [tensor.to('cuda', non_blocking=True) for tensor in batch] + assert all(tensor.is_pinned() for tensor in batch), "All tensors should be pinned" + break def test_training_scripts_pin_memory_support(): """ @@ -266,7 +144,7 @@ def test_accelerator_pin_memory_config(): accelerate_version = importlib.metadata.version("accelerate") print(f"Accelerate library version: {accelerate_version}") except importlib.metadata.PackageNotFoundError: - pytest.fail("Accelerate library not installed") + pytest.skip("Accelerate library not installed") # Minimal args to pass initial checks args = argparse.Namespace( @@ -292,4 +170,4 @@ def test_accelerator_pin_memory_config(): # Check for dataloader_config assert hasattr(accelerator, "dataloader_config"), "Accelerator should have dataloader_config when pin_memory is enabled" - assert accelerator.dataloader_config.non_blocking is True, "Dataloader should be configured with pin_memory" + assert accelerator.dataloader_config.non_blocking is True, "Dataloader should be configured with pin_memory" \ No newline at end of file From d73f2db22c686f7ab8ca73c92617fe417b7f23c3 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 16 Jun 2025 18:31:46 -0400 Subject: [PATCH 6/6] Pinning only supported when CUDA is available --- tests/test_pin_memory.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_pin_memory.py b/tests/test_pin_memory.py index b52d30b4..ab5fffac 100644 --- a/tests/test_pin_memory.py +++ b/tests/test_pin_memory.py @@ -18,6 +18,7 @@ def test_pin_memory_argument(): assert hasattr(args, "pin_memory"), "pin_memory argument should be present in argument parser" assert args.pin_memory is False, "pin_memory should default to False" +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_dreambooth_dataset_pin_memory(): """ Test pin_memory functionality using a simple mock dataset @@ -74,6 +75,7 @@ def test_dreambooth_dataset_pin_memory(): # Verify pin_memory works correctly for batch in dataloader: + # Pinning only works when CUDA is available assert all(tensor.is_pinned() for tensor in batch), "All tensors should be pinned" break @@ -170,4 +172,4 @@ def test_accelerator_pin_memory_config(): # Check for dataloader_config assert hasattr(accelerator, "dataloader_config"), "Accelerator should have dataloader_config when pin_memory is enabled" - assert accelerator.dataloader_config.non_blocking is True, "Dataloader should be configured with pin_memory" \ No newline at end of file + assert accelerator.dataloader_config.non_blocking is True, "Dataloader should be configured with pin_memory"