mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 08:52:45 +00:00
Add pin_memory to DataLoader and update ImageInfo to support
This commit is contained in:
@@ -242,6 +242,7 @@ def train(args):
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
pin_memory=args.pin_memory,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
|
||||
@@ -126,6 +126,7 @@ def main(args):
|
||||
batch_size=args.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=args.max_data_loader_n_workers,
|
||||
pin_memory=args.pin_memory,
|
||||
collate_fn=collate_fn_remove_corrupted,
|
||||
drop_last=False,
|
||||
)
|
||||
|
||||
@@ -113,6 +113,7 @@ def main(args):
|
||||
dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=False,
|
||||
pin_memory=args.pin_memory,
|
||||
num_workers=args.max_data_loader_n_workers,
|
||||
collate_fn=collate_fn_remove_corrupted,
|
||||
drop_last=False,
|
||||
|
||||
@@ -122,6 +122,7 @@ def main(args):
|
||||
dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
pin_memory=args.pin_memory,
|
||||
num_workers=args.max_data_loader_n_workers,
|
||||
collate_fn=collate_fn_remove_corrupted,
|
||||
drop_last=False,
|
||||
|
||||
@@ -335,6 +335,7 @@ def main(args):
|
||||
dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=False,
|
||||
pin_memory=args.pin_memory,
|
||||
num_workers=args.max_data_loader_n_workers,
|
||||
collate_fn=collate_fn_remove_corrupted,
|
||||
drop_last=False,
|
||||
|
||||
@@ -397,6 +397,7 @@ def train(args):
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
pin_memory=args.pin_memory,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
|
||||
@@ -398,6 +398,7 @@ def train(args):
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
pin_memory=args.pin_memory,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
|
||||
@@ -176,6 +176,19 @@ class ImageInfo:
|
||||
|
||||
self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
|
||||
|
||||
@staticmethod
|
||||
def _pin_tensor(tensor):
|
||||
return tensor.pin_memory() if tensor is not None else tensor
|
||||
|
||||
def pin_memory(self):
|
||||
self.latents = self._pin_tensor(self.latents)
|
||||
self.latents_flipped = self._pin_tensor(self.latents_flipped)
|
||||
self.text_encoder_outputs1 = self._pin_tensor(self.text_encoder_outputs1)
|
||||
self.text_encoder_outputs2 = self._pin_tensor(self.text_encoder_outputs2)
|
||||
self.text_encoder_pool2 = self._pin_tensor(self.text_encoder_pool2)
|
||||
self.alpha_mask = self._pin_tensor(self.alpha_mask)
|
||||
return self
|
||||
|
||||
|
||||
class BucketManager:
|
||||
def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None:
|
||||
@@ -2036,6 +2049,11 @@ class DreamBoothDataset(BaseDataset):
|
||||
|
||||
self.num_reg_images = num_reg_images
|
||||
|
||||
def pin_memory(self):
|
||||
for key in self.image_data.keys():
|
||||
if hasattr(self.image_data[key], 'pin_memory') and callable(self.image_data[key].pin_memory):
|
||||
self.image_data[key].pin_memory()
|
||||
|
||||
|
||||
class FineTuningDataset(BaseDataset):
|
||||
def __init__(
|
||||
@@ -3734,6 +3752,11 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
action="store_true",
|
||||
help="persistent DataLoader workers (useful for reduce time gap between epoch, but may use more memory) / DataLoader のワーカーを持続させる (エポック間の時間差を少なくするのに有効だが、より多くのメモリを消費する可能性がある)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pin_memory",
|
||||
action="store_true",
|
||||
help="Pin memory for faster GPU loading / GPU の読み込みを高速化するためのピンメモリ",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
|
||||
parser.add_argument(
|
||||
"--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / gradient checkpointingを有効にする"
|
||||
@@ -6379,6 +6402,10 @@ class collator_class:
|
||||
dataset.set_current_step(self.current_step.value)
|
||||
return examples[0]
|
||||
|
||||
def pin_memory(self):
|
||||
if hasattr(self, 'pin_memory') and callable(self.pin_memory):
|
||||
self.dataset.pin_memory()
|
||||
|
||||
|
||||
class LossRecorder:
|
||||
def __init__(self):
|
||||
|
||||
@@ -498,6 +498,7 @@ def train(args):
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
pin_memory=args.pin_memory,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
|
||||
@@ -430,6 +430,7 @@ def train(args):
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
pin_memory=args.pin_memory,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
|
||||
@@ -281,6 +281,7 @@ def train(args):
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
pin_memory=args.pin_memory,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
|
||||
@@ -272,6 +272,7 @@ def train(args):
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
pin_memory=args.pin_memory,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
|
||||
@@ -220,6 +220,7 @@ def train(args):
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
pin_memory=args.pin_memory,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
|
||||
@@ -210,6 +210,7 @@ def train(args):
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
pin_memory=args.pin_memory,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
|
||||
@@ -577,6 +577,7 @@ class NetworkTrainer:
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
pin_memory=args.pin_memory,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
|
||||
@@ -408,6 +408,7 @@ class TextualInversionTrainer:
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
pin_memory=args.pin_memory,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
|
||||
@@ -316,6 +316,7 @@ def train(args):
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
pin_memory=args.pin_memory,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user