mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 00:32:25 +00:00
pin_memory and skip_npz_check
This commit is contained in:
@@ -55,6 +55,7 @@ class TokenizeStrategy:
|
||||
self, model_class: Any, model_id: str, subfolder: Optional[str] = None, tokenizer_cache_dir: Optional[str] = None
|
||||
) -> Any:
|
||||
tokenizer = None
|
||||
#TODO: skip_npz_check adaption (However I don't cache TE latents)
|
||||
if tokenizer_cache_dir:
|
||||
local_tokenizer_path = os.path.join(tokenizer_cache_dir, model_id.replace("/", "_"))
|
||||
if os.path.exists(local_tokenizer_path):
|
||||
@@ -435,8 +436,10 @@ class LatentsCachingStrategy:
|
||||
):
|
||||
if not self.cache_to_disk:
|
||||
return False
|
||||
if not os.path.exists(npz_path):
|
||||
return False
|
||||
# In multinode training, os.path will hang, but np.load not sure.
|
||||
if not self.skip_npz_check:
|
||||
if not os.path.exists(npz_path):
|
||||
return False
|
||||
if self.skip_disk_cache_validity_check:
|
||||
return True
|
||||
|
||||
|
||||
@@ -152,9 +152,13 @@ class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
|
||||
# support old .npz
|
||||
old_npz_file = os.path.splitext(absolute_path)[0] + SdSdxlLatentsCachingStrategy.SD_OLD_LATENTS_NPZ_SUFFIX
|
||||
if os.path.exists(old_npz_file):
|
||||
if self.skip_npz_check:
|
||||
# TODO: Check user behaviour (lack of information)
|
||||
return old_npz_file
|
||||
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.suffix
|
||||
elif os.path.exists(old_npz_file):
|
||||
return old_npz_file
|
||||
else:
|
||||
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.suffix
|
||||
|
||||
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
|
||||
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask)
|
||||
|
||||
@@ -14,7 +14,7 @@ import shutil
|
||||
import time
|
||||
import typing
|
||||
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
|
||||
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState
|
||||
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState, DataLoaderConfiguration
|
||||
import glob
|
||||
import math
|
||||
import os
|
||||
@@ -218,6 +218,18 @@ class ImageInfo:
|
||||
|
||||
self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
|
||||
|
||||
@staticmethod
|
||||
def _pin_tensor(tensor):
|
||||
return tensor.pin_memory() if tensor is not None else tensor
|
||||
|
||||
def pin_memory(self):
|
||||
self.latents = self._pin_tensor(self.latents)
|
||||
self.latents_flipped = self._pin_tensor(self.latents_flipped)
|
||||
self.text_encoder_outputs1 = self._pin_tensor(self.text_encoder_outputs1)
|
||||
self.text_encoder_outputs2 = self._pin_tensor(self.text_encoder_outputs2)
|
||||
self.text_encoder_pool2 = self._pin_tensor(self.text_encoder_pool2)
|
||||
self.alpha_mask = self._pin_tensor(self.alpha_mask)
|
||||
return self
|
||||
|
||||
class BucketManager:
|
||||
def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None:
|
||||
@@ -441,6 +453,7 @@ class BaseSubset:
|
||||
custom_attributes: Optional[Dict[str, Any]] = None,
|
||||
validation_seed: Optional[int] = None,
|
||||
validation_split: Optional[float] = 0.0,
|
||||
skip_npz_check: Optional[bool] = SKIP_NPZ_PATH_CHECK,
|
||||
) -> None:
|
||||
self.image_dir = image_dir
|
||||
self.alpha_mask = alpha_mask if alpha_mask is not None else False
|
||||
@@ -471,6 +484,7 @@ class BaseSubset:
|
||||
self.validation_seed = validation_seed
|
||||
self.validation_split = validation_split
|
||||
|
||||
self.skip_npz_check = skip_npz_check # For multinode training
|
||||
|
||||
class DreamBoothSubset(BaseSubset):
|
||||
def __init__(
|
||||
@@ -1841,6 +1855,11 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
example["bucket_reso"] = bucket_reso
|
||||
return example
|
||||
|
||||
def pin_memory(self):
|
||||
for key in self.image_data.keys():
|
||||
if hasattr(self.image_data[key], 'pin_memory') and callable(self.image_data[key].pin_memory):
|
||||
self.image_data[key].pin_memory()
|
||||
|
||||
|
||||
class DreamBoothDataset(BaseDataset):
|
||||
IMAGE_INFO_CACHE_FILE = "metadata_cache.json"
|
||||
@@ -2129,6 +2148,10 @@ class DreamBoothDataset(BaseDataset):
|
||||
|
||||
self.num_reg_images = num_reg_images
|
||||
|
||||
def pin_memory(self):
|
||||
for key in self.image_data.keys():
|
||||
if hasattr(self.image_data[key], 'pin_memory') and callable(self.image_data[key].pin_memory):
|
||||
self.image_data[key].pin_memory()
|
||||
|
||||
class FineTuningDataset(BaseDataset):
|
||||
def __init__(
|
||||
@@ -2358,6 +2381,10 @@ class FineTuningDataset(BaseDataset):
|
||||
|
||||
return npz_file_norm, npz_file_flip
|
||||
|
||||
def pin_memory(self):
|
||||
for key in self.image_data.keys():
|
||||
if hasattr(self.image_data[key], 'pin_memory') and callable(self.image_data[key].pin_memory):
|
||||
self.image_data[key].pin_memory()
|
||||
|
||||
class ControlNetDataset(BaseDataset):
|
||||
def __init__(
|
||||
@@ -3840,6 +3867,11 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
action="store_true",
|
||||
help="persistent DataLoader workers (useful for reduce time gap between epoch, but may use more memory) / DataLoader のワーカーを持続させる (エポック間の時間差を少なくするのに有効だが、より多くのメモリを消費する可能性がある)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pin_memory",
|
||||
action="store_true",
|
||||
help="Pin memory for faster GPU loading (Windows will have side effect) / GPU の読み込みを高速化するためのピンメモリ",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
|
||||
parser.add_argument(
|
||||
"--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / gradient checkpointingを有効にする"
|
||||
@@ -5387,6 +5419,8 @@ def prepare_accelerator(args: argparse.Namespace):
|
||||
kwargs_handlers = [i for i in kwargs_handlers if i is not None]
|
||||
deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args)
|
||||
|
||||
dataloader_config = DataLoaderConfiguration(non_blocking=args.pin_memory)
|
||||
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
mixed_precision=args.mixed_precision,
|
||||
@@ -5395,6 +5429,7 @@ def prepare_accelerator(args: argparse.Namespace):
|
||||
kwargs_handlers=kwargs_handlers,
|
||||
dynamo_backend=dynamo_backend,
|
||||
deepspeed_plugin=deepspeed_plugin,
|
||||
dataloader_config=dataloader_config
|
||||
)
|
||||
print("accelerator device:", accelerator.device)
|
||||
return accelerator
|
||||
@@ -6590,6 +6625,10 @@ class collator_class:
|
||||
dataset.set_current_step(self.current_step.value)
|
||||
return examples[0]
|
||||
|
||||
def pin_memory(self):
|
||||
if hasattr(self, 'pin_memory') and callable(self.pin_memory):
|
||||
self.dataset.pin_memory()
|
||||
|
||||
|
||||
class LossRecorder:
|
||||
def __init__(self):
|
||||
|
||||
@@ -796,12 +796,15 @@ class NativeTrainer:
|
||||
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
||||
|
||||
pin_memory = args.pin_memory
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset_group,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
pin_memory=pin_memory,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
@@ -811,6 +814,7 @@ class NativeTrainer:
|
||||
batch_size=1,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
pin_memory=pin_memory,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user