mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 08:52:45 +00:00
173 lines
5.6 KiB
Python
173 lines
5.6 KiB
Python
import argparse
|
|
import pytest
|
|
import torch
|
|
import importlib.metadata
|
|
import multiprocessing
|
|
|
|
def test_pin_memory_argument():
|
|
"""
|
|
Test that the pin_memory argument is correctly added to argument parsers
|
|
"""
|
|
from library.train_util import add_training_arguments
|
|
|
|
parser = argparse.ArgumentParser()
|
|
add_training_arguments(parser, support_dreambooth=True)
|
|
|
|
# Parse an empty list of arguments to check the default
|
|
args = parser.parse_args([])
|
|
assert hasattr(args, "pin_memory"), "pin_memory argument should be present in argument parser"
|
|
assert args.pin_memory is False, "pin_memory should default to False"
|
|
|
|
def test_dreambooth_dataset_pin_memory():
|
|
"""
|
|
Test pin_memory functionality using a simple mock dataset
|
|
"""
|
|
from library.train_util import DreamBoothDataset, DreamBoothSubset, collator_class
|
|
|
|
# Create a mock DreamBoothSubset with minimal arguments
|
|
def create_mock_subset():
|
|
return DreamBoothSubset(
|
|
image_dir='/mock/path',
|
|
is_reg=False,
|
|
class_tokens='test_token',
|
|
caption_extension='.txt',
|
|
alpha_mask=False,
|
|
num_repeats=1,
|
|
shuffle_caption=False,
|
|
caption_separator=',',
|
|
keep_tokens=0,
|
|
keep_tokens_separator='',
|
|
secondary_separator='',
|
|
enable_wildcard=False,
|
|
color_aug=False,
|
|
flip_aug=False,
|
|
face_crop_aug_range=None,
|
|
random_crop=False,
|
|
caption_dropout_rate=0,
|
|
caption_dropout_every_n_epochs=0,
|
|
caption_tag_dropout_rate=0,
|
|
caption_prefix='',
|
|
caption_suffix='',
|
|
token_warmup_min=1,
|
|
token_warmup_step=0,
|
|
cache_info=False
|
|
)
|
|
|
|
# Create a simplified mock dataset
|
|
class SimpleMockDataset(torch.utils.data.Dataset):
|
|
def __init__(self):
|
|
self.data = [torch.randn(64, 64) for _ in range(4)]
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, idx):
|
|
return self.data[idx]
|
|
|
|
# Create a DataLoader to test pin_memory
|
|
dataloader = torch.utils.data.DataLoader(
|
|
SimpleMockDataset(),
|
|
batch_size=2,
|
|
num_workers=0, # Use 0 to avoid multiprocessing overhead
|
|
pin_memory=True
|
|
)
|
|
|
|
# Verify pin_memory works correctly
|
|
for batch in dataloader:
|
|
assert all(tensor.is_pinned() for tensor in batch), "All tensors should be pinned"
|
|
break
|
|
|
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
|
def test_pin_memory_cuda_transfer():
|
|
"""
|
|
Test pin_memory functionality for CUDA tensor transfer
|
|
"""
|
|
# Create a simple dataset
|
|
class SimpleCUDADataset(torch.utils.data.Dataset):
|
|
def __init__(self):
|
|
self.data = [torch.randn(64, 64) for _ in range(4)]
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, idx):
|
|
return self.data[idx]
|
|
|
|
# Create a DataLoader with pin_memory enabled
|
|
dataloader = torch.utils.data.DataLoader(
|
|
SimpleCUDADataset(),
|
|
batch_size=2,
|
|
num_workers=0, # Use 0 to avoid multiprocessing overhead
|
|
pin_memory=True
|
|
)
|
|
|
|
# Verify CUDA transfer works with pinned memory
|
|
for batch in dataloader:
|
|
cuda_batch = [tensor.to('cuda', non_blocking=True) for tensor in batch]
|
|
assert all(tensor.is_pinned() for tensor in batch), "All tensors should be pinned"
|
|
break
|
|
|
|
def test_training_scripts_pin_memory_support():
|
|
"""
|
|
Verify that multiple training scripts support pin_memory argument
|
|
"""
|
|
training_scripts = [
|
|
"fine_tune.py",
|
|
"flux_train.py",
|
|
"sd3_train.py",
|
|
"sdxl_train.py",
|
|
"train_network.py",
|
|
"train_textual_inversion.py",
|
|
"sdxl_train_control_net.py",
|
|
"flux_train_control_net.py",
|
|
]
|
|
|
|
from library.train_util import add_training_arguments
|
|
|
|
for script in training_scripts:
|
|
parser = argparse.ArgumentParser()
|
|
add_training_arguments(parser, support_dreambooth=True)
|
|
|
|
# Parse arguments to check pin_memory
|
|
args = parser.parse_args([])
|
|
assert hasattr(args, "pin_memory"), f"{script} should have pin_memory argument"
|
|
|
|
def test_accelerator_pin_memory_config():
|
|
"""
|
|
Test that the Accelerator is configured with pin_memory option
|
|
Checks compatibility and configuration based on Accelerate library version
|
|
"""
|
|
from library.train_util import prepare_accelerator
|
|
|
|
# Check Accelerate library version
|
|
try:
|
|
accelerate_version = importlib.metadata.version("accelerate")
|
|
print(f"Accelerate library version: {accelerate_version}")
|
|
except importlib.metadata.PackageNotFoundError:
|
|
pytest.skip("Accelerate library not installed")
|
|
|
|
# Minimal args to pass initial checks
|
|
args = argparse.Namespace(
|
|
gradient_accumulation_steps=1,
|
|
mixed_precision="no",
|
|
log_with=None,
|
|
kwargs_handlers=[],
|
|
deepspeed_plugin=None,
|
|
pin_memory=True,
|
|
logging_dir=None,
|
|
torch_compile=False,
|
|
log_prefix=None,
|
|
ddp_gradient_as_bucket_view=False,
|
|
ddp_static_graph=False,
|
|
ddp_timeout=None,
|
|
wandb_api_key=None,
|
|
dynamo_backend="NO",
|
|
deepspeed=False,
|
|
)
|
|
|
|
# Prepare accelerator
|
|
accelerator = prepare_accelerator(args)
|
|
|
|
# Check for dataloader_config
|
|
assert hasattr(accelerator, "dataloader_config"), "Accelerator should have dataloader_config when pin_memory is enabled"
|
|
assert accelerator.dataloader_config.non_blocking is True, "Dataloader should be configured with pin_memory" |