mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
support block swap with fused_optimizer_pass
This commit is contained in:
@@ -4,6 +4,7 @@
|
|||||||
# and some module/classes are contributed from KohakuBlueleaf. Thanks for the contribution!
|
# and some module/classes are contributed from KohakuBlueleaf. Thanks for the contribution!
|
||||||
|
|
||||||
from ast import Tuple
|
from ast import Tuple
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import math
|
import math
|
||||||
@@ -17,6 +18,8 @@ import torch.nn.functional as F
|
|||||||
from torch.utils.checkpoint import checkpoint
|
from torch.utils.checkpoint import checkpoint
|
||||||
from transformers import CLIPTokenizer, T5TokenizerFast
|
from transformers import CLIPTokenizer, T5TokenizerFast
|
||||||
|
|
||||||
|
from library.device_utils import clean_memory_on_device
|
||||||
|
|
||||||
from .utils import setup_logging
|
from .utils import setup_logging
|
||||||
|
|
||||||
setup_logging()
|
setup_logging()
|
||||||
@@ -848,6 +851,35 @@ class MMDiT(nn.Module):
|
|||||||
spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
|
spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
|
||||||
return spatial_pos_embed
|
return spatial_pos_embed
|
||||||
|
|
||||||
|
def enable_block_swap(self, num_blocks: int):
|
||||||
|
self.blocks_to_swap = num_blocks
|
||||||
|
|
||||||
|
n = 1 # async block swap. 1 is enough
|
||||||
|
self.thread_pool = ThreadPoolExecutor(max_workers=n)
|
||||||
|
|
||||||
|
def move_to_device_except_swap_blocks(self, device: torch.device):
|
||||||
|
# assume model is on cpu
|
||||||
|
if self.blocks_to_swap:
|
||||||
|
save_blocks = self.joint_blocks
|
||||||
|
self.joint_blocks = None
|
||||||
|
|
||||||
|
self.to(device)
|
||||||
|
|
||||||
|
if self.blocks_to_swap:
|
||||||
|
self.joint_blocks = save_blocks
|
||||||
|
|
||||||
|
def prepare_block_swap_before_forward(self):
|
||||||
|
# make: first n blocks are on cuda, and last n blocks are on cpu
|
||||||
|
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
||||||
|
# raise ValueError("Block swap is not enabled.")
|
||||||
|
return
|
||||||
|
num_blocks = len(self.joint_blocks)
|
||||||
|
for i in range(num_blocks - self.blocks_to_swap):
|
||||||
|
self.joint_blocks[i].to(self.device)
|
||||||
|
for i in range(num_blocks - self.blocks_to_swap, num_blocks):
|
||||||
|
self.joint_blocks[i].to("cpu")
|
||||||
|
clean_memory_on_device(self.device)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
@@ -881,8 +913,51 @@ class MMDiT(nn.Module):
|
|||||||
1,
|
1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not self.blocks_to_swap:
|
||||||
for block in self.joint_blocks:
|
for block in self.joint_blocks:
|
||||||
context, x = block(context, x, c)
|
context, x = block(context, x, c)
|
||||||
|
else:
|
||||||
|
futures = {}
|
||||||
|
|
||||||
|
def submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda):
|
||||||
|
def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda):
|
||||||
|
# print(f"Moving {bidx_to_cpu} to cpu.")
|
||||||
|
block_to_cpu.to("cpu", non_blocking=True)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# print(f"Moving {bidx_to_cuda} to cuda.")
|
||||||
|
block_to_cuda.to(self.device, non_blocking=True)
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
# print(f"Block move done. {bidx_to_cpu} to cpu, {bidx_to_cuda} to cuda.")
|
||||||
|
return block_idx_to_cpu, block_idx_to_cuda
|
||||||
|
|
||||||
|
block_to_cpu = self.joint_blocks[block_idx_to_cpu]
|
||||||
|
block_to_cuda = self.joint_blocks[block_idx_to_cuda]
|
||||||
|
# print(f"Submit move blocks. {block_idx_to_cpu} to cpu, {block_idx_to_cuda} to cuda.")
|
||||||
|
return self.thread_pool.submit(move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda)
|
||||||
|
|
||||||
|
def wait_for_blocks_move(block_idx, ftrs):
|
||||||
|
if block_idx not in ftrs:
|
||||||
|
return
|
||||||
|
# print(f"Waiting for move blocks: {block_idx}")
|
||||||
|
# start_time = time.perf_counter()
|
||||||
|
ftr = ftrs.pop(block_idx)
|
||||||
|
ftr.result()
|
||||||
|
# torch.cuda.synchronize()
|
||||||
|
# print(f"Move blocks took {time.perf_counter() - start_time:.2f} seconds")
|
||||||
|
|
||||||
|
for block_idx, block in enumerate(self.joint_blocks):
|
||||||
|
wait_for_blocks_move(block_idx, futures)
|
||||||
|
|
||||||
|
context, x = block(context, x, c)
|
||||||
|
|
||||||
|
if block_idx < self.blocks_to_swap:
|
||||||
|
block_idx_to_cpu = block_idx
|
||||||
|
block_idx_to_cuda = len(self.joint_blocks) - self.blocks_to_swap + block_idx
|
||||||
|
future = submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda)
|
||||||
|
futures[block_idx_to_cuda] = future
|
||||||
|
|
||||||
x = self.final_layer(x, c, H, W) # Our final layer combined UnPatchify
|
x = self.final_layer(x, c, H, W) # Our final layer combined UnPatchify
|
||||||
return x[:, :, :H, :W]
|
return x[:, :, :H, :W]
|
||||||
|
|
||||||
|
|||||||
19
sd3_train.py
19
sd3_train.py
@@ -369,6 +369,14 @@ def train(args):
|
|||||||
if not train_mmdit:
|
if not train_mmdit:
|
||||||
mmdit.to(accelerator.device, dtype=weight_dtype) # because of mmdit will not be prepared
|
mmdit.to(accelerator.device, dtype=weight_dtype) # because of mmdit will not be prepared
|
||||||
|
|
||||||
|
# block swap
|
||||||
|
is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
|
||||||
|
if is_swapping_blocks:
|
||||||
|
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
|
||||||
|
# This idea is based on 2kpr's great work. Thank you!
|
||||||
|
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
|
||||||
|
mmdit.enable_block_swap(args.blocks_to_swap)
|
||||||
|
|
||||||
if not cache_latents:
|
if not cache_latents:
|
||||||
# move to accelerator device
|
# move to accelerator device
|
||||||
vae.requires_grad_(False)
|
vae.requires_grad_(False)
|
||||||
@@ -575,7 +583,9 @@ def train(args):
|
|||||||
else:
|
else:
|
||||||
# acceleratorがなんかよろしくやってくれるらしい
|
# acceleratorがなんかよろしくやってくれるらしい
|
||||||
if train_mmdit:
|
if train_mmdit:
|
||||||
mmdit = accelerator.prepare(mmdit)
|
mmdit = accelerator.prepare(mmdit, device_placement=[not is_swapping_blocks])
|
||||||
|
if is_swapping_blocks:
|
||||||
|
accelerator.unwrap_model(mmdit).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
|
||||||
if train_clip:
|
if train_clip:
|
||||||
clip_l = accelerator.prepare(clip_l)
|
clip_l = accelerator.prepare(clip_l)
|
||||||
clip_g = accelerator.prepare(clip_g)
|
clip_g = accelerator.prepare(clip_g)
|
||||||
@@ -600,8 +610,10 @@ def train(args):
|
|||||||
block_to_cpu = block_to_cpu.to("cpu", non_blocking=True)
|
block_to_cpu = block_to_cpu.to("cpu", non_blocking=True)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# print(f"Backward: Move block {bidx_to_cuda} to CUDA")
|
||||||
block_to_cuda = block_to_cuda.to(dvc, non_blocking=True)
|
block_to_cuda = block_to_cuda.to(dvc, non_blocking=True)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
# print(f"Backward: Done moving blocks {bidx_to_cpu} and {bidx_to_cuda}")
|
||||||
return bidx_to_cpu, bidx_to_cuda
|
return bidx_to_cpu, bidx_to_cuda
|
||||||
|
|
||||||
block_to_cpu = blocks[block_idx_to_cpu]
|
block_to_cpu = blocks[block_idx_to_cpu]
|
||||||
@@ -639,7 +651,7 @@ def train(args):
|
|||||||
grad_hook = None
|
grad_hook = None
|
||||||
|
|
||||||
if blocks_to_swap:
|
if blocks_to_swap:
|
||||||
is_block = param_name.startswith("double_blocks")
|
is_block = param_name.startswith("joint_blocks")
|
||||||
if is_block:
|
if is_block:
|
||||||
block_idx = int(param_name.split(".")[1])
|
block_idx = int(param_name.split(".")[1])
|
||||||
if block_idx not in handled_block_indices:
|
if block_idx not in handled_block_indices:
|
||||||
@@ -805,6 +817,9 @@ def train(args):
|
|||||||
init_kwargs=init_kwargs,
|
init_kwargs=init_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if is_swapping_blocks:
|
||||||
|
accelerator.unwrap_model(mmdit).prepare_block_swap_before_forward()
|
||||||
|
|
||||||
# For --sample_at_first
|
# For --sample_at_first
|
||||||
optimizer_eval_fn()
|
optimizer_eval_fn()
|
||||||
sd3_train_utils.sample_images(accelerator, args, 0, global_step, mmdit, vae, [clip_l, clip_g, t5xxl], sample_prompts_te_outputs)
|
sd3_train_utils.sample_images(accelerator, args, 0, global_step, mmdit, vae, [clip_l, clip_g, t5xxl], sample_prompts_te_outputs)
|
||||||
|
|||||||
Reference in New Issue
Block a user