reduce peak GPU memory usage before training

This commit is contained in:
Kohya S
2024-07-12 21:28:01 +09:00
parent b8896aad40
commit 082f13658b
3 changed files with 24 additions and 23 deletions

View File

@@ -471,7 +471,7 @@ class AttentionLinears(nn.Module):
num_heads: int = 8, num_heads: int = 8,
qkv_bias: bool = False, qkv_bias: bool = False,
pre_only: bool = False, pre_only: bool = False,
qk_norm: str = None, qk_norm: Optional[str] = None,
): ):
super().__init__() super().__init__()
self.num_heads = num_heads self.num_heads = num_heads

View File

@@ -2410,6 +2410,7 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alph
# 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top) # 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top)
# TODO update to use CachingStrategy
def load_latents_from_disk( def load_latents_from_disk(
npz_path, npz_path,
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:

View File

@@ -458,6 +458,28 @@ def train(args):
# text_encoder1.text_model.encoder.layers[-1].requires_grad_(False) # text_encoder1.text_model.encoder.layers[-1].requires_grad_(False)
# text_encoder1.text_model.final_layer_norm.requires_grad_(False) # text_encoder1.text_model.final_layer_norm.requires_grad_(False)
# TextEncoderの出力をキャッシュするときには、すでに出力を取得済みなのでCPUへ移動する
if args.cache_text_encoder_outputs:
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
clip_l.to("cpu", dtype=torch.float32)
clip_g.to("cpu", dtype=torch.float32)
if t5xxl is not None:
t5xxl.to("cpu", dtype=torch.float32)
clean_memory_on_device(accelerator.device)
else:
# make sure Text Encoders are on GPU
# TODO support CPU for text encoders
clip_l.to(accelerator.device)
clip_g.to(accelerator.device)
if t5xxl is not None:
t5xxl.to(accelerator.device)
# TODO cache sample prompt's embeddings to free text encoder's memory
if args.cache_text_encoder_outputs:
if not args.save_t5xxl:
t5xxl = None # free memory
clean_memory_on_device(accelerator.device)
if args.deepspeed: if args.deepspeed:
ds_model = deepspeed_utils.prepare_deepspeed_model( ds_model = deepspeed_utils.prepare_deepspeed_model(
args, args,
@@ -482,28 +504,6 @@ def train(args):
# text_encoder2 = accelerator.prepare(text_encoder2) # text_encoder2 = accelerator.prepare(text_encoder2)
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
# TextEncoderの出力をキャッシュするときには、すでに出力を取得済みなのでCPUへ移動する
if args.cache_text_encoder_outputs:
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
clip_l.to("cpu", dtype=torch.float32)
clip_g.to("cpu", dtype=torch.float32)
if t5xxl is not None:
t5xxl.to("cpu", dtype=torch.float32)
clean_memory_on_device(accelerator.device)
else:
# make sure Text Encoders are on GPU
# TODO support CPU for text encoders
clip_l.to(accelerator.device)
clip_g.to(accelerator.device)
if t5xxl is not None:
t5xxl.to(accelerator.device)
# TODO cache sample prompt's embeddings to free text encoder's memory
if args.cache_text_encoder_outputs:
if not args.save_t5xxl:
t5xxl = None # free memory
clean_memory_on_device(accelerator.device)
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする # 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16: if args.full_fp16:
# During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do. # During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do.