pin_memory and skip_npz_check

This commit is contained in:
Darren Lau
2025-03-06 16:36:26 +08:00
parent 57cc9ea0ad
commit 040c04f17e
4 changed files with 55 additions and 5 deletions

View File

@@ -55,6 +55,7 @@ class TokenizeStrategy:
self, model_class: Any, model_id: str, subfolder: Optional[str] = None, tokenizer_cache_dir: Optional[str] = None
) -> Any:
tokenizer = None
#TODO: skip_npz_check adaption (However I don't cache TE latents)
if tokenizer_cache_dir:
local_tokenizer_path = os.path.join(tokenizer_cache_dir, model_id.replace("/", "_"))
if os.path.exists(local_tokenizer_path):
@@ -435,8 +436,10 @@ class LatentsCachingStrategy:
):
if not self.cache_to_disk:
return False
if not os.path.exists(npz_path):
return False
# In multinode training, os.path will hang, but np.load not sure.
if not self.skip_npz_check:
if not os.path.exists(npz_path):
return False
if self.skip_disk_cache_validity_check:
return True

View File

@@ -152,9 +152,13 @@ class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy):
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
# support old .npz
old_npz_file = os.path.splitext(absolute_path)[0] + SdSdxlLatentsCachingStrategy.SD_OLD_LATENTS_NPZ_SUFFIX
if os.path.exists(old_npz_file):
if self.skip_npz_check:
# TODO: Check user behaviour (lack of information)
return old_npz_file
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.suffix
elif os.path.exists(old_npz_file):
return old_npz_file
else:
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.suffix
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask)

View File

@@ -14,7 +14,7 @@ import shutil
import time
import typing
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState, DataLoaderConfiguration
import glob
import math
import os
@@ -218,6 +218,18 @@ class ImageInfo:
self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
@staticmethod
def _pin_tensor(tensor):
return tensor.pin_memory() if tensor is not None else tensor
def pin_memory(self):
self.latents = self._pin_tensor(self.latents)
self.latents_flipped = self._pin_tensor(self.latents_flipped)
self.text_encoder_outputs1 = self._pin_tensor(self.text_encoder_outputs1)
self.text_encoder_outputs2 = self._pin_tensor(self.text_encoder_outputs2)
self.text_encoder_pool2 = self._pin_tensor(self.text_encoder_pool2)
self.alpha_mask = self._pin_tensor(self.alpha_mask)
return self
class BucketManager:
def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None:
@@ -441,6 +453,7 @@ class BaseSubset:
custom_attributes: Optional[Dict[str, Any]] = None,
validation_seed: Optional[int] = None,
validation_split: Optional[float] = 0.0,
skip_npz_check: Optional[bool] = SKIP_NPZ_PATH_CHECK,
) -> None:
self.image_dir = image_dir
self.alpha_mask = alpha_mask if alpha_mask is not None else False
@@ -471,6 +484,7 @@ class BaseSubset:
self.validation_seed = validation_seed
self.validation_split = validation_split
self.skip_npz_check = skip_npz_check # For multinode training
class DreamBoothSubset(BaseSubset):
def __init__(
@@ -1841,6 +1855,11 @@ class BaseDataset(torch.utils.data.Dataset):
example["bucket_reso"] = bucket_reso
return example
def pin_memory(self):
for key in self.image_data.keys():
if hasattr(self.image_data[key], 'pin_memory') and callable(self.image_data[key].pin_memory):
self.image_data[key].pin_memory()
class DreamBoothDataset(BaseDataset):
IMAGE_INFO_CACHE_FILE = "metadata_cache.json"
@@ -2129,6 +2148,10 @@ class DreamBoothDataset(BaseDataset):
self.num_reg_images = num_reg_images
def pin_memory(self):
for key in self.image_data.keys():
if hasattr(self.image_data[key], 'pin_memory') and callable(self.image_data[key].pin_memory):
self.image_data[key].pin_memory()
class FineTuningDataset(BaseDataset):
def __init__(
@@ -2358,6 +2381,10 @@ class FineTuningDataset(BaseDataset):
return npz_file_norm, npz_file_flip
def pin_memory(self):
for key in self.image_data.keys():
if hasattr(self.image_data[key], 'pin_memory') and callable(self.image_data[key].pin_memory):
self.image_data[key].pin_memory()
class ControlNetDataset(BaseDataset):
def __init__(
@@ -3840,6 +3867,11 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
action="store_true",
help="persistent DataLoader workers (useful for reduce time gap between epoch, but may use more memory) / DataLoader のワーカーを持続させる (エポック間の時間差を少なくするのに有効だが、より多くのメモリを消費する可能性がある)",
)
parser.add_argument(
"--pin_memory",
action="store_true",
help="Pin memory for faster GPU loading (Windows will have side effect) / GPU の読み込みを高速化するためのピンメモリ",
)
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
parser.add_argument(
"--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / gradient checkpointingを有効にする"
@@ -5387,6 +5419,8 @@ def prepare_accelerator(args: argparse.Namespace):
kwargs_handlers = [i for i in kwargs_handlers if i is not None]
deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args)
dataloader_config = DataLoaderConfiguration(non_blocking=args.pin_memory)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
@@ -5395,6 +5429,7 @@ def prepare_accelerator(args: argparse.Namespace):
kwargs_handlers=kwargs_handlers,
dynamo_backend=dynamo_backend,
deepspeed_plugin=deepspeed_plugin,
dataloader_config=dataloader_config
)
print("accelerator device:", accelerator.device)
return accelerator
@@ -6590,6 +6625,10 @@ class collator_class:
dataset.set_current_step(self.current_step.value)
return examples[0]
def pin_memory(self):
if hasattr(self, 'pin_memory') and callable(self.pin_memory):
self.dataset.pin_memory()
class LossRecorder:
def __init__(self):

View File

@@ -796,12 +796,15 @@ class NativeTrainer:
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
pin_memory = args.pin_memory
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
pin_memory=pin_memory,
persistent_workers=args.persistent_data_loader_workers,
)
@@ -811,6 +814,7 @@ class NativeTrainer:
batch_size=1,
collate_fn=collator,
num_workers=n_workers,
pin_memory=pin_memory,
persistent_workers=args.persistent_data_loader_workers,
)