diff --git a/library/sd3_models.py b/library/sd3_models.py index c81aa479..e5c5887a 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -4,6 +4,7 @@ # and some module/classes are contributed from KohakuBlueleaf. Thanks for the contribution! from ast import Tuple +from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from functools import partial import math @@ -17,6 +18,8 @@ import torch.nn.functional as F from torch.utils.checkpoint import checkpoint from transformers import CLIPTokenizer, T5TokenizerFast +from library.device_utils import clean_memory_on_device + from .utils import setup_logging setup_logging() @@ -848,6 +851,35 @@ class MMDiT(nn.Module): spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1]) return spatial_pos_embed + def enable_block_swap(self, num_blocks: int): + self.blocks_to_swap = num_blocks + + n = 1 # async block swap. 1 is enough + self.thread_pool = ThreadPoolExecutor(max_workers=n) + + def move_to_device_except_swap_blocks(self, device: torch.device): + # assume model is on cpu + if self.blocks_to_swap: + save_blocks = self.joint_blocks + self.joint_blocks = None + + self.to(device) + + if self.blocks_to_swap: + self.joint_blocks = save_blocks + + def prepare_block_swap_before_forward(self): + # make: first n blocks are on cuda, and last n blocks are on cpu + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + # raise ValueError("Block swap is not enabled.") + return + num_blocks = len(self.joint_blocks) + for i in range(num_blocks - self.blocks_to_swap): + self.joint_blocks[i].to(self.device) + for i in range(num_blocks - self.blocks_to_swap, num_blocks): + self.joint_blocks[i].to("cpu") + clean_memory_on_device(self.device) + def forward( self, x: torch.Tensor, @@ -881,8 +913,51 @@ class MMDiT(nn.Module): 1, ) - for block in self.joint_blocks: - context, x = block(context, x, c) + if not self.blocks_to_swap: + for block in self.joint_blocks: + context, x = block(context, x, c) + else: + futures = {} + + def submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda): + def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda): + # print(f"Moving {bidx_to_cpu} to cpu.") + block_to_cpu.to("cpu", non_blocking=True) + torch.cuda.empty_cache() + + # print(f"Moving {bidx_to_cuda} to cuda.") + block_to_cuda.to(self.device, non_blocking=True) + + torch.cuda.synchronize() + # print(f"Block move done. {bidx_to_cpu} to cpu, {bidx_to_cuda} to cuda.") + return block_idx_to_cpu, block_idx_to_cuda + + block_to_cpu = self.joint_blocks[block_idx_to_cpu] + block_to_cuda = self.joint_blocks[block_idx_to_cuda] + # print(f"Submit move blocks. {block_idx_to_cpu} to cpu, {block_idx_to_cuda} to cuda.") + return self.thread_pool.submit(move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda) + + def wait_for_blocks_move(block_idx, ftrs): + if block_idx not in ftrs: + return + # print(f"Waiting for move blocks: {block_idx}") + # start_time = time.perf_counter() + ftr = ftrs.pop(block_idx) + ftr.result() + # torch.cuda.synchronize() + # print(f"Move blocks took {time.perf_counter() - start_time:.2f} seconds") + + for block_idx, block in enumerate(self.joint_blocks): + wait_for_blocks_move(block_idx, futures) + + context, x = block(context, x, c) + + if block_idx < self.blocks_to_swap: + block_idx_to_cpu = block_idx + block_idx_to_cuda = len(self.joint_blocks) - self.blocks_to_swap + block_idx + future = submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda) + futures[block_idx_to_cuda] = future + x = self.final_layer(x, c, H, W) # Our final layer combined UnPatchify return x[:, :, :H, :W] diff --git a/sd3_train.py b/sd3_train.py index d4ab13a3..5e2efa6f 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -369,6 +369,14 @@ def train(args): if not train_mmdit: mmdit.to(accelerator.device, dtype=weight_dtype) # because of mmdit will not be prepared + # block swap + is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 + if is_swapping_blocks: + # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. + # This idea is based on 2kpr's great work. Thank you! + logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") + mmdit.enable_block_swap(args.blocks_to_swap) + if not cache_latents: # move to accelerator device vae.requires_grad_(False) @@ -575,7 +583,9 @@ def train(args): else: # acceleratorがなんかよろしくやってくれるらしい if train_mmdit: - mmdit = accelerator.prepare(mmdit) + mmdit = accelerator.prepare(mmdit, device_placement=[not is_swapping_blocks]) + if is_swapping_blocks: + accelerator.unwrap_model(mmdit).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage if train_clip: clip_l = accelerator.prepare(clip_l) clip_g = accelerator.prepare(clip_g) @@ -600,8 +610,10 @@ def train(args): block_to_cpu = block_to_cpu.to("cpu", non_blocking=True) torch.cuda.empty_cache() + # print(f"Backward: Move block {bidx_to_cuda} to CUDA") block_to_cuda = block_to_cuda.to(dvc, non_blocking=True) torch.cuda.synchronize() + # print(f"Backward: Done moving blocks {bidx_to_cpu} and {bidx_to_cuda}") return bidx_to_cpu, bidx_to_cuda block_to_cpu = blocks[block_idx_to_cpu] @@ -639,7 +651,7 @@ def train(args): grad_hook = None if blocks_to_swap: - is_block = param_name.startswith("double_blocks") + is_block = param_name.startswith("joint_blocks") if is_block: block_idx = int(param_name.split(".")[1]) if block_idx not in handled_block_indices: @@ -805,6 +817,9 @@ def train(args): init_kwargs=init_kwargs, ) + if is_swapping_blocks: + accelerator.unwrap_model(mmdit).prepare_block_swap_before_forward() + # For --sample_at_first optimizer_eval_fn() sd3_train_utils.sample_images(accelerator, args, 0, global_step, mmdit, vae, [clip_l, clip_g, t5xxl], sample_prompts_te_outputs)