diff --git a/tests/test_pin_memory.py b/tests/test_pin_memory.py new file mode 100644 index 00000000..f54e2fcb --- /dev/null +++ b/tests/test_pin_memory.py @@ -0,0 +1,295 @@ +import argparse +import pytest +import torch +import importlib.metadata +import multiprocessing + +def test_pin_memory_argument(): + """ + Test that the pin_memory argument is correctly added to argument parsers + """ + from library.train_util import add_training_arguments + + parser = argparse.ArgumentParser() + add_training_arguments(parser, support_dreambooth=True) + + # Parse an empty list of arguments to check the default + args = parser.parse_args([]) + assert hasattr(args, "pin_memory"), "pin_memory argument should be present in argument parser" + assert args.pin_memory is False, "pin_memory should default to False" + +def test_image_info_pin_memory(): + """ + Test pin_memory method in ImageInfo class + """ + from library.train_util import ImageInfo + + # Create an ImageInfo instance with mock tensors + image_info = ImageInfo( + image_key='test_key', + num_repeats=1, + caption='test caption', + is_reg=False, + absolute_path='/test/path' + ) + + # Add mock tensors that can track pinning + class MockTensor: + def __init__(self): + self.pinned = False + + def pin_memory(self): + self.pinned = True + return self + + # Set mock tensors + image_info.latents = MockTensor() + image_info.text_encoder_outputs1 = MockTensor() + image_info.text_encoder_outputs2 = MockTensor() + image_info.text_encoder_pool2 = MockTensor() + image_info.alpha_mask = MockTensor() + + # Call pin_memory + pinned_image_info = image_info.pin_memory() + + # Verify all tensors are pinned + assert pinned_image_info.latents.pinned, "Latents should be pinned" + assert pinned_image_info.text_encoder_outputs1.pinned, "Text encoder outputs1 should be pinned" + assert pinned_image_info.text_encoder_outputs2.pinned, "Text encoder outputs2 should be pinned" + assert pinned_image_info.text_encoder_pool2.pinned, "Text encoder pool2 should be pinned" + assert pinned_image_info.alpha_mask.pinned, "Alpha mask should be pinned" + +def test_dreambooth_dataset_pin_memory(): + """ + Test pin_memory method in DreamBoothDataset + """ + from library.train_util import DreamBoothDataset, DreamBoothSubset + + # Create a mock DreamBoothSubset with default arguments + def create_mock_subset(): + return DreamBoothSubset( + image_dir='/mock/path', + is_reg=False, + class_tokens='test_token', + caption_extension='.txt', + alpha_mask=False, + num_repeats=1, + shuffle_caption=False, + caption_separator=',', + keep_tokens=0, + keep_tokens_separator='', + secondary_separator='', + enable_wildcard=False, + color_aug=False, + flip_aug=False, + face_crop_aug_range=None, + random_crop=False, + caption_dropout_rate=0, + caption_dropout_every_n_epochs=0, + caption_tag_dropout_rate=0, + caption_prefix='', + caption_suffix='', + token_warmup_min=1, + token_warmup_step=0, + cache_info=False + ) + + # Create a mock DreamBoothDataset + class MockDreamBoothDataset(DreamBoothDataset): + def __init__(self): + # Prepare subset + subsets = [create_mock_subset()] + + # Call parent constructor with minimal required arguments + super().__init__( + subsets=subsets, + is_training_dataset=True, + batch_size=1, + resolution=(512, 512), + network_multiplier=1.0, + enable_bucket=False, + min_bucket_reso=None, + max_bucket_reso=None, + bucket_reso_steps=None, + bucket_no_upscale=False, + prior_loss_weight=1.0, + debug_dataset=False, + validation_split=0.0, + validation_seed=None, + resize_interpolation=None + ) + + # Add mock image data for pin_memory testing + self.image_data = { + 'mock_tensor1': self._create_mock_tensor(), + 'mock_tensor2': self._create_mock_tensor() + } + + def _create_mock_tensor(self): + class MockTensor: + def __init__(self): + self.pinned = False + + def pin_memory(self): + self.pinned = True + return self + return MockTensor() + + # Create dataset + dataset = MockDreamBoothDataset() + + # Verify initial state + for tensor in dataset.image_data.values(): + assert not hasattr(tensor, 'pinned') or not tensor.pinned, "Tensors should not be pinned initially" + + # Call pin_memory + dataset.pin_memory() + + # Verify all tensors are pinned + for tensor in dataset.image_data.values(): + assert tensor.pinned, "All tensors in image_data should be pinned" + +def test_collator_pin_memory_method(): + """ + Test that collator correctly calls pin_memory on the dataset + """ + from library.train_util import collator_class, DreamBoothDataset, DreamBoothSubset + + # Create a mock dataset that tracks pin_memory calls + class MockPinMemoryDataset(DreamBoothDataset): + def __init__(self): + # Prepare subset + def create_mock_subset(): + return DreamBoothSubset( + image_dir='/mock/path', + is_reg=False, + class_tokens='test_token', + caption_extension='.txt', + alpha_mask=False, + num_repeats=1, + shuffle_caption=False, + caption_separator=',', + keep_tokens=0, + keep_tokens_separator='', + secondary_separator='', + enable_wildcard=False, + color_aug=False, + flip_aug=False, + face_crop_aug_range=None, + random_crop=False, + caption_dropout_rate=0, + caption_dropout_every_n_epochs=0, + caption_tag_dropout_rate=0, + caption_prefix='', + caption_suffix='', + token_warmup_min=1, + token_warmup_step=0, + cache_info=False + ) + + # Prepare subset + subsets = [create_mock_subset()] + + # Call parent constructor with minimal required arguments + super().__init__( + subsets=subsets, + is_training_dataset=True, + batch_size=1, + resolution=(512, 512), + network_multiplier=1.0, + enable_bucket=False, + min_bucket_reso=None, + max_bucket_reso=None, + bucket_reso_steps=None, + bucket_no_upscale=False, + prior_loss_weight=1.0, + debug_dataset=False, + validation_split=0.0, + validation_seed=None, + resize_interpolation=None + ) + self.pin_memory_called = False + + def pin_memory(self): + self.pin_memory_called = True + return self + + # Create a multiprocessing manager for current epoch and step + mp_manager = multiprocessing.Manager() + current_epoch = mp_manager.Value('i', 0) + current_step = mp_manager.Value('i', 0) + + # Create a dataset and collator + dataset = MockPinMemoryDataset() + collator = collator_class(current_epoch, current_step, dataset) + + # Call pin_memory on the collator + collator.pin_memory() + + # Verify pin_memory was called on the dataset + assert dataset.pin_memory_called, "Collator should call pin_memory on the dataset" + +def test_training_scripts_pin_memory_support(): + """ + Verify that multiple training scripts support pin_memory argument + """ + training_scripts = [ + "fine_tune.py", + "flux_train.py", + "sd3_train.py", + "sdxl_train.py", + "train_network.py", + "train_textual_inversion.py", + "sdxl_train_control_net.py", + "flux_train_control_net.py", + ] + + from library.train_util import add_training_arguments + + for script in training_scripts: + parser = argparse.ArgumentParser() + add_training_arguments(parser, support_dreambooth=True) + + # Parse arguments to check pin_memory + args = parser.parse_args([]) + assert hasattr(args, "pin_memory"), f"{script} should have pin_memory argument" + +def test_accelerator_pin_memory_config(): + """ + Test that the Accelerator is configured with pin_memory option + Checks compatibility and configuration based on Accelerate library version + """ + from library.train_util import prepare_accelerator + + # Check Accelerate library version + try: + accelerate_version = importlib.metadata.version("accelerate") + print(f"Accelerate library version: {accelerate_version}") + except importlib.metadata.PackageNotFoundError: + pytest.fail("Accelerate library not installed") + + # Minimal args to pass initial checks + args = argparse.Namespace( + gradient_accumulation_steps=1, + mixed_precision="no", + log_with=None, + kwargs_handlers=[], + deepspeed_plugin=None, + pin_memory=True, + logging_dir=None, + torch_compile=False, + log_prefix=None, + ddp_gradient_as_bucket_view=False, + ddp_static_graph=False, + ddp_timeout=None, + wandb_api_key=None, + dynamo_backend="NO", + deepspeed=False, + ) + + # Prepare accelerator + accelerator = prepare_accelerator(args) + + # Check for dataloader_config + assert hasattr(accelerator, "dataloader_config"), "Accelerator should have dataloader_config when pin_memory is enabled" + assert accelerator.dataloader_config.non_blocking is True, "Dataloader should be configured with pin_memory"