mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 00:49:40 +00:00
Merge d73f2db22c into a5a162044c
This commit is contained in:
@@ -244,6 +244,7 @@ def train(args):
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
pin_memory=args.pin_memory,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
|
||||
@@ -126,6 +126,7 @@ def main(args):
|
||||
batch_size=args.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=args.max_data_loader_n_workers,
|
||||
pin_memory=args.pin_memory,
|
||||
collate_fn=collate_fn_remove_corrupted,
|
||||
drop_last=False,
|
||||
)
|
||||
@@ -187,6 +188,11 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
default=None,
|
||||
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pin_memory",
|
||||
action="store_true",
|
||||
help="Pin memory for faster GPU loading / GPU の読み込みを高速化するためのピンメモリ",
|
||||
)
|
||||
parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数(多いと精度が上がるが時間がかかる)")
|
||||
parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p")
|
||||
parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長")
|
||||
|
||||
@@ -113,6 +113,7 @@ def main(args):
|
||||
dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=False,
|
||||
pin_memory=args.pin_memory,
|
||||
num_workers=args.max_data_loader_n_workers,
|
||||
collate_fn=collate_fn_remove_corrupted,
|
||||
drop_last=False,
|
||||
@@ -164,6 +165,11 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
default=None,
|
||||
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pin_memory",
|
||||
action="store_true",
|
||||
help="Pin memory for faster GPU loading / GPU の読み込みを高速化するためのピンメモリ",
|
||||
)
|
||||
parser.add_argument("--max_length", type=int, default=50, help="max length of caption / captionの最大長")
|
||||
parser.add_argument(
|
||||
"--remove_words",
|
||||
|
||||
@@ -122,6 +122,7 @@ def main(args):
|
||||
dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
pin_memory=args.pin_memory,
|
||||
num_workers=args.max_data_loader_n_workers,
|
||||
collate_fn=collate_fn_remove_corrupted,
|
||||
drop_last=False,
|
||||
@@ -223,6 +224,11 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
default=None,
|
||||
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pin_memory",
|
||||
action="store_true",
|
||||
help="Pin memory for faster GPU loading / GPU の読み込みを高速化するためのピンメモリ",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_resolution",
|
||||
type=str,
|
||||
|
||||
@@ -336,6 +336,7 @@ def main(args):
|
||||
dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=False,
|
||||
pin_memory=args.pin_memory,
|
||||
num_workers=args.max_data_loader_n_workers,
|
||||
collate_fn=collate_fn_remove_corrupted,
|
||||
drop_last=False,
|
||||
@@ -410,6 +411,11 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
default=None,
|
||||
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pin_memory",
|
||||
action="store_true",
|
||||
help="Pin memory for faster GPU loading / GPU の読み込みを高速化するためのピンメモリ",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--caption_extention",
|
||||
type=str,
|
||||
|
||||
@@ -398,6 +398,7 @@ def train(args):
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
pin_memory=args.pin_memory,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
|
||||
@@ -405,6 +405,7 @@ def train(args):
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
pin_memory=args.pin_memory,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
|
||||
@@ -12,9 +12,8 @@ import pathlib
|
||||
import re
|
||||
import shutil
|
||||
import time
|
||||
import typing
|
||||
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
|
||||
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState
|
||||
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState, DataLoaderConfiguration
|
||||
import glob
|
||||
import math
|
||||
import os
|
||||
@@ -209,6 +208,19 @@ class ImageInfo:
|
||||
self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
|
||||
self.resize_interpolation: Optional[str] = None
|
||||
|
||||
@staticmethod
|
||||
def _pin_tensor(tensor):
|
||||
return tensor.pin_memory() if tensor is not None else tensor
|
||||
|
||||
def pin_memory(self):
|
||||
self.latents = self._pin_tensor(self.latents)
|
||||
self.latents_flipped = self._pin_tensor(self.latents_flipped)
|
||||
self.text_encoder_outputs1 = self._pin_tensor(self.text_encoder_outputs1)
|
||||
self.text_encoder_outputs2 = self._pin_tensor(self.text_encoder_outputs2)
|
||||
self.text_encoder_pool2 = self._pin_tensor(self.text_encoder_pool2)
|
||||
self.alpha_mask = self._pin_tensor(self.alpha_mask)
|
||||
return self
|
||||
|
||||
|
||||
class BucketManager:
|
||||
def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None:
|
||||
@@ -2181,6 +2193,11 @@ class DreamBoothDataset(BaseDataset):
|
||||
|
||||
self.num_reg_images = num_reg_images
|
||||
|
||||
def pin_memory(self):
|
||||
for key in self.image_data.keys():
|
||||
if hasattr(self.image_data[key], 'pin_memory') and callable(self.image_data[key].pin_memory):
|
||||
self.image_data[key].pin_memory()
|
||||
|
||||
|
||||
class FineTuningDataset(BaseDataset):
|
||||
def __init__(
|
||||
@@ -4005,6 +4022,11 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
action="store_true",
|
||||
help="persistent DataLoader workers (useful for reduce time gap between epoch, but may use more memory) / DataLoader のワーカーを持続させる (エポック間の時間差を少なくするのに有効だが、より多くのメモリを消費する可能性がある)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pin_memory",
|
||||
action="store_true",
|
||||
help="Pin memory for faster GPU loading / GPU の読み込みを高速化するためのピンメモリ",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
|
||||
parser.add_argument(
|
||||
"--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / gradient checkpointingを有効にする"
|
||||
@@ -5529,6 +5551,8 @@ def prepare_accelerator(args: argparse.Namespace):
|
||||
kwargs_handlers = [i for i in kwargs_handlers if i is not None]
|
||||
deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args)
|
||||
|
||||
dataloader_config = DataLoaderConfiguration(non_blocking=args.pin_memory)
|
||||
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
mixed_precision=args.mixed_precision,
|
||||
@@ -5537,6 +5561,7 @@ def prepare_accelerator(args: argparse.Namespace):
|
||||
kwargs_handlers=kwargs_handlers,
|
||||
dynamo_backend=dynamo_backend,
|
||||
deepspeed_plugin=deepspeed_plugin,
|
||||
dataloader_config=dataloader_config
|
||||
)
|
||||
print("accelerator device:", accelerator.device)
|
||||
return accelerator
|
||||
@@ -6700,6 +6725,10 @@ class collator_class:
|
||||
dataset.set_current_step(self.current_step.value)
|
||||
return examples[0]
|
||||
|
||||
def pin_memory(self):
|
||||
if hasattr(self, 'pin_memory') and callable(self.pin_memory):
|
||||
self.dataset.pin_memory()
|
||||
|
||||
|
||||
class LossRecorder:
|
||||
def __init__(self):
|
||||
|
||||
@@ -502,6 +502,7 @@ def train(args):
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
pin_memory=args.pin_memory,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
|
||||
@@ -431,6 +431,7 @@ def train(args):
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
pin_memory=args.pin_memory,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
|
||||
@@ -282,6 +282,7 @@ def train(args):
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
pin_memory=args.pin_memory,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
|
||||
@@ -273,6 +273,7 @@ def train(args):
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
pin_memory=args.pin_memory,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
|
||||
@@ -221,6 +221,7 @@ def train(args):
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
pin_memory=args.pin_memory,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
|
||||
175
tests/test_pin_memory.py
Normal file
175
tests/test_pin_memory.py
Normal file
@@ -0,0 +1,175 @@
|
||||
import argparse
|
||||
import pytest
|
||||
import torch
|
||||
import importlib.metadata
|
||||
import multiprocessing
|
||||
|
||||
def test_pin_memory_argument():
|
||||
"""
|
||||
Test that the pin_memory argument is correctly added to argument parsers
|
||||
"""
|
||||
from library.train_util import add_training_arguments
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
add_training_arguments(parser, support_dreambooth=True)
|
||||
|
||||
# Parse an empty list of arguments to check the default
|
||||
args = parser.parse_args([])
|
||||
assert hasattr(args, "pin_memory"), "pin_memory argument should be present in argument parser"
|
||||
assert args.pin_memory is False, "pin_memory should default to False"
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_dreambooth_dataset_pin_memory():
|
||||
"""
|
||||
Test pin_memory functionality using a simple mock dataset
|
||||
"""
|
||||
from library.train_util import DreamBoothDataset, DreamBoothSubset, collator_class
|
||||
|
||||
# Create a mock DreamBoothSubset with minimal arguments
|
||||
def create_mock_subset():
|
||||
return DreamBoothSubset(
|
||||
image_dir='/mock/path',
|
||||
is_reg=False,
|
||||
class_tokens='test_token',
|
||||
caption_extension='.txt',
|
||||
alpha_mask=False,
|
||||
num_repeats=1,
|
||||
shuffle_caption=False,
|
||||
caption_separator=',',
|
||||
keep_tokens=0,
|
||||
keep_tokens_separator='',
|
||||
secondary_separator='',
|
||||
enable_wildcard=False,
|
||||
color_aug=False,
|
||||
flip_aug=False,
|
||||
face_crop_aug_range=None,
|
||||
random_crop=False,
|
||||
caption_dropout_rate=0,
|
||||
caption_dropout_every_n_epochs=0,
|
||||
caption_tag_dropout_rate=0,
|
||||
caption_prefix='',
|
||||
caption_suffix='',
|
||||
token_warmup_min=1,
|
||||
token_warmup_step=0,
|
||||
cache_info=False
|
||||
)
|
||||
|
||||
# Create a simplified mock dataset
|
||||
class SimpleMockDataset(torch.utils.data.Dataset):
|
||||
def __init__(self):
|
||||
self.data = [torch.randn(64, 64) for _ in range(4)]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.data[idx]
|
||||
|
||||
# Create a DataLoader to test pin_memory
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
SimpleMockDataset(),
|
||||
batch_size=2,
|
||||
num_workers=0, # Use 0 to avoid multiprocessing overhead
|
||||
pin_memory=True
|
||||
)
|
||||
|
||||
# Verify pin_memory works correctly
|
||||
for batch in dataloader:
|
||||
# Pinning only works when CUDA is available
|
||||
assert all(tensor.is_pinned() for tensor in batch), "All tensors should be pinned"
|
||||
break
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_pin_memory_cuda_transfer():
|
||||
"""
|
||||
Test pin_memory functionality for CUDA tensor transfer
|
||||
"""
|
||||
# Create a simple dataset
|
||||
class SimpleCUDADataset(torch.utils.data.Dataset):
|
||||
def __init__(self):
|
||||
self.data = [torch.randn(64, 64) for _ in range(4)]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.data[idx]
|
||||
|
||||
# Create a DataLoader with pin_memory enabled
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
SimpleCUDADataset(),
|
||||
batch_size=2,
|
||||
num_workers=0, # Use 0 to avoid multiprocessing overhead
|
||||
pin_memory=True
|
||||
)
|
||||
|
||||
# Verify CUDA transfer works with pinned memory
|
||||
for batch in dataloader:
|
||||
cuda_batch = [tensor.to('cuda', non_blocking=True) for tensor in batch]
|
||||
assert all(tensor.is_pinned() for tensor in batch), "All tensors should be pinned"
|
||||
break
|
||||
|
||||
def test_training_scripts_pin_memory_support():
|
||||
"""
|
||||
Verify that multiple training scripts support pin_memory argument
|
||||
"""
|
||||
training_scripts = [
|
||||
"fine_tune.py",
|
||||
"flux_train.py",
|
||||
"sd3_train.py",
|
||||
"sdxl_train.py",
|
||||
"train_network.py",
|
||||
"train_textual_inversion.py",
|
||||
"sdxl_train_control_net.py",
|
||||
"flux_train_control_net.py",
|
||||
]
|
||||
|
||||
from library.train_util import add_training_arguments
|
||||
|
||||
for script in training_scripts:
|
||||
parser = argparse.ArgumentParser()
|
||||
add_training_arguments(parser, support_dreambooth=True)
|
||||
|
||||
# Parse arguments to check pin_memory
|
||||
args = parser.parse_args([])
|
||||
assert hasattr(args, "pin_memory"), f"{script} should have pin_memory argument"
|
||||
|
||||
def test_accelerator_pin_memory_config():
|
||||
"""
|
||||
Test that the Accelerator is configured with pin_memory option
|
||||
Checks compatibility and configuration based on Accelerate library version
|
||||
"""
|
||||
from library.train_util import prepare_accelerator
|
||||
|
||||
# Check Accelerate library version
|
||||
try:
|
||||
accelerate_version = importlib.metadata.version("accelerate")
|
||||
print(f"Accelerate library version: {accelerate_version}")
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
pytest.skip("Accelerate library not installed")
|
||||
|
||||
# Minimal args to pass initial checks
|
||||
args = argparse.Namespace(
|
||||
gradient_accumulation_steps=1,
|
||||
mixed_precision="no",
|
||||
log_with=None,
|
||||
kwargs_handlers=[],
|
||||
deepspeed_plugin=None,
|
||||
pin_memory=True,
|
||||
logging_dir=None,
|
||||
torch_compile=False,
|
||||
log_prefix=None,
|
||||
ddp_gradient_as_bucket_view=False,
|
||||
ddp_static_graph=False,
|
||||
ddp_timeout=None,
|
||||
wandb_api_key=None,
|
||||
dynamo_backend="NO",
|
||||
deepspeed=False,
|
||||
)
|
||||
|
||||
# Prepare accelerator
|
||||
accelerator = prepare_accelerator(args)
|
||||
|
||||
# Check for dataloader_config
|
||||
assert hasattr(accelerator, "dataloader_config"), "Accelerator should have dataloader_config when pin_memory is enabled"
|
||||
assert accelerator.dataloader_config.non_blocking is True, "Dataloader should be configured with pin_memory"
|
||||
@@ -212,6 +212,7 @@ def train(args):
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
pin_memory=args.pin_memory,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
|
||||
@@ -783,6 +783,7 @@ class NetworkTrainer:
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
pin_memory=args.pin_memory,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
|
||||
@@ -412,6 +412,7 @@ class TextualInversionTrainer:
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
pin_memory=args.pin_memory,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
|
||||
@@ -317,6 +317,7 @@ def train(args):
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
pin_memory=args.pin_memory,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user