mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
add memory efficient training for FLUX.1
This commit is contained in:
@@ -4,6 +4,11 @@
|
||||
|
||||
from dataclasses import dataclass
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
@@ -466,6 +471,33 @@ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tenso
|
||||
|
||||
|
||||
# region layers
|
||||
|
||||
|
||||
# for cpu_offload_checkpointing
|
||||
|
||||
|
||||
def to_cuda(x):
|
||||
if isinstance(x, torch.Tensor):
|
||||
return x.cuda()
|
||||
elif isinstance(x, (list, tuple)):
|
||||
return [to_cuda(elem) for elem in x]
|
||||
elif isinstance(x, dict):
|
||||
return {k: to_cuda(v) for k, v in x.items()}
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
def to_cpu(x):
|
||||
if isinstance(x, torch.Tensor):
|
||||
return x.cpu()
|
||||
elif isinstance(x, (list, tuple)):
|
||||
return [to_cpu(elem) for elem in x]
|
||||
elif isinstance(x, dict):
|
||||
return {k: to_cpu(v) for k, v in x.items()}
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
class EmbedND(nn.Module):
|
||||
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
|
||||
super().__init__()
|
||||
@@ -648,16 +680,15 @@ class DoubleStreamBlock(nn.Module):
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
self.cpu_offload_checkpointing = False
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
def enable_gradient_checkpointing(self, cpu_offload: bool = False):
|
||||
self.gradient_checkpointing = True
|
||||
# self.img_attn.enable_gradient_checkpointing()
|
||||
# self.txt_attn.enable_gradient_checkpointing()
|
||||
self.cpu_offload_checkpointing = cpu_offload
|
||||
|
||||
def disable_gradient_checkpointing(self):
|
||||
self.gradient_checkpointing = False
|
||||
# self.img_attn.disable_gradient_checkpointing()
|
||||
# self.txt_attn.disable_gradient_checkpointing()
|
||||
self.cpu_offload_checkpointing = False
|
||||
|
||||
def _forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
|
||||
img_mod1, img_mod2 = self.img_mod(vec)
|
||||
@@ -694,11 +725,24 @@ class DoubleStreamBlock(nn.Module):
|
||||
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
||||
return img, txt
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
|
||||
if self.training and self.gradient_checkpointing:
|
||||
return checkpoint(self._forward, *args, use_reentrant=False, **kwargs)
|
||||
if not self.cpu_offload_checkpointing:
|
||||
return checkpoint(self._forward, img, txt, vec, pe, use_reentrant=False)
|
||||
# cpu offload checkpointing
|
||||
|
||||
def create_custom_forward(func):
|
||||
def custom_forward(*inputs):
|
||||
cuda_inputs = to_cuda(inputs)
|
||||
outputs = func(*cuda_inputs)
|
||||
return to_cpu(outputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), img, txt, vec, pe)
|
||||
|
||||
else:
|
||||
return self._forward(*args, **kwargs)
|
||||
return self._forward(img, txt, vec, pe)
|
||||
|
||||
# def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor):
|
||||
# if self.training and self.gradient_checkpointing:
|
||||
@@ -747,12 +791,15 @@ class SingleStreamBlock(nn.Module):
|
||||
self.modulation = Modulation(hidden_size, double=False)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
self.cpu_offload_checkpointing = False
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
def enable_gradient_checkpointing(self, cpu_offload: bool = False):
|
||||
self.gradient_checkpointing = True
|
||||
self.cpu_offload_checkpointing = cpu_offload
|
||||
|
||||
def disable_gradient_checkpointing(self):
|
||||
self.gradient_checkpointing = False
|
||||
self.cpu_offload_checkpointing = False
|
||||
|
||||
def _forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
||||
mod, _ = self.modulation(vec)
|
||||
@@ -768,11 +815,24 @@ class SingleStreamBlock(nn.Module):
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
return x + mod.gate * output
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
||||
if self.training and self.gradient_checkpointing:
|
||||
return checkpoint(self._forward, *args, use_reentrant=False, **kwargs)
|
||||
if not self.cpu_offload_checkpointing:
|
||||
return checkpoint(self._forward, x, vec, pe, use_reentrant=False)
|
||||
|
||||
# cpu offload checkpointing
|
||||
|
||||
def create_custom_forward(func):
|
||||
def custom_forward(*inputs):
|
||||
cuda_inputs = to_cuda(inputs)
|
||||
outputs = func(*cuda_inputs)
|
||||
return to_cpu(outputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, vec, pe)
|
||||
else:
|
||||
return self._forward(*args, **kwargs)
|
||||
return self._forward(x, vec, pe)
|
||||
|
||||
# def forward(self, x: Tensor, vec: Tensor, pe: Tensor):
|
||||
# if self.training and self.gradient_checkpointing:
|
||||
@@ -849,6 +909,9 @@ class Flux(nn.Module):
|
||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
self.cpu_offload_checkpointing = False
|
||||
self.double_blocks_to_swap = None
|
||||
self.single_blocks_to_swap = None
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
@@ -858,8 +921,9 @@ class Flux(nn.Module):
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
def enable_gradient_checkpointing(self, cpu_offload: bool = False):
|
||||
self.gradient_checkpointing = True
|
||||
self.cpu_offload_checkpointing = cpu_offload
|
||||
|
||||
self.time_in.enable_gradient_checkpointing()
|
||||
self.vector_in.enable_gradient_checkpointing()
|
||||
@@ -867,12 +931,13 @@ class Flux(nn.Module):
|
||||
self.guidance_in.enable_gradient_checkpointing()
|
||||
|
||||
for block in self.double_blocks + self.single_blocks:
|
||||
block.enable_gradient_checkpointing()
|
||||
block.enable_gradient_checkpointing(cpu_offload=cpu_offload)
|
||||
|
||||
print("FLUX: Gradient checkpointing enabled.")
|
||||
print(f"FLUX: Gradient checkpointing enabled. CPU offload: {cpu_offload}")
|
||||
|
||||
def disable_gradient_checkpointing(self):
|
||||
self.gradient_checkpointing = False
|
||||
self.cpu_offload_checkpointing = False
|
||||
|
||||
self.time_in.disable_gradient_checkpointing()
|
||||
self.vector_in.disable_gradient_checkpointing()
|
||||
@@ -884,6 +949,24 @@ class Flux(nn.Module):
|
||||
|
||||
print("FLUX: Gradient checkpointing disabled.")
|
||||
|
||||
def enable_block_swap(self, double_blocks: Optional[int], single_blocks: Optional[int]):
|
||||
self.double_blocks_to_swap = double_blocks
|
||||
self.single_blocks_to_swap = single_blocks
|
||||
|
||||
def prepare_block_swap_before_forward(self):
|
||||
# move last n blocks to cpu: they are on cuda
|
||||
if self.double_blocks_to_swap:
|
||||
for i in range(len(self.double_blocks) - self.double_blocks_to_swap):
|
||||
self.double_blocks[i].to(self.device)
|
||||
for i in range(len(self.double_blocks) - self.double_blocks_to_swap, len(self.double_blocks)):
|
||||
self.double_blocks[i].to("cpu") # , non_blocking=True)
|
||||
if self.single_blocks_to_swap:
|
||||
for i in range(len(self.single_blocks) - self.single_blocks_to_swap):
|
||||
self.single_blocks[i].to(self.device)
|
||||
for i in range(len(self.single_blocks) - self.single_blocks_to_swap, len(self.single_blocks)):
|
||||
self.single_blocks[i].to("cpu") # , non_blocking=True)
|
||||
clean_memory_on_device(self.device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
img: Tensor,
|
||||
@@ -910,14 +993,75 @@ class Flux(nn.Module):
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
for block in self.double_blocks:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
if not self.double_blocks_to_swap:
|
||||
for block in self.double_blocks:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
else:
|
||||
# make sure first n blocks are on cuda, and last n blocks are on cpu at beginning
|
||||
for block_idx in range(self.double_blocks_to_swap):
|
||||
block = self.double_blocks[len(self.double_blocks) - self.double_blocks_to_swap + block_idx]
|
||||
if block.parameters().__next__().device.type != "cpu":
|
||||
block.to("cpu") # , non_blocking=True)
|
||||
# print(f"Moved double block {len(self.double_blocks) - self.double_blocks_to_swap + block_idx} to cpu.")
|
||||
|
||||
block = self.double_blocks[block_idx]
|
||||
if block.parameters().__next__().device.type == "cpu":
|
||||
block.to(self.device)
|
||||
# print(f"Moved double block {block_idx} to cuda.")
|
||||
|
||||
to_cpu_block_index = 0
|
||||
for block_idx, block in enumerate(self.double_blocks):
|
||||
# move last n blocks to cuda: they are on cpu, and move first n blocks to cpu: they are on cuda
|
||||
moving = block_idx >= len(self.double_blocks) - self.double_blocks_to_swap
|
||||
if moving:
|
||||
block.to(self.device) # move to cuda
|
||||
# print(f"Moved double block {block_idx} to cuda.")
|
||||
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
|
||||
if moving:
|
||||
self.double_blocks[to_cpu_block_index].to("cpu") # , non_blocking=True)
|
||||
# print(f"Moved double block {to_cpu_block_index} to cpu.")
|
||||
to_cpu_block_index += 1
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
for block in self.single_blocks:
|
||||
img = block(img, vec=vec, pe=pe)
|
||||
|
||||
if not self.single_blocks_to_swap:
|
||||
for block in self.single_blocks:
|
||||
img = block(img, vec=vec, pe=pe)
|
||||
else:
|
||||
# make sure first n blocks are on cuda, and last n blocks are on cpu at beginning
|
||||
for block_idx in range(self.single_blocks_to_swap):
|
||||
block = self.single_blocks[len(self.single_blocks) - self.single_blocks_to_swap + block_idx]
|
||||
if block.parameters().__next__().device.type != "cpu":
|
||||
block.to("cpu") # , non_blocking=True)
|
||||
# print(f"Moved single block {len(self.single_blocks) - self.single_blocks_to_swap + block_idx} to cpu.")
|
||||
|
||||
block = self.single_blocks[block_idx]
|
||||
if block.parameters().__next__().device.type == "cpu":
|
||||
block.to(self.device)
|
||||
# print(f"Moved single block {block_idx} to cuda.")
|
||||
|
||||
to_cpu_block_index = 0
|
||||
for block_idx, block in enumerate(self.single_blocks):
|
||||
# move last n blocks to cuda: they are on cpu, and move first n blocks to cpu: they are on cuda
|
||||
moving = block_idx >= len(self.single_blocks) - self.single_blocks_to_swap
|
||||
if moving:
|
||||
block.to(self.device) # move to cuda
|
||||
# print(f"Moved single block {block_idx} to cuda.")
|
||||
|
||||
img = block(img, vec=vec, pe=pe)
|
||||
|
||||
if moving:
|
||||
self.single_blocks[to_cpu_block_index].to("cpu") # , non_blocking=True)
|
||||
# print(f"Moved single block {to_cpu_block_index} to cpu.")
|
||||
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
|
||||
if self.training and self.cpu_offload_checkpointing:
|
||||
img = img.to(self.device)
|
||||
vec = vec.to(self.device)
|
||||
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
|
||||
|
||||
Reference in New Issue
Block a user