Update pin_memory tests to use DataLoader

This commit is contained in:
rockerBOO
2025-06-16 18:08:19 -04:00
parent 95e260fb99
commit 81df559406

View File

@@ -18,54 +18,13 @@ def test_pin_memory_argument():
assert hasattr(args, "pin_memory"), "pin_memory argument should be present in argument parser"
assert args.pin_memory is False, "pin_memory should default to False"
def test_image_info_pin_memory():
"""
Test pin_memory method in ImageInfo class
"""
from library.train_util import ImageInfo
# Create an ImageInfo instance with mock tensors
image_info = ImageInfo(
image_key='test_key',
num_repeats=1,
caption='test caption',
is_reg=False,
absolute_path='/test/path'
)
# Add mock tensors that can track pinning
class MockTensor:
def __init__(self):
self.pinned = False
def pin_memory(self):
self.pinned = True
return self
# Set mock tensors
image_info.latents = MockTensor()
image_info.text_encoder_outputs1 = MockTensor()
image_info.text_encoder_outputs2 = MockTensor()
image_info.text_encoder_pool2 = MockTensor()
image_info.alpha_mask = MockTensor()
# Call pin_memory
pinned_image_info = image_info.pin_memory()
# Verify all tensors are pinned
assert pinned_image_info.latents.pinned, "Latents should be pinned"
assert pinned_image_info.text_encoder_outputs1.pinned, "Text encoder outputs1 should be pinned"
assert pinned_image_info.text_encoder_outputs2.pinned, "Text encoder outputs2 should be pinned"
assert pinned_image_info.text_encoder_pool2.pinned, "Text encoder pool2 should be pinned"
assert pinned_image_info.alpha_mask.pinned, "Alpha mask should be pinned"
def test_dreambooth_dataset_pin_memory():
"""
Test pin_memory method in DreamBoothDataset
Test pin_memory functionality using a simple mock dataset
"""
from library.train_util import DreamBoothDataset, DreamBoothSubset
from library.train_util import DreamBoothDataset, DreamBoothSubset, collator_class
# Create a mock DreamBoothSubset with default arguments
# Create a mock DreamBoothSubset with minimal arguments
def create_mock_subset():
return DreamBoothSubset(
image_dir='/mock/path',
@@ -94,140 +53,59 @@ def test_dreambooth_dataset_pin_memory():
cache_info=False
)
# Create a mock DreamBoothDataset
class MockDreamBoothDataset(DreamBoothDataset):
# Create a simplified mock dataset
class SimpleMockDataset(torch.utils.data.Dataset):
def __init__(self):
# Prepare subset
subsets = [create_mock_subset()]
# Call parent constructor with minimal required arguments
super().__init__(
subsets=subsets,
is_training_dataset=True,
batch_size=1,
resolution=(512, 512),
network_multiplier=1.0,
enable_bucket=False,
min_bucket_reso=None,
max_bucket_reso=None,
bucket_reso_steps=None,
bucket_no_upscale=False,
prior_loss_weight=1.0,
debug_dataset=False,
validation_split=0.0,
validation_seed=None,
resize_interpolation=None
)
# Add mock image data for pin_memory testing
self.image_data = {
'mock_tensor1': self._create_mock_tensor(),
'mock_tensor2': self._create_mock_tensor()
}
self.data = [torch.randn(64, 64) for _ in range(4)]
def _create_mock_tensor(self):
class MockTensor:
def __init__(self):
self.pinned = False
def pin_memory(self):
self.pinned = True
return self
return MockTensor()
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
# Create dataset
dataset = MockDreamBoothDataset()
# Create a DataLoader to test pin_memory
dataloader = torch.utils.data.DataLoader(
SimpleMockDataset(),
batch_size=2,
num_workers=0, # Use 0 to avoid multiprocessing overhead
pin_memory=True
)
# Verify initial state
for tensor in dataset.image_data.values():
assert not hasattr(tensor, 'pinned') or not tensor.pinned, "Tensors should not be pinned initially"
# Verify pin_memory works correctly
for batch in dataloader:
assert all(tensor.is_pinned() for tensor in batch), "All tensors should be pinned"
break
# Call pin_memory
dataset.pin_memory()
# Verify all tensors are pinned
for tensor in dataset.image_data.values():
assert tensor.pinned, "All tensors in image_data should be pinned"
def test_collator_pin_memory_method():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_pin_memory_cuda_transfer():
"""
Test that collator correctly calls pin_memory on the dataset
Test pin_memory functionality for CUDA tensor transfer
"""
from library.train_util import collator_class, DreamBoothDataset, DreamBoothSubset
# Create a mock dataset that tracks pin_memory calls
class MockPinMemoryDataset(DreamBoothDataset):
# Create a simple dataset
class SimpleCUDADataset(torch.utils.data.Dataset):
def __init__(self):
# Prepare subset
def create_mock_subset():
return DreamBoothSubset(
image_dir='/mock/path',
is_reg=False,
class_tokens='test_token',
caption_extension='.txt',
alpha_mask=False,
num_repeats=1,
shuffle_caption=False,
caption_separator=',',
keep_tokens=0,
keep_tokens_separator='',
secondary_separator='',
enable_wildcard=False,
color_aug=False,
flip_aug=False,
face_crop_aug_range=None,
random_crop=False,
caption_dropout_rate=0,
caption_dropout_every_n_epochs=0,
caption_tag_dropout_rate=0,
caption_prefix='',
caption_suffix='',
token_warmup_min=1,
token_warmup_step=0,
cache_info=False
)
self.data = [torch.randn(64, 64) for _ in range(4)]
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
# Prepare subset
subsets = [create_mock_subset()]
# Create a DataLoader with pin_memory enabled
dataloader = torch.utils.data.DataLoader(
SimpleCUDADataset(),
batch_size=2,
num_workers=0, # Use 0 to avoid multiprocessing overhead
pin_memory=True
)
# Call parent constructor with minimal required arguments
super().__init__(
subsets=subsets,
is_training_dataset=True,
batch_size=1,
resolution=(512, 512),
network_multiplier=1.0,
enable_bucket=False,
min_bucket_reso=None,
max_bucket_reso=None,
bucket_reso_steps=None,
bucket_no_upscale=False,
prior_loss_weight=1.0,
debug_dataset=False,
validation_split=0.0,
validation_seed=None,
resize_interpolation=None
)
self.pin_memory_called = False
def pin_memory(self):
self.pin_memory_called = True
return self
# Create a multiprocessing manager for current epoch and step
mp_manager = multiprocessing.Manager()
current_epoch = mp_manager.Value('i', 0)
current_step = mp_manager.Value('i', 0)
# Create a dataset and collator
dataset = MockPinMemoryDataset()
collator = collator_class(current_epoch, current_step, dataset)
# Call pin_memory on the collator
collator.pin_memory()
# Verify pin_memory was called on the dataset
assert dataset.pin_memory_called, "Collator should call pin_memory on the dataset"
# Verify CUDA transfer works with pinned memory
for batch in dataloader:
cuda_batch = [tensor.to('cuda', non_blocking=True) for tensor in batch]
assert all(tensor.is_pinned() for tensor in batch), "All tensors should be pinned"
break
def test_training_scripts_pin_memory_support():
"""
@@ -266,7 +144,7 @@ def test_accelerator_pin_memory_config():
accelerate_version = importlib.metadata.version("accelerate")
print(f"Accelerate library version: {accelerate_version}")
except importlib.metadata.PackageNotFoundError:
pytest.fail("Accelerate library not installed")
pytest.skip("Accelerate library not installed")
# Minimal args to pass initial checks
args = argparse.Namespace(
@@ -292,4 +170,4 @@ def test_accelerator_pin_memory_config():
# Check for dataloader_config
assert hasattr(accelerator, "dataloader_config"), "Accelerator should have dataloader_config when pin_memory is enabled"
assert accelerator.dataloader_config.non_blocking is True, "Dataloader should be configured with pin_memory"
assert accelerator.dataloader_config.non_blocking is True, "Dataloader should be configured with pin_memory"