mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
reduce peak GPU memory usage before training
This commit is contained in:
@@ -471,7 +471,7 @@ class AttentionLinears(nn.Module):
|
|||||||
num_heads: int = 8,
|
num_heads: int = 8,
|
||||||
qkv_bias: bool = False,
|
qkv_bias: bool = False,
|
||||||
pre_only: bool = False,
|
pre_only: bool = False,
|
||||||
qk_norm: str = None,
|
qk_norm: Optional[str] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
|
|||||||
@@ -2410,6 +2410,7 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alph
|
|||||||
|
|
||||||
|
|
||||||
# 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top)
|
# 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top)
|
||||||
|
# TODO update to use CachingStrategy
|
||||||
def load_latents_from_disk(
|
def load_latents_from_disk(
|
||||||
npz_path,
|
npz_path,
|
||||||
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
||||||
|
|||||||
44
sd3_train.py
44
sd3_train.py
@@ -458,6 +458,28 @@ def train(args):
|
|||||||
# text_encoder1.text_model.encoder.layers[-1].requires_grad_(False)
|
# text_encoder1.text_model.encoder.layers[-1].requires_grad_(False)
|
||||||
# text_encoder1.text_model.final_layer_norm.requires_grad_(False)
|
# text_encoder1.text_model.final_layer_norm.requires_grad_(False)
|
||||||
|
|
||||||
|
# TextEncoderの出力をキャッシュするときには、すでに出力を取得済みなのでCPUへ移動する
|
||||||
|
if args.cache_text_encoder_outputs:
|
||||||
|
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
|
||||||
|
clip_l.to("cpu", dtype=torch.float32)
|
||||||
|
clip_g.to("cpu", dtype=torch.float32)
|
||||||
|
if t5xxl is not None:
|
||||||
|
t5xxl.to("cpu", dtype=torch.float32)
|
||||||
|
clean_memory_on_device(accelerator.device)
|
||||||
|
else:
|
||||||
|
# make sure Text Encoders are on GPU
|
||||||
|
# TODO support CPU for text encoders
|
||||||
|
clip_l.to(accelerator.device)
|
||||||
|
clip_g.to(accelerator.device)
|
||||||
|
if t5xxl is not None:
|
||||||
|
t5xxl.to(accelerator.device)
|
||||||
|
|
||||||
|
# TODO cache sample prompt's embeddings to free text encoder's memory
|
||||||
|
if args.cache_text_encoder_outputs:
|
||||||
|
if not args.save_t5xxl:
|
||||||
|
t5xxl = None # free memory
|
||||||
|
clean_memory_on_device(accelerator.device)
|
||||||
|
|
||||||
if args.deepspeed:
|
if args.deepspeed:
|
||||||
ds_model = deepspeed_utils.prepare_deepspeed_model(
|
ds_model = deepspeed_utils.prepare_deepspeed_model(
|
||||||
args,
|
args,
|
||||||
@@ -482,28 +504,6 @@ def train(args):
|
|||||||
# text_encoder2 = accelerator.prepare(text_encoder2)
|
# text_encoder2 = accelerator.prepare(text_encoder2)
|
||||||
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
|
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
|
||||||
|
|
||||||
# TextEncoderの出力をキャッシュするときには、すでに出力を取得済みなのでCPUへ移動する
|
|
||||||
if args.cache_text_encoder_outputs:
|
|
||||||
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
|
|
||||||
clip_l.to("cpu", dtype=torch.float32)
|
|
||||||
clip_g.to("cpu", dtype=torch.float32)
|
|
||||||
if t5xxl is not None:
|
|
||||||
t5xxl.to("cpu", dtype=torch.float32)
|
|
||||||
clean_memory_on_device(accelerator.device)
|
|
||||||
else:
|
|
||||||
# make sure Text Encoders are on GPU
|
|
||||||
# TODO support CPU for text encoders
|
|
||||||
clip_l.to(accelerator.device)
|
|
||||||
clip_g.to(accelerator.device)
|
|
||||||
if t5xxl is not None:
|
|
||||||
t5xxl.to(accelerator.device)
|
|
||||||
|
|
||||||
# TODO cache sample prompt's embeddings to free text encoder's memory
|
|
||||||
if args.cache_text_encoder_outputs:
|
|
||||||
if not args.save_t5xxl:
|
|
||||||
t5xxl = None # free memory
|
|
||||||
clean_memory_on_device(accelerator.device)
|
|
||||||
|
|
||||||
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
||||||
if args.full_fp16:
|
if args.full_fp16:
|
||||||
# During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do.
|
# During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do.
|
||||||
|
|||||||
Reference in New Issue
Block a user