support block swap with fused_optimizer_pass

This commit is contained in:
Kohya S
2024-10-24 22:02:05 +09:00
parent 0286114bd2
commit f8c5146d71
2 changed files with 94 additions and 4 deletions

View File

@@ -4,6 +4,7 @@
# and some module/classes are contributed from KohakuBlueleaf. Thanks for the contribution! # and some module/classes are contributed from KohakuBlueleaf. Thanks for the contribution!
from ast import Tuple from ast import Tuple
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial from functools import partial
import math import math
@@ -17,6 +18,8 @@ import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
from transformers import CLIPTokenizer, T5TokenizerFast from transformers import CLIPTokenizer, T5TokenizerFast
from library.device_utils import clean_memory_on_device
from .utils import setup_logging from .utils import setup_logging
setup_logging() setup_logging()
@@ -848,6 +851,35 @@ class MMDiT(nn.Module):
spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1]) spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
return spatial_pos_embed return spatial_pos_embed
def enable_block_swap(self, num_blocks: int):
self.blocks_to_swap = num_blocks
n = 1 # async block swap. 1 is enough
self.thread_pool = ThreadPoolExecutor(max_workers=n)
def move_to_device_except_swap_blocks(self, device: torch.device):
# assume model is on cpu
if self.blocks_to_swap:
save_blocks = self.joint_blocks
self.joint_blocks = None
self.to(device)
if self.blocks_to_swap:
self.joint_blocks = save_blocks
def prepare_block_swap_before_forward(self):
# make: first n blocks are on cuda, and last n blocks are on cpu
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
# raise ValueError("Block swap is not enabled.")
return
num_blocks = len(self.joint_blocks)
for i in range(num_blocks - self.blocks_to_swap):
self.joint_blocks[i].to(self.device)
for i in range(num_blocks - self.blocks_to_swap, num_blocks):
self.joint_blocks[i].to("cpu")
clean_memory_on_device(self.device)
def forward( def forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
@@ -881,8 +913,51 @@ class MMDiT(nn.Module):
1, 1,
) )
for block in self.joint_blocks: if not self.blocks_to_swap:
context, x = block(context, x, c) for block in self.joint_blocks:
context, x = block(context, x, c)
else:
futures = {}
def submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda):
def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda):
# print(f"Moving {bidx_to_cpu} to cpu.")
block_to_cpu.to("cpu", non_blocking=True)
torch.cuda.empty_cache()
# print(f"Moving {bidx_to_cuda} to cuda.")
block_to_cuda.to(self.device, non_blocking=True)
torch.cuda.synchronize()
# print(f"Block move done. {bidx_to_cpu} to cpu, {bidx_to_cuda} to cuda.")
return block_idx_to_cpu, block_idx_to_cuda
block_to_cpu = self.joint_blocks[block_idx_to_cpu]
block_to_cuda = self.joint_blocks[block_idx_to_cuda]
# print(f"Submit move blocks. {block_idx_to_cpu} to cpu, {block_idx_to_cuda} to cuda.")
return self.thread_pool.submit(move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda)
def wait_for_blocks_move(block_idx, ftrs):
if block_idx not in ftrs:
return
# print(f"Waiting for move blocks: {block_idx}")
# start_time = time.perf_counter()
ftr = ftrs.pop(block_idx)
ftr.result()
# torch.cuda.synchronize()
# print(f"Move blocks took {time.perf_counter() - start_time:.2f} seconds")
for block_idx, block in enumerate(self.joint_blocks):
wait_for_blocks_move(block_idx, futures)
context, x = block(context, x, c)
if block_idx < self.blocks_to_swap:
block_idx_to_cpu = block_idx
block_idx_to_cuda = len(self.joint_blocks) - self.blocks_to_swap + block_idx
future = submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda)
futures[block_idx_to_cuda] = future
x = self.final_layer(x, c, H, W) # Our final layer combined UnPatchify x = self.final_layer(x, c, H, W) # Our final layer combined UnPatchify
return x[:, :, :H, :W] return x[:, :, :H, :W]

View File

@@ -369,6 +369,14 @@ def train(args):
if not train_mmdit: if not train_mmdit:
mmdit.to(accelerator.device, dtype=weight_dtype) # because of mmdit will not be prepared mmdit.to(accelerator.device, dtype=weight_dtype) # because of mmdit will not be prepared
# block swap
is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
if is_swapping_blocks:
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
# This idea is based on 2kpr's great work. Thank you!
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
mmdit.enable_block_swap(args.blocks_to_swap)
if not cache_latents: if not cache_latents:
# move to accelerator device # move to accelerator device
vae.requires_grad_(False) vae.requires_grad_(False)
@@ -575,7 +583,9 @@ def train(args):
else: else:
# acceleratorがなんかよろしくやってくれるらしい # acceleratorがなんかよろしくやってくれるらしい
if train_mmdit: if train_mmdit:
mmdit = accelerator.prepare(mmdit) mmdit = accelerator.prepare(mmdit, device_placement=[not is_swapping_blocks])
if is_swapping_blocks:
accelerator.unwrap_model(mmdit).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
if train_clip: if train_clip:
clip_l = accelerator.prepare(clip_l) clip_l = accelerator.prepare(clip_l)
clip_g = accelerator.prepare(clip_g) clip_g = accelerator.prepare(clip_g)
@@ -600,8 +610,10 @@ def train(args):
block_to_cpu = block_to_cpu.to("cpu", non_blocking=True) block_to_cpu = block_to_cpu.to("cpu", non_blocking=True)
torch.cuda.empty_cache() torch.cuda.empty_cache()
# print(f"Backward: Move block {bidx_to_cuda} to CUDA")
block_to_cuda = block_to_cuda.to(dvc, non_blocking=True) block_to_cuda = block_to_cuda.to(dvc, non_blocking=True)
torch.cuda.synchronize() torch.cuda.synchronize()
# print(f"Backward: Done moving blocks {bidx_to_cpu} and {bidx_to_cuda}")
return bidx_to_cpu, bidx_to_cuda return bidx_to_cpu, bidx_to_cuda
block_to_cpu = blocks[block_idx_to_cpu] block_to_cpu = blocks[block_idx_to_cpu]
@@ -639,7 +651,7 @@ def train(args):
grad_hook = None grad_hook = None
if blocks_to_swap: if blocks_to_swap:
is_block = param_name.startswith("double_blocks") is_block = param_name.startswith("joint_blocks")
if is_block: if is_block:
block_idx = int(param_name.split(".")[1]) block_idx = int(param_name.split(".")[1])
if block_idx not in handled_block_indices: if block_idx not in handled_block_indices:
@@ -805,6 +817,9 @@ def train(args):
init_kwargs=init_kwargs, init_kwargs=init_kwargs,
) )
if is_swapping_blocks:
accelerator.unwrap_model(mmdit).prepare_block_swap_before_forward()
# For --sample_at_first # For --sample_at_first
optimizer_eval_fn() optimizer_eval_fn()
sd3_train_utils.sample_images(accelerator, args, 0, global_step, mmdit, vae, [clip_l, clip_g, t5xxl], sample_prompts_te_outputs) sd3_train_utils.sample_images(accelerator, args, 0, global_step, mmdit, vae, [clip_l, clip_g, t5xxl], sample_prompts_te_outputs)