mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 17:02:45 +00:00
296 lines
9.9 KiB
Python
296 lines
9.9 KiB
Python
import argparse
|
|
import pytest
|
|
import torch
|
|
import importlib.metadata
|
|
import multiprocessing
|
|
|
|
def test_pin_memory_argument():
|
|
"""
|
|
Test that the pin_memory argument is correctly added to argument parsers
|
|
"""
|
|
from library.train_util import add_training_arguments
|
|
|
|
parser = argparse.ArgumentParser()
|
|
add_training_arguments(parser, support_dreambooth=True)
|
|
|
|
# Parse an empty list of arguments to check the default
|
|
args = parser.parse_args([])
|
|
assert hasattr(args, "pin_memory"), "pin_memory argument should be present in argument parser"
|
|
assert args.pin_memory is False, "pin_memory should default to False"
|
|
|
|
def test_image_info_pin_memory():
|
|
"""
|
|
Test pin_memory method in ImageInfo class
|
|
"""
|
|
from library.train_util import ImageInfo
|
|
|
|
# Create an ImageInfo instance with mock tensors
|
|
image_info = ImageInfo(
|
|
image_key='test_key',
|
|
num_repeats=1,
|
|
caption='test caption',
|
|
is_reg=False,
|
|
absolute_path='/test/path'
|
|
)
|
|
|
|
# Add mock tensors that can track pinning
|
|
class MockTensor:
|
|
def __init__(self):
|
|
self.pinned = False
|
|
|
|
def pin_memory(self):
|
|
self.pinned = True
|
|
return self
|
|
|
|
# Set mock tensors
|
|
image_info.latents = MockTensor()
|
|
image_info.text_encoder_outputs1 = MockTensor()
|
|
image_info.text_encoder_outputs2 = MockTensor()
|
|
image_info.text_encoder_pool2 = MockTensor()
|
|
image_info.alpha_mask = MockTensor()
|
|
|
|
# Call pin_memory
|
|
pinned_image_info = image_info.pin_memory()
|
|
|
|
# Verify all tensors are pinned
|
|
assert pinned_image_info.latents.pinned, "Latents should be pinned"
|
|
assert pinned_image_info.text_encoder_outputs1.pinned, "Text encoder outputs1 should be pinned"
|
|
assert pinned_image_info.text_encoder_outputs2.pinned, "Text encoder outputs2 should be pinned"
|
|
assert pinned_image_info.text_encoder_pool2.pinned, "Text encoder pool2 should be pinned"
|
|
assert pinned_image_info.alpha_mask.pinned, "Alpha mask should be pinned"
|
|
|
|
def test_dreambooth_dataset_pin_memory():
|
|
"""
|
|
Test pin_memory method in DreamBoothDataset
|
|
"""
|
|
from library.train_util import DreamBoothDataset, DreamBoothSubset
|
|
|
|
# Create a mock DreamBoothSubset with default arguments
|
|
def create_mock_subset():
|
|
return DreamBoothSubset(
|
|
image_dir='/mock/path',
|
|
is_reg=False,
|
|
class_tokens='test_token',
|
|
caption_extension='.txt',
|
|
alpha_mask=False,
|
|
num_repeats=1,
|
|
shuffle_caption=False,
|
|
caption_separator=',',
|
|
keep_tokens=0,
|
|
keep_tokens_separator='',
|
|
secondary_separator='',
|
|
enable_wildcard=False,
|
|
color_aug=False,
|
|
flip_aug=False,
|
|
face_crop_aug_range=None,
|
|
random_crop=False,
|
|
caption_dropout_rate=0,
|
|
caption_dropout_every_n_epochs=0,
|
|
caption_tag_dropout_rate=0,
|
|
caption_prefix='',
|
|
caption_suffix='',
|
|
token_warmup_min=1,
|
|
token_warmup_step=0,
|
|
cache_info=False
|
|
)
|
|
|
|
# Create a mock DreamBoothDataset
|
|
class MockDreamBoothDataset(DreamBoothDataset):
|
|
def __init__(self):
|
|
# Prepare subset
|
|
subsets = [create_mock_subset()]
|
|
|
|
# Call parent constructor with minimal required arguments
|
|
super().__init__(
|
|
subsets=subsets,
|
|
is_training_dataset=True,
|
|
batch_size=1,
|
|
resolution=(512, 512),
|
|
network_multiplier=1.0,
|
|
enable_bucket=False,
|
|
min_bucket_reso=None,
|
|
max_bucket_reso=None,
|
|
bucket_reso_steps=None,
|
|
bucket_no_upscale=False,
|
|
prior_loss_weight=1.0,
|
|
debug_dataset=False,
|
|
validation_split=0.0,
|
|
validation_seed=None,
|
|
resize_interpolation=None
|
|
)
|
|
|
|
# Add mock image data for pin_memory testing
|
|
self.image_data = {
|
|
'mock_tensor1': self._create_mock_tensor(),
|
|
'mock_tensor2': self._create_mock_tensor()
|
|
}
|
|
|
|
def _create_mock_tensor(self):
|
|
class MockTensor:
|
|
def __init__(self):
|
|
self.pinned = False
|
|
|
|
def pin_memory(self):
|
|
self.pinned = True
|
|
return self
|
|
return MockTensor()
|
|
|
|
# Create dataset
|
|
dataset = MockDreamBoothDataset()
|
|
|
|
# Verify initial state
|
|
for tensor in dataset.image_data.values():
|
|
assert not hasattr(tensor, 'pinned') or not tensor.pinned, "Tensors should not be pinned initially"
|
|
|
|
# Call pin_memory
|
|
dataset.pin_memory()
|
|
|
|
# Verify all tensors are pinned
|
|
for tensor in dataset.image_data.values():
|
|
assert tensor.pinned, "All tensors in image_data should be pinned"
|
|
|
|
def test_collator_pin_memory_method():
|
|
"""
|
|
Test that collator correctly calls pin_memory on the dataset
|
|
"""
|
|
from library.train_util import collator_class, DreamBoothDataset, DreamBoothSubset
|
|
|
|
# Create a mock dataset that tracks pin_memory calls
|
|
class MockPinMemoryDataset(DreamBoothDataset):
|
|
def __init__(self):
|
|
# Prepare subset
|
|
def create_mock_subset():
|
|
return DreamBoothSubset(
|
|
image_dir='/mock/path',
|
|
is_reg=False,
|
|
class_tokens='test_token',
|
|
caption_extension='.txt',
|
|
alpha_mask=False,
|
|
num_repeats=1,
|
|
shuffle_caption=False,
|
|
caption_separator=',',
|
|
keep_tokens=0,
|
|
keep_tokens_separator='',
|
|
secondary_separator='',
|
|
enable_wildcard=False,
|
|
color_aug=False,
|
|
flip_aug=False,
|
|
face_crop_aug_range=None,
|
|
random_crop=False,
|
|
caption_dropout_rate=0,
|
|
caption_dropout_every_n_epochs=0,
|
|
caption_tag_dropout_rate=0,
|
|
caption_prefix='',
|
|
caption_suffix='',
|
|
token_warmup_min=1,
|
|
token_warmup_step=0,
|
|
cache_info=False
|
|
)
|
|
|
|
# Prepare subset
|
|
subsets = [create_mock_subset()]
|
|
|
|
# Call parent constructor with minimal required arguments
|
|
super().__init__(
|
|
subsets=subsets,
|
|
is_training_dataset=True,
|
|
batch_size=1,
|
|
resolution=(512, 512),
|
|
network_multiplier=1.0,
|
|
enable_bucket=False,
|
|
min_bucket_reso=None,
|
|
max_bucket_reso=None,
|
|
bucket_reso_steps=None,
|
|
bucket_no_upscale=False,
|
|
prior_loss_weight=1.0,
|
|
debug_dataset=False,
|
|
validation_split=0.0,
|
|
validation_seed=None,
|
|
resize_interpolation=None
|
|
)
|
|
self.pin_memory_called = False
|
|
|
|
def pin_memory(self):
|
|
self.pin_memory_called = True
|
|
return self
|
|
|
|
# Create a multiprocessing manager for current epoch and step
|
|
mp_manager = multiprocessing.Manager()
|
|
current_epoch = mp_manager.Value('i', 0)
|
|
current_step = mp_manager.Value('i', 0)
|
|
|
|
# Create a dataset and collator
|
|
dataset = MockPinMemoryDataset()
|
|
collator = collator_class(current_epoch, current_step, dataset)
|
|
|
|
# Call pin_memory on the collator
|
|
collator.pin_memory()
|
|
|
|
# Verify pin_memory was called on the dataset
|
|
assert dataset.pin_memory_called, "Collator should call pin_memory on the dataset"
|
|
|
|
def test_training_scripts_pin_memory_support():
|
|
"""
|
|
Verify that multiple training scripts support pin_memory argument
|
|
"""
|
|
training_scripts = [
|
|
"fine_tune.py",
|
|
"flux_train.py",
|
|
"sd3_train.py",
|
|
"sdxl_train.py",
|
|
"train_network.py",
|
|
"train_textual_inversion.py",
|
|
"sdxl_train_control_net.py",
|
|
"flux_train_control_net.py",
|
|
]
|
|
|
|
from library.train_util import add_training_arguments
|
|
|
|
for script in training_scripts:
|
|
parser = argparse.ArgumentParser()
|
|
add_training_arguments(parser, support_dreambooth=True)
|
|
|
|
# Parse arguments to check pin_memory
|
|
args = parser.parse_args([])
|
|
assert hasattr(args, "pin_memory"), f"{script} should have pin_memory argument"
|
|
|
|
def test_accelerator_pin_memory_config():
|
|
"""
|
|
Test that the Accelerator is configured with pin_memory option
|
|
Checks compatibility and configuration based on Accelerate library version
|
|
"""
|
|
from library.train_util import prepare_accelerator
|
|
|
|
# Check Accelerate library version
|
|
try:
|
|
accelerate_version = importlib.metadata.version("accelerate")
|
|
print(f"Accelerate library version: {accelerate_version}")
|
|
except importlib.metadata.PackageNotFoundError:
|
|
pytest.fail("Accelerate library not installed")
|
|
|
|
# Minimal args to pass initial checks
|
|
args = argparse.Namespace(
|
|
gradient_accumulation_steps=1,
|
|
mixed_precision="no",
|
|
log_with=None,
|
|
kwargs_handlers=[],
|
|
deepspeed_plugin=None,
|
|
pin_memory=True,
|
|
logging_dir=None,
|
|
torch_compile=False,
|
|
log_prefix=None,
|
|
ddp_gradient_as_bucket_view=False,
|
|
ddp_static_graph=False,
|
|
ddp_timeout=None,
|
|
wandb_api_key=None,
|
|
dynamo_backend="NO",
|
|
deepspeed=False,
|
|
)
|
|
|
|
# Prepare accelerator
|
|
accelerator = prepare_accelerator(args)
|
|
|
|
# Check for dataloader_config
|
|
assert hasattr(accelerator, "dataloader_config"), "Accelerator should have dataloader_config when pin_memory is enabled"
|
|
assert accelerator.dataloader_config.non_blocking is True, "Dataloader should be configured with pin_memory"
|