mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
reduce peak VRAM usage by excluding some blocks to cuda
This commit is contained in:
@@ -251,7 +251,6 @@ def train(args):
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
# load FLUX
|
||||
# if we load to cpu, flux.to(fp8) takes a long time
|
||||
flux = flux_utils.load_flow_model(name, args.pretrained_model_name_or_path, weight_dtype, "cpu")
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
@@ -259,7 +258,8 @@ def train(args):
|
||||
|
||||
flux.requires_grad_(True)
|
||||
|
||||
if args.double_blocks_to_swap is not None or args.single_blocks_to_swap is not None:
|
||||
is_swapping_blocks = args.double_blocks_to_swap is not None or args.single_blocks_to_swap is not None
|
||||
if is_swapping_blocks:
|
||||
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
|
||||
# This idea is based on 2kpr's great work. Thank you!
|
||||
logger.info(
|
||||
@@ -412,8 +412,11 @@ def train(args):
|
||||
training_models = [ds_model]
|
||||
|
||||
else:
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
flux = accelerator.prepare(flux)
|
||||
# accelerator does some magic
|
||||
# if we doesn't swap blocks, we can move the model to device
|
||||
flux = accelerator.prepare(flux, device_placement=[not is_swapping_blocks])
|
||||
if is_swapping_blocks:
|
||||
flux.move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
|
||||
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
||||
@@ -539,7 +542,7 @@ def train(args):
|
||||
init_kwargs=init_kwargs,
|
||||
)
|
||||
|
||||
if args.double_blocks_to_swap is not None or args.single_blocks_to_swap is not None:
|
||||
if is_swapping_blocks:
|
||||
flux.prepare_block_swap_before_forward()
|
||||
|
||||
# For --sample_at_first
|
||||
@@ -595,7 +598,7 @@ def train(args):
|
||||
|
||||
# get noisy model input and timesteps
|
||||
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
|
||||
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
|
||||
args, noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype
|
||||
)
|
||||
|
||||
# pack latents and get img_ids
|
||||
|
||||
@@ -953,6 +953,22 @@ class Flux(nn.Module):
|
||||
self.double_blocks_to_swap = double_blocks
|
||||
self.single_blocks_to_swap = single_blocks
|
||||
|
||||
def move_to_device_except_swap_blocks(self, device: torch.device):
|
||||
# assume model is on cpu
|
||||
if self.double_blocks_to_swap:
|
||||
save_double_blocks = self.double_blocks
|
||||
self.double_blocks = None
|
||||
if self.single_blocks_to_swap:
|
||||
save_single_blocks = self.single_blocks
|
||||
self.single_blocks = None
|
||||
|
||||
self.to(device)
|
||||
|
||||
if self.double_blocks_to_swap:
|
||||
self.double_blocks = save_double_blocks
|
||||
if self.single_blocks_to_swap:
|
||||
self.single_blocks = save_single_blocks
|
||||
|
||||
def prepare_block_swap_before_forward(self):
|
||||
# move last n blocks to cpu: they are on cuda
|
||||
if self.double_blocks_to_swap:
|
||||
|
||||
Reference in New Issue
Block a user