From 082f13658bdbaed872ede6c0a7a75ab1a5f3712d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 12 Jul 2024 21:28:01 +0900 Subject: [PATCH] reduce peak GPU memory usage before training --- library/sd3_models.py | 2 +- library/train_util.py | 1 + sd3_train.py | 44 +++++++++++++++++++++---------------------- 3 files changed, 24 insertions(+), 23 deletions(-) diff --git a/library/sd3_models.py b/library/sd3_models.py index a1ff1e75..ec8e1bbd 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -471,7 +471,7 @@ class AttentionLinears(nn.Module): num_heads: int = 8, qkv_bias: bool = False, pre_only: bool = False, - qk_norm: str = None, + qk_norm: Optional[str] = None, ): super().__init__() self.num_heads = num_heads diff --git a/library/train_util.py b/library/train_util.py index 9db226ea..7af0070e 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2410,6 +2410,7 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alph # 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top) +# TODO update to use CachingStrategy def load_latents_from_disk( npz_path, ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: diff --git a/sd3_train.py b/sd3_train.py index e2f622e4..f34e4712 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -458,6 +458,28 @@ def train(args): # text_encoder1.text_model.encoder.layers[-1].requires_grad_(False) # text_encoder1.text_model.final_layer_norm.requires_grad_(False) + # TextEncoderの出力をキャッシュするときには、すでに出力を取得済みなのでCPUへ移動する + if args.cache_text_encoder_outputs: + # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 + clip_l.to("cpu", dtype=torch.float32) + clip_g.to("cpu", dtype=torch.float32) + if t5xxl is not None: + t5xxl.to("cpu", dtype=torch.float32) + clean_memory_on_device(accelerator.device) + else: + # make sure Text Encoders are on GPU + # TODO support CPU for text encoders + clip_l.to(accelerator.device) + clip_g.to(accelerator.device) + if t5xxl is not None: + t5xxl.to(accelerator.device) + + # TODO cache sample prompt's embeddings to free text encoder's memory + if args.cache_text_encoder_outputs: + if not args.save_t5xxl: + t5xxl = None # free memory + clean_memory_on_device(accelerator.device) + if args.deepspeed: ds_model = deepspeed_utils.prepare_deepspeed_model( args, @@ -482,28 +504,6 @@ def train(args): # text_encoder2 = accelerator.prepare(text_encoder2) optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) - # TextEncoderの出力をキャッシュするときには、すでに出力を取得済みなのでCPUへ移動する - if args.cache_text_encoder_outputs: - # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 - clip_l.to("cpu", dtype=torch.float32) - clip_g.to("cpu", dtype=torch.float32) - if t5xxl is not None: - t5xxl.to("cpu", dtype=torch.float32) - clean_memory_on_device(accelerator.device) - else: - # make sure Text Encoders are on GPU - # TODO support CPU for text encoders - clip_l.to(accelerator.device) - clip_g.to(accelerator.device) - if t5xxl is not None: - t5xxl.to(accelerator.device) - - # TODO cache sample prompt's embeddings to free text encoder's memory - if args.cache_text_encoder_outputs: - if not args.save_t5xxl: - t5xxl = None # free memory - clean_memory_on_device(accelerator.device) - # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする if args.full_fp16: # During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do.