diff --git a/library/strategy_base.py b/library/strategy_base.py index c7f6e39b..5d7e9593 100644 --- a/library/strategy_base.py +++ b/library/strategy_base.py @@ -55,6 +55,7 @@ class TokenizeStrategy: self, model_class: Any, model_id: str, subfolder: Optional[str] = None, tokenizer_cache_dir: Optional[str] = None ) -> Any: tokenizer = None + #TODO: skip_npz_check adaption (However I don't cache TE latents) if tokenizer_cache_dir: local_tokenizer_path = os.path.join(tokenizer_cache_dir, model_id.replace("/", "_")) if os.path.exists(local_tokenizer_path): @@ -435,8 +436,10 @@ class LatentsCachingStrategy: ): if not self.cache_to_disk: return False - if not os.path.exists(npz_path): - return False + # In multinode training, os.path will hang, but np.load not sure. + if not self.skip_npz_check: + if not os.path.exists(npz_path): + return False if self.skip_disk_cache_validity_check: return True diff --git a/library/strategy_sd.py b/library/strategy_sd.py index a44fc409..a227d871 100644 --- a/library/strategy_sd.py +++ b/library/strategy_sd.py @@ -152,9 +152,13 @@ class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy): def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: # support old .npz old_npz_file = os.path.splitext(absolute_path)[0] + SdSdxlLatentsCachingStrategy.SD_OLD_LATENTS_NPZ_SUFFIX - if os.path.exists(old_npz_file): + if self.skip_npz_check: + # TODO: Check user behaviour (lack of information) return old_npz_file - return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.suffix + elif os.path.exists(old_npz_file): + return old_npz_file + else: + return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.suffix def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask) diff --git a/library/train_util.py b/library/train_util.py index ebe9f706..5eecf2d9 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -14,7 +14,7 @@ import shutil import time import typing from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union -from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState +from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState, DataLoaderConfiguration import glob import math import os @@ -218,6 +218,18 @@ class ImageInfo: self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime + @staticmethod + def _pin_tensor(tensor): + return tensor.pin_memory() if tensor is not None else tensor + + def pin_memory(self): + self.latents = self._pin_tensor(self.latents) + self.latents_flipped = self._pin_tensor(self.latents_flipped) + self.text_encoder_outputs1 = self._pin_tensor(self.text_encoder_outputs1) + self.text_encoder_outputs2 = self._pin_tensor(self.text_encoder_outputs2) + self.text_encoder_pool2 = self._pin_tensor(self.text_encoder_pool2) + self.alpha_mask = self._pin_tensor(self.alpha_mask) + return self class BucketManager: def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None: @@ -441,6 +453,7 @@ class BaseSubset: custom_attributes: Optional[Dict[str, Any]] = None, validation_seed: Optional[int] = None, validation_split: Optional[float] = 0.0, + skip_npz_check: Optional[bool] = SKIP_NPZ_PATH_CHECK, ) -> None: self.image_dir = image_dir self.alpha_mask = alpha_mask if alpha_mask is not None else False @@ -471,6 +484,7 @@ class BaseSubset: self.validation_seed = validation_seed self.validation_split = validation_split + self.skip_npz_check = skip_npz_check # For multinode training class DreamBoothSubset(BaseSubset): def __init__( @@ -1841,6 +1855,11 @@ class BaseDataset(torch.utils.data.Dataset): example["bucket_reso"] = bucket_reso return example + def pin_memory(self): + for key in self.image_data.keys(): + if hasattr(self.image_data[key], 'pin_memory') and callable(self.image_data[key].pin_memory): + self.image_data[key].pin_memory() + class DreamBoothDataset(BaseDataset): IMAGE_INFO_CACHE_FILE = "metadata_cache.json" @@ -2129,6 +2148,10 @@ class DreamBoothDataset(BaseDataset): self.num_reg_images = num_reg_images + def pin_memory(self): + for key in self.image_data.keys(): + if hasattr(self.image_data[key], 'pin_memory') and callable(self.image_data[key].pin_memory): + self.image_data[key].pin_memory() class FineTuningDataset(BaseDataset): def __init__( @@ -2358,6 +2381,10 @@ class FineTuningDataset(BaseDataset): return npz_file_norm, npz_file_flip + def pin_memory(self): + for key in self.image_data.keys(): + if hasattr(self.image_data[key], 'pin_memory') and callable(self.image_data[key].pin_memory): + self.image_data[key].pin_memory() class ControlNetDataset(BaseDataset): def __init__( @@ -3840,6 +3867,11 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: action="store_true", help="persistent DataLoader workers (useful for reduce time gap between epoch, but may use more memory) / DataLoader のワーカーを持続させる (エポック間の時間差を少なくするのに有効だが、より多くのメモリを消費する可能性がある)", ) + parser.add_argument( + "--pin_memory", + action="store_true", + help="Pin memory for faster GPU loading (Windows will have side effect) / GPU の読み込みを高速化するためのピンメモリ", + ) parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed") parser.add_argument( "--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / gradient checkpointingを有効にする" @@ -5387,6 +5419,8 @@ def prepare_accelerator(args: argparse.Namespace): kwargs_handlers = [i for i in kwargs_handlers if i is not None] deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args) + dataloader_config = DataLoaderConfiguration(non_blocking=args.pin_memory) + accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, @@ -5395,6 +5429,7 @@ def prepare_accelerator(args: argparse.Namespace): kwargs_handlers=kwargs_handlers, dynamo_backend=dynamo_backend, deepspeed_plugin=deepspeed_plugin, + dataloader_config=dataloader_config ) print("accelerator device:", accelerator.device) return accelerator @@ -6590,6 +6625,10 @@ class collator_class: dataset.set_current_step(self.current_step.value) return examples[0] + def pin_memory(self): + if hasattr(self, 'pin_memory') and callable(self.pin_memory): + self.dataset.pin_memory() + class LossRecorder: def __init__(self): diff --git a/train_native.py b/train_native.py index 4d315daf..45c94a23 100644 --- a/train_native.py +++ b/train_native.py @@ -796,12 +796,15 @@ class NativeTrainer: # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers + pin_memory = args.pin_memory + train_dataloader = torch.utils.data.DataLoader( train_dataset_group, batch_size=1, shuffle=True, collate_fn=collator, num_workers=n_workers, + pin_memory=pin_memory, persistent_workers=args.persistent_data_loader_workers, ) @@ -811,6 +814,7 @@ class NativeTrainer: batch_size=1, collate_fn=collator, num_workers=n_workers, + pin_memory=pin_memory, persistent_workers=args.persistent_data_loader_workers, )