reduce peak VRAM usage by excluding some blocks to cuda

This commit is contained in:
Kohya S
2024-08-19 21:55:28 +09:00
parent d034032a5d
commit 6e72a799c8
2 changed files with 25 additions and 6 deletions

View File

@@ -251,7 +251,6 @@ def train(args):
clean_memory_on_device(accelerator.device)
# load FLUX
# if we load to cpu, flux.to(fp8) takes a long time
flux = flux_utils.load_flow_model(name, args.pretrained_model_name_or_path, weight_dtype, "cpu")
if args.gradient_checkpointing:
@@ -259,7 +258,8 @@ def train(args):
flux.requires_grad_(True)
if args.double_blocks_to_swap is not None or args.single_blocks_to_swap is not None:
is_swapping_blocks = args.double_blocks_to_swap is not None or args.single_blocks_to_swap is not None
if is_swapping_blocks:
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
# This idea is based on 2kpr's great work. Thank you!
logger.info(
@@ -412,8 +412,11 @@ def train(args):
training_models = [ds_model]
else:
# acceleratorがなんかよろしくやってくれるらしい
flux = accelerator.prepare(flux)
# accelerator does some magic
# if we doesn't swap blocks, we can move the model to device
flux = accelerator.prepare(flux, device_placement=[not is_swapping_blocks])
if is_swapping_blocks:
flux.move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
@@ -539,7 +542,7 @@ def train(args):
init_kwargs=init_kwargs,
)
if args.double_blocks_to_swap is not None or args.single_blocks_to_swap is not None:
if is_swapping_blocks:
flux.prepare_block_swap_before_forward()
# For --sample_at_first
@@ -595,7 +598,7 @@ def train(args):
# get noisy model input and timesteps
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
args, noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype
)
# pack latents and get img_ids

View File

@@ -953,6 +953,22 @@ class Flux(nn.Module):
self.double_blocks_to_swap = double_blocks
self.single_blocks_to_swap = single_blocks
def move_to_device_except_swap_blocks(self, device: torch.device):
# assume model is on cpu
if self.double_blocks_to_swap:
save_double_blocks = self.double_blocks
self.double_blocks = None
if self.single_blocks_to_swap:
save_single_blocks = self.single_blocks
self.single_blocks = None
self.to(device)
if self.double_blocks_to_swap:
self.double_blocks = save_double_blocks
if self.single_blocks_to_swap:
self.single_blocks = save_single_blocks
def prepare_block_swap_before_forward(self):
# move last n blocks to cpu: they are on cuda
if self.double_blocks_to_swap: