This commit is contained in:
Dave Lage
2025-09-28 00:24:21 +05:30
committed by GitHub
18 changed files with 242 additions and 2 deletions

View File

@@ -244,6 +244,7 @@ def train(args):
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
pin_memory=args.pin_memory,
persistent_workers=args.persistent_data_loader_workers,
)

View File

@@ -126,6 +126,7 @@ def main(args):
batch_size=args.batch_size,
shuffle=False,
num_workers=args.max_data_loader_n_workers,
pin_memory=args.pin_memory,
collate_fn=collate_fn_remove_corrupted,
drop_last=False,
)
@@ -187,6 +188,11 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する読み込みを高速化",
)
parser.add_argument(
"--pin_memory",
action="store_true",
help="Pin memory for faster GPU loading / GPU の読み込みを高速化するためのピンメモリ",
)
parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数多いと精度が上がるが時間がかかる")
parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p")
parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長")

View File

@@ -113,6 +113,7 @@ def main(args):
dataset,
batch_size=args.batch_size,
shuffle=False,
pin_memory=args.pin_memory,
num_workers=args.max_data_loader_n_workers,
collate_fn=collate_fn_remove_corrupted,
drop_last=False,
@@ -164,6 +165,11 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する読み込みを高速化",
)
parser.add_argument(
"--pin_memory",
action="store_true",
help="Pin memory for faster GPU loading / GPU の読み込みを高速化するためのピンメモリ",
)
parser.add_argument("--max_length", type=int, default=50, help="max length of caption / captionの最大長")
parser.add_argument(
"--remove_words",

View File

@@ -122,6 +122,7 @@ def main(args):
dataset,
batch_size=1,
shuffle=False,
pin_memory=args.pin_memory,
num_workers=args.max_data_loader_n_workers,
collate_fn=collate_fn_remove_corrupted,
drop_last=False,
@@ -223,6 +224,11 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する読み込みを高速化",
)
parser.add_argument(
"--pin_memory",
action="store_true",
help="Pin memory for faster GPU loading / GPU の読み込みを高速化するためのピンメモリ",
)
parser.add_argument(
"--max_resolution",
type=str,

View File

@@ -336,6 +336,7 @@ def main(args):
dataset,
batch_size=args.batch_size,
shuffle=False,
pin_memory=args.pin_memory,
num_workers=args.max_data_loader_n_workers,
collate_fn=collate_fn_remove_corrupted,
drop_last=False,
@@ -410,6 +411,11 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する読み込みを高速化",
)
parser.add_argument(
"--pin_memory",
action="store_true",
help="Pin memory for faster GPU loading / GPU の読み込みを高速化するためのピンメモリ",
)
parser.add_argument(
"--caption_extention",
type=str,

View File

@@ -398,6 +398,7 @@ def train(args):
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
pin_memory=args.pin_memory,
persistent_workers=args.persistent_data_loader_workers,
)

View File

@@ -405,6 +405,7 @@ def train(args):
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
pin_memory=args.pin_memory,
persistent_workers=args.persistent_data_loader_workers,
)

View File

@@ -12,9 +12,8 @@ import pathlib
import re
import shutil
import time
import typing
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState, DataLoaderConfiguration
import glob
import math
import os
@@ -209,6 +208,19 @@ class ImageInfo:
self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
self.resize_interpolation: Optional[str] = None
@staticmethod
def _pin_tensor(tensor):
return tensor.pin_memory() if tensor is not None else tensor
def pin_memory(self):
self.latents = self._pin_tensor(self.latents)
self.latents_flipped = self._pin_tensor(self.latents_flipped)
self.text_encoder_outputs1 = self._pin_tensor(self.text_encoder_outputs1)
self.text_encoder_outputs2 = self._pin_tensor(self.text_encoder_outputs2)
self.text_encoder_pool2 = self._pin_tensor(self.text_encoder_pool2)
self.alpha_mask = self._pin_tensor(self.alpha_mask)
return self
class BucketManager:
def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None:
@@ -2181,6 +2193,11 @@ class DreamBoothDataset(BaseDataset):
self.num_reg_images = num_reg_images
def pin_memory(self):
for key in self.image_data.keys():
if hasattr(self.image_data[key], 'pin_memory') and callable(self.image_data[key].pin_memory):
self.image_data[key].pin_memory()
class FineTuningDataset(BaseDataset):
def __init__(
@@ -4005,6 +4022,11 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
action="store_true",
help="persistent DataLoader workers (useful for reduce time gap between epoch, but may use more memory) / DataLoader のワーカーを持続させる (エポック間の時間差を少なくするのに有効だが、より多くのメモリを消費する可能性がある)",
)
parser.add_argument(
"--pin_memory",
action="store_true",
help="Pin memory for faster GPU loading / GPU の読み込みを高速化するためのピンメモリ",
)
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
parser.add_argument(
"--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / gradient checkpointingを有効にする"
@@ -5529,6 +5551,8 @@ def prepare_accelerator(args: argparse.Namespace):
kwargs_handlers = [i for i in kwargs_handlers if i is not None]
deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args)
dataloader_config = DataLoaderConfiguration(non_blocking=args.pin_memory)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
@@ -5537,6 +5561,7 @@ def prepare_accelerator(args: argparse.Namespace):
kwargs_handlers=kwargs_handlers,
dynamo_backend=dynamo_backend,
deepspeed_plugin=deepspeed_plugin,
dataloader_config=dataloader_config
)
print("accelerator device:", accelerator.device)
return accelerator
@@ -6700,6 +6725,10 @@ class collator_class:
dataset.set_current_step(self.current_step.value)
return examples[0]
def pin_memory(self):
if hasattr(self, 'pin_memory') and callable(self.pin_memory):
self.dataset.pin_memory()
class LossRecorder:
def __init__(self):

View File

@@ -502,6 +502,7 @@ def train(args):
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
pin_memory=args.pin_memory,
persistent_workers=args.persistent_data_loader_workers,
)

View File

@@ -431,6 +431,7 @@ def train(args):
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
pin_memory=args.pin_memory,
persistent_workers=args.persistent_data_loader_workers,
)

View File

@@ -282,6 +282,7 @@ def train(args):
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
pin_memory=args.pin_memory,
persistent_workers=args.persistent_data_loader_workers,
)

View File

@@ -273,6 +273,7 @@ def train(args):
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
pin_memory=args.pin_memory,
persistent_workers=args.persistent_data_loader_workers,
)

View File

@@ -221,6 +221,7 @@ def train(args):
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
pin_memory=args.pin_memory,
persistent_workers=args.persistent_data_loader_workers,
)

175
tests/test_pin_memory.py Normal file
View File

@@ -0,0 +1,175 @@
import argparse
import pytest
import torch
import importlib.metadata
import multiprocessing
def test_pin_memory_argument():
"""
Test that the pin_memory argument is correctly added to argument parsers
"""
from library.train_util import add_training_arguments
parser = argparse.ArgumentParser()
add_training_arguments(parser, support_dreambooth=True)
# Parse an empty list of arguments to check the default
args = parser.parse_args([])
assert hasattr(args, "pin_memory"), "pin_memory argument should be present in argument parser"
assert args.pin_memory is False, "pin_memory should default to False"
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_dreambooth_dataset_pin_memory():
"""
Test pin_memory functionality using a simple mock dataset
"""
from library.train_util import DreamBoothDataset, DreamBoothSubset, collator_class
# Create a mock DreamBoothSubset with minimal arguments
def create_mock_subset():
return DreamBoothSubset(
image_dir='/mock/path',
is_reg=False,
class_tokens='test_token',
caption_extension='.txt',
alpha_mask=False,
num_repeats=1,
shuffle_caption=False,
caption_separator=',',
keep_tokens=0,
keep_tokens_separator='',
secondary_separator='',
enable_wildcard=False,
color_aug=False,
flip_aug=False,
face_crop_aug_range=None,
random_crop=False,
caption_dropout_rate=0,
caption_dropout_every_n_epochs=0,
caption_tag_dropout_rate=0,
caption_prefix='',
caption_suffix='',
token_warmup_min=1,
token_warmup_step=0,
cache_info=False
)
# Create a simplified mock dataset
class SimpleMockDataset(torch.utils.data.Dataset):
def __init__(self):
self.data = [torch.randn(64, 64) for _ in range(4)]
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
# Create a DataLoader to test pin_memory
dataloader = torch.utils.data.DataLoader(
SimpleMockDataset(),
batch_size=2,
num_workers=0, # Use 0 to avoid multiprocessing overhead
pin_memory=True
)
# Verify pin_memory works correctly
for batch in dataloader:
# Pinning only works when CUDA is available
assert all(tensor.is_pinned() for tensor in batch), "All tensors should be pinned"
break
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_pin_memory_cuda_transfer():
"""
Test pin_memory functionality for CUDA tensor transfer
"""
# Create a simple dataset
class SimpleCUDADataset(torch.utils.data.Dataset):
def __init__(self):
self.data = [torch.randn(64, 64) for _ in range(4)]
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
# Create a DataLoader with pin_memory enabled
dataloader = torch.utils.data.DataLoader(
SimpleCUDADataset(),
batch_size=2,
num_workers=0, # Use 0 to avoid multiprocessing overhead
pin_memory=True
)
# Verify CUDA transfer works with pinned memory
for batch in dataloader:
cuda_batch = [tensor.to('cuda', non_blocking=True) for tensor in batch]
assert all(tensor.is_pinned() for tensor in batch), "All tensors should be pinned"
break
def test_training_scripts_pin_memory_support():
"""
Verify that multiple training scripts support pin_memory argument
"""
training_scripts = [
"fine_tune.py",
"flux_train.py",
"sd3_train.py",
"sdxl_train.py",
"train_network.py",
"train_textual_inversion.py",
"sdxl_train_control_net.py",
"flux_train_control_net.py",
]
from library.train_util import add_training_arguments
for script in training_scripts:
parser = argparse.ArgumentParser()
add_training_arguments(parser, support_dreambooth=True)
# Parse arguments to check pin_memory
args = parser.parse_args([])
assert hasattr(args, "pin_memory"), f"{script} should have pin_memory argument"
def test_accelerator_pin_memory_config():
"""
Test that the Accelerator is configured with pin_memory option
Checks compatibility and configuration based on Accelerate library version
"""
from library.train_util import prepare_accelerator
# Check Accelerate library version
try:
accelerate_version = importlib.metadata.version("accelerate")
print(f"Accelerate library version: {accelerate_version}")
except importlib.metadata.PackageNotFoundError:
pytest.skip("Accelerate library not installed")
# Minimal args to pass initial checks
args = argparse.Namespace(
gradient_accumulation_steps=1,
mixed_precision="no",
log_with=None,
kwargs_handlers=[],
deepspeed_plugin=None,
pin_memory=True,
logging_dir=None,
torch_compile=False,
log_prefix=None,
ddp_gradient_as_bucket_view=False,
ddp_static_graph=False,
ddp_timeout=None,
wandb_api_key=None,
dynamo_backend="NO",
deepspeed=False,
)
# Prepare accelerator
accelerator = prepare_accelerator(args)
# Check for dataloader_config
assert hasattr(accelerator, "dataloader_config"), "Accelerator should have dataloader_config when pin_memory is enabled"
assert accelerator.dataloader_config.non_blocking is True, "Dataloader should be configured with pin_memory"

View File

@@ -212,6 +212,7 @@ def train(args):
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
pin_memory=args.pin_memory,
persistent_workers=args.persistent_data_loader_workers,
)

View File

@@ -783,6 +783,7 @@ class NetworkTrainer:
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
pin_memory=args.pin_memory,
persistent_workers=args.persistent_data_loader_workers,
)

View File

@@ -412,6 +412,7 @@ class TextualInversionTrainer:
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
pin_memory=args.pin_memory,
persistent_workers=args.persistent_data_loader_workers,
)

View File

@@ -317,6 +317,7 @@ def train(args):
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
pin_memory=args.pin_memory,
persistent_workers=args.persistent_data_loader_workers,
)