Compare commits

...

12 Commits

Author SHA1 Message Date
Dave Lage
cfff446aab Merge d73f2db22c into a5a162044c 2025-11-03 14:58:48 +01:00
Kohya S.
a5a162044c Merge pull request #2226 from kohya-ss/fix-hunyuan-image-batch-gen-error
fix: error on batch generation closes #2209
2025-10-15 21:57:45 +09:00
Kohya S
a33cad714e fix: error on batch generation closes #2209 2025-10-15 21:57:11 +09:00
Kohya S.
5e366acda4 Merge pull request #2003 from laolongboy/sd3-dev
Fix missing parameters in model conversion script
2025-10-01 21:03:12 +09:00
rockerBOO
d73f2db22c Pinning only supported when CUDA is available 2025-06-16 18:31:46 -04:00
rockerBOO
81df559406 Update pin_memory tests to use DataLoader 2025-06-16 18:08:19 -04:00
rockerBOO
95e260fb99 Add tests for pin memory 2025-06-16 17:53:02 -04:00
rockerBOO
098122340a Merge branch 'sd3' into pin_memory 2025-06-16 17:27:49 -04:00
laolongboy
e64dc05c2a Supplement the input parameters to correctly convert the flux model to BFL format; fixes #1996 2025-03-24 23:33:25 +08:00
rockerBOO
03b35be387 Add pin_memory to finetune scripts 2025-01-23 12:45:37 -05:00
rockerBOO
50d8daa7d8 Accelerate dataloader_config to non_blocking if pin_memory is enabled 2025-01-23 11:02:29 -05:00
rockerBOO
c4b0bb6fce Add pin_memory to DataLoader and update ImageInfo to support 2025-01-23 10:39:01 -05:00
20 changed files with 245 additions and 5 deletions

View File

@@ -244,6 +244,7 @@ def train(args):
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
pin_memory=args.pin_memory,
persistent_workers=args.persistent_data_loader_workers,
)

View File

@@ -126,6 +126,7 @@ def main(args):
batch_size=args.batch_size,
shuffle=False,
num_workers=args.max_data_loader_n_workers,
pin_memory=args.pin_memory,
collate_fn=collate_fn_remove_corrupted,
drop_last=False,
)
@@ -187,6 +188,11 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する読み込みを高速化",
)
parser.add_argument(
"--pin_memory",
action="store_true",
help="Pin memory for faster GPU loading / GPU の読み込みを高速化するためのピンメモリ",
)
parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数多いと精度が上がるが時間がかかる")
parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p")
parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長")

View File

@@ -113,6 +113,7 @@ def main(args):
dataset,
batch_size=args.batch_size,
shuffle=False,
pin_memory=args.pin_memory,
num_workers=args.max_data_loader_n_workers,
collate_fn=collate_fn_remove_corrupted,
drop_last=False,
@@ -164,6 +165,11 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する読み込みを高速化",
)
parser.add_argument(
"--pin_memory",
action="store_true",
help="Pin memory for faster GPU loading / GPU の読み込みを高速化するためのピンメモリ",
)
parser.add_argument("--max_length", type=int, default=50, help="max length of caption / captionの最大長")
parser.add_argument(
"--remove_words",

View File

@@ -122,6 +122,7 @@ def main(args):
dataset,
batch_size=1,
shuffle=False,
pin_memory=args.pin_memory,
num_workers=args.max_data_loader_n_workers,
collate_fn=collate_fn_remove_corrupted,
drop_last=False,
@@ -223,6 +224,11 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する読み込みを高速化",
)
parser.add_argument(
"--pin_memory",
action="store_true",
help="Pin memory for faster GPU loading / GPU の読み込みを高速化するためのピンメモリ",
)
parser.add_argument(
"--max_resolution",
type=str,

View File

@@ -336,6 +336,7 @@ def main(args):
dataset,
batch_size=args.batch_size,
shuffle=False,
pin_memory=args.pin_memory,
num_workers=args.max_data_loader_n_workers,
collate_fn=collate_fn_remove_corrupted,
drop_last=False,
@@ -410,6 +411,11 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する読み込みを高速化",
)
parser.add_argument(
"--pin_memory",
action="store_true",
help="Pin memory for faster GPU loading / GPU の読み込みを高速化するためのピンメモリ",
)
parser.add_argument(
"--caption_extention",
type=str,

View File

@@ -398,6 +398,7 @@ def train(args):
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
pin_memory=args.pin_memory,
persistent_workers=args.persistent_data_loader_workers,
)

View File

@@ -405,6 +405,7 @@ def train(args):
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
pin_memory=args.pin_memory,
persistent_workers=args.persistent_data_loader_workers,
)

View File

@@ -1001,7 +1001,7 @@ def process_batch_prompts(prompts_data: List[Dict], args: argparse.Namespace) ->
all_precomputed_text_data.append(text_data)
# Models should be removed from device after prepare_text_inputs
del tokenizer_batch, text_encoder_batch, temp_shared_models_txt, conds_cache_batch
del tokenizer_vlm, text_encoder_vlm_batch, tokenizer_byt5, text_encoder_byt5_batch, temp_shared_models_txt, conds_cache_batch
gc.collect() # Force cleanup of Text Encoder from GPU memory
clean_memory_on_device(device)
@@ -1075,7 +1075,7 @@ def process_batch_prompts(prompts_data: List[Dict], args: argparse.Namespace) ->
# save_output expects latent to be [BCTHW] or [CTHW]. generate returns [BCTHW] (batch size 1).
# latent[0] is correct if generate returns it with batch dim.
# The latent from generate is (1, C, T, H, W)
save_output(current_args, vae_for_batch, latent[0], device) # Pass vae_for_batch
save_output(current_args, vae_for_batch, latent, device) # Pass vae_for_batch
vae_for_batch.to("cpu") # Move VAE back to CPU

View File

@@ -12,9 +12,8 @@ import pathlib
import re
import shutil
import time
import typing
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState, DataLoaderConfiguration
import glob
import math
import os
@@ -209,6 +208,19 @@ class ImageInfo:
self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
self.resize_interpolation: Optional[str] = None
@staticmethod
def _pin_tensor(tensor):
return tensor.pin_memory() if tensor is not None else tensor
def pin_memory(self):
self.latents = self._pin_tensor(self.latents)
self.latents_flipped = self._pin_tensor(self.latents_flipped)
self.text_encoder_outputs1 = self._pin_tensor(self.text_encoder_outputs1)
self.text_encoder_outputs2 = self._pin_tensor(self.text_encoder_outputs2)
self.text_encoder_pool2 = self._pin_tensor(self.text_encoder_pool2)
self.alpha_mask = self._pin_tensor(self.alpha_mask)
return self
class BucketManager:
def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None:
@@ -2181,6 +2193,11 @@ class DreamBoothDataset(BaseDataset):
self.num_reg_images = num_reg_images
def pin_memory(self):
for key in self.image_data.keys():
if hasattr(self.image_data[key], 'pin_memory') and callable(self.image_data[key].pin_memory):
self.image_data[key].pin_memory()
class FineTuningDataset(BaseDataset):
def __init__(
@@ -4005,6 +4022,11 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
action="store_true",
help="persistent DataLoader workers (useful for reduce time gap between epoch, but may use more memory) / DataLoader のワーカーを持続させる (エポック間の時間差を少なくするのに有効だが、より多くのメモリを消費する可能性がある)",
)
parser.add_argument(
"--pin_memory",
action="store_true",
help="Pin memory for faster GPU loading / GPU の読み込みを高速化するためのピンメモリ",
)
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
parser.add_argument(
"--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / gradient checkpointingを有効にする"
@@ -5529,6 +5551,8 @@ def prepare_accelerator(args: argparse.Namespace):
kwargs_handlers = [i for i in kwargs_handlers if i is not None]
deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args)
dataloader_config = DataLoaderConfiguration(non_blocking=args.pin_memory)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
@@ -5537,6 +5561,7 @@ def prepare_accelerator(args: argparse.Namespace):
kwargs_handlers=kwargs_handlers,
dynamo_backend=dynamo_backend,
deepspeed_plugin=deepspeed_plugin,
dataloader_config=dataloader_config
)
print("accelerator device:", accelerator.device)
return accelerator
@@ -6700,6 +6725,10 @@ class collator_class:
dataset.set_current_step(self.current_step.value)
return examples[0]
def pin_memory(self):
if hasattr(self, 'pin_memory') and callable(self.pin_memory):
self.dataset.pin_memory()
class LossRecorder:
def __init__(self):

View File

@@ -502,6 +502,7 @@ def train(args):
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
pin_memory=args.pin_memory,
persistent_workers=args.persistent_data_loader_workers,
)

View File

@@ -431,6 +431,7 @@ def train(args):
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
pin_memory=args.pin_memory,
persistent_workers=args.persistent_data_loader_workers,
)

View File

@@ -282,6 +282,7 @@ def train(args):
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
pin_memory=args.pin_memory,
persistent_workers=args.persistent_data_loader_workers,
)

View File

@@ -273,6 +273,7 @@ def train(args):
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
pin_memory=args.pin_memory,
persistent_workers=args.persistent_data_loader_workers,
)

View File

@@ -221,6 +221,7 @@ def train(args):
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
pin_memory=args.pin_memory,
persistent_workers=args.persistent_data_loader_workers,
)

175
tests/test_pin_memory.py Normal file
View File

@@ -0,0 +1,175 @@
import argparse
import pytest
import torch
import importlib.metadata
import multiprocessing
def test_pin_memory_argument():
"""
Test that the pin_memory argument is correctly added to argument parsers
"""
from library.train_util import add_training_arguments
parser = argparse.ArgumentParser()
add_training_arguments(parser, support_dreambooth=True)
# Parse an empty list of arguments to check the default
args = parser.parse_args([])
assert hasattr(args, "pin_memory"), "pin_memory argument should be present in argument parser"
assert args.pin_memory is False, "pin_memory should default to False"
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_dreambooth_dataset_pin_memory():
"""
Test pin_memory functionality using a simple mock dataset
"""
from library.train_util import DreamBoothDataset, DreamBoothSubset, collator_class
# Create a mock DreamBoothSubset with minimal arguments
def create_mock_subset():
return DreamBoothSubset(
image_dir='/mock/path',
is_reg=False,
class_tokens='test_token',
caption_extension='.txt',
alpha_mask=False,
num_repeats=1,
shuffle_caption=False,
caption_separator=',',
keep_tokens=0,
keep_tokens_separator='',
secondary_separator='',
enable_wildcard=False,
color_aug=False,
flip_aug=False,
face_crop_aug_range=None,
random_crop=False,
caption_dropout_rate=0,
caption_dropout_every_n_epochs=0,
caption_tag_dropout_rate=0,
caption_prefix='',
caption_suffix='',
token_warmup_min=1,
token_warmup_step=0,
cache_info=False
)
# Create a simplified mock dataset
class SimpleMockDataset(torch.utils.data.Dataset):
def __init__(self):
self.data = [torch.randn(64, 64) for _ in range(4)]
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
# Create a DataLoader to test pin_memory
dataloader = torch.utils.data.DataLoader(
SimpleMockDataset(),
batch_size=2,
num_workers=0, # Use 0 to avoid multiprocessing overhead
pin_memory=True
)
# Verify pin_memory works correctly
for batch in dataloader:
# Pinning only works when CUDA is available
assert all(tensor.is_pinned() for tensor in batch), "All tensors should be pinned"
break
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_pin_memory_cuda_transfer():
"""
Test pin_memory functionality for CUDA tensor transfer
"""
# Create a simple dataset
class SimpleCUDADataset(torch.utils.data.Dataset):
def __init__(self):
self.data = [torch.randn(64, 64) for _ in range(4)]
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
# Create a DataLoader with pin_memory enabled
dataloader = torch.utils.data.DataLoader(
SimpleCUDADataset(),
batch_size=2,
num_workers=0, # Use 0 to avoid multiprocessing overhead
pin_memory=True
)
# Verify CUDA transfer works with pinned memory
for batch in dataloader:
cuda_batch = [tensor.to('cuda', non_blocking=True) for tensor in batch]
assert all(tensor.is_pinned() for tensor in batch), "All tensors should be pinned"
break
def test_training_scripts_pin_memory_support():
"""
Verify that multiple training scripts support pin_memory argument
"""
training_scripts = [
"fine_tune.py",
"flux_train.py",
"sd3_train.py",
"sdxl_train.py",
"train_network.py",
"train_textual_inversion.py",
"sdxl_train_control_net.py",
"flux_train_control_net.py",
]
from library.train_util import add_training_arguments
for script in training_scripts:
parser = argparse.ArgumentParser()
add_training_arguments(parser, support_dreambooth=True)
# Parse arguments to check pin_memory
args = parser.parse_args([])
assert hasattr(args, "pin_memory"), f"{script} should have pin_memory argument"
def test_accelerator_pin_memory_config():
"""
Test that the Accelerator is configured with pin_memory option
Checks compatibility and configuration based on Accelerate library version
"""
from library.train_util import prepare_accelerator
# Check Accelerate library version
try:
accelerate_version = importlib.metadata.version("accelerate")
print(f"Accelerate library version: {accelerate_version}")
except importlib.metadata.PackageNotFoundError:
pytest.skip("Accelerate library not installed")
# Minimal args to pass initial checks
args = argparse.Namespace(
gradient_accumulation_steps=1,
mixed_precision="no",
log_with=None,
kwargs_handlers=[],
deepspeed_plugin=None,
pin_memory=True,
logging_dir=None,
torch_compile=False,
log_prefix=None,
ddp_gradient_as_bucket_view=False,
ddp_static_graph=False,
ddp_timeout=None,
wandb_api_key=None,
dynamo_backend="NO",
deepspeed=False,
)
# Prepare accelerator
accelerator = prepare_accelerator(args)
# Check for dataloader_config
assert hasattr(accelerator, "dataloader_config"), "Accelerator should have dataloader_config when pin_memory is enabled"
assert accelerator.dataloader_config.non_blocking is True, "Dataloader should be configured with pin_memory"

View File

@@ -57,7 +57,7 @@ def convert(args):
save_dtype = str_to_dtype(args.save_precision) if args.save_precision is not None else None
# make reverse map from diffusers map
diffusers_to_bfl_map = flux_utils.make_diffusers_to_bfl_map()
diffusers_to_bfl_map = flux_utils.make_diffusers_to_bfl_map(19, 38)
# iterate over three safetensors files to reduce memory usage
flux_sd = {}

View File

@@ -212,6 +212,7 @@ def train(args):
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
pin_memory=args.pin_memory,
persistent_workers=args.persistent_data_loader_workers,
)

View File

@@ -783,6 +783,7 @@ class NetworkTrainer:
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
pin_memory=args.pin_memory,
persistent_workers=args.persistent_data_loader_workers,
)

View File

@@ -412,6 +412,7 @@ class TextualInversionTrainer:
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
pin_memory=args.pin_memory,
persistent_workers=args.persistent_data_loader_workers,
)

View File

@@ -317,6 +317,7 @@ def train(args):
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
pin_memory=args.pin_memory,
persistent_workers=args.persistent_data_loader_workers,
)