Add pin_memory to DataLoader and update ImageInfo to support

This commit is contained in:
rockerBOO
2025-01-23 10:39:01 -05:00
parent e89653975d
commit c4b0bb6fce
17 changed files with 43 additions and 0 deletions

View File

@@ -242,6 +242,7 @@ def train(args):
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
pin_memory=args.pin_memory,
persistent_workers=args.persistent_data_loader_workers,
)

View File

@@ -126,6 +126,7 @@ def main(args):
batch_size=args.batch_size,
shuffle=False,
num_workers=args.max_data_loader_n_workers,
pin_memory=args.pin_memory,
collate_fn=collate_fn_remove_corrupted,
drop_last=False,
)

View File

@@ -113,6 +113,7 @@ def main(args):
dataset,
batch_size=args.batch_size,
shuffle=False,
pin_memory=args.pin_memory,
num_workers=args.max_data_loader_n_workers,
collate_fn=collate_fn_remove_corrupted,
drop_last=False,

View File

@@ -122,6 +122,7 @@ def main(args):
dataset,
batch_size=1,
shuffle=False,
pin_memory=args.pin_memory,
num_workers=args.max_data_loader_n_workers,
collate_fn=collate_fn_remove_corrupted,
drop_last=False,

View File

@@ -335,6 +335,7 @@ def main(args):
dataset,
batch_size=args.batch_size,
shuffle=False,
pin_memory=args.pin_memory,
num_workers=args.max_data_loader_n_workers,
collate_fn=collate_fn_remove_corrupted,
drop_last=False,

View File

@@ -397,6 +397,7 @@ def train(args):
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
pin_memory=args.pin_memory,
persistent_workers=args.persistent_data_loader_workers,
)

View File

@@ -398,6 +398,7 @@ def train(args):
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
pin_memory=args.pin_memory,
persistent_workers=args.persistent_data_loader_workers,
)

View File

@@ -176,6 +176,19 @@ class ImageInfo:
self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
@staticmethod
def _pin_tensor(tensor):
return tensor.pin_memory() if tensor is not None else tensor
def pin_memory(self):
self.latents = self._pin_tensor(self.latents)
self.latents_flipped = self._pin_tensor(self.latents_flipped)
self.text_encoder_outputs1 = self._pin_tensor(self.text_encoder_outputs1)
self.text_encoder_outputs2 = self._pin_tensor(self.text_encoder_outputs2)
self.text_encoder_pool2 = self._pin_tensor(self.text_encoder_pool2)
self.alpha_mask = self._pin_tensor(self.alpha_mask)
return self
class BucketManager:
def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None:
@@ -2036,6 +2049,11 @@ class DreamBoothDataset(BaseDataset):
self.num_reg_images = num_reg_images
def pin_memory(self):
for key in self.image_data.keys():
if hasattr(self.image_data[key], 'pin_memory') and callable(self.image_data[key].pin_memory):
self.image_data[key].pin_memory()
class FineTuningDataset(BaseDataset):
def __init__(
@@ -3734,6 +3752,11 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
action="store_true",
help="persistent DataLoader workers (useful for reduce time gap between epoch, but may use more memory) / DataLoader のワーカーを持続させる (エポック間の時間差を少なくするのに有効だが、より多くのメモリを消費する可能性がある)",
)
parser.add_argument(
"--pin_memory",
action="store_true",
help="Pin memory for faster GPU loading / GPU の読み込みを高速化するためのピンメモリ",
)
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
parser.add_argument(
"--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / gradient checkpointingを有効にする"
@@ -6379,6 +6402,10 @@ class collator_class:
dataset.set_current_step(self.current_step.value)
return examples[0]
def pin_memory(self):
if hasattr(self, 'pin_memory') and callable(self.pin_memory):
self.dataset.pin_memory()
class LossRecorder:
def __init__(self):

View File

@@ -498,6 +498,7 @@ def train(args):
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
pin_memory=args.pin_memory,
persistent_workers=args.persistent_data_loader_workers,
)

View File

@@ -430,6 +430,7 @@ def train(args):
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
pin_memory=args.pin_memory,
persistent_workers=args.persistent_data_loader_workers,
)

View File

@@ -281,6 +281,7 @@ def train(args):
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
pin_memory=args.pin_memory,
persistent_workers=args.persistent_data_loader_workers,
)

View File

@@ -272,6 +272,7 @@ def train(args):
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
pin_memory=args.pin_memory,
persistent_workers=args.persistent_data_loader_workers,
)

View File

@@ -220,6 +220,7 @@ def train(args):
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
pin_memory=args.pin_memory,
persistent_workers=args.persistent_data_loader_workers,
)

View File

@@ -210,6 +210,7 @@ def train(args):
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
pin_memory=args.pin_memory,
persistent_workers=args.persistent_data_loader_workers,
)

View File

@@ -577,6 +577,7 @@ class NetworkTrainer:
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
pin_memory=args.pin_memory,
persistent_workers=args.persistent_data_loader_workers,
)

View File

@@ -408,6 +408,7 @@ class TextualInversionTrainer:
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
pin_memory=args.pin_memory,
persistent_workers=args.persistent_data_loader_workers,
)

View File

@@ -316,6 +316,7 @@ def train(args):
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
pin_memory=args.pin_memory,
persistent_workers=args.persistent_data_loader_workers,
)