mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
new block swap for FLUX.1 fine tuning
This commit is contained in:
@@ -2,9 +2,12 @@
|
||||
# license: Apache-2.0 License
|
||||
|
||||
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
import math
|
||||
from typing import Optional
|
||||
import os
|
||||
import time
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
@@ -917,8 +920,10 @@ class Flux(nn.Module):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
self.cpu_offload_checkpointing = False
|
||||
self.double_blocks_to_swap = None
|
||||
self.single_blocks_to_swap = None
|
||||
self.blocks_to_swap = None
|
||||
|
||||
self.thread_pool: Optional[ThreadPoolExecutor] = None
|
||||
self.num_block_units = len(self.double_blocks) + len(self.single_blocks) // 2
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
@@ -956,38 +961,52 @@ class Flux(nn.Module):
|
||||
|
||||
print("FLUX: Gradient checkpointing disabled.")
|
||||
|
||||
def enable_block_swap(self, double_blocks: Optional[int], single_blocks: Optional[int]):
|
||||
self.double_blocks_to_swap = double_blocks
|
||||
self.single_blocks_to_swap = single_blocks
|
||||
def enable_block_swap(self, num_blocks: int):
|
||||
self.blocks_to_swap = num_blocks
|
||||
|
||||
n = 1 # async block swap. 1 is enough
|
||||
# n = 2
|
||||
# n = max(1, os.cpu_count() // 2)
|
||||
self.thread_pool = ThreadPoolExecutor(max_workers=n)
|
||||
|
||||
def move_to_device_except_swap_blocks(self, device: torch.device):
|
||||
# assume model is on cpu
|
||||
if self.double_blocks_to_swap:
|
||||
if self.blocks_to_swap:
|
||||
save_double_blocks = self.double_blocks
|
||||
self.double_blocks = None
|
||||
if self.single_blocks_to_swap:
|
||||
save_single_blocks = self.single_blocks
|
||||
self.double_blocks = None
|
||||
self.single_blocks = None
|
||||
|
||||
self.to(device)
|
||||
|
||||
if self.double_blocks_to_swap:
|
||||
if self.blocks_to_swap:
|
||||
self.double_blocks = save_double_blocks
|
||||
if self.single_blocks_to_swap:
|
||||
self.single_blocks = save_single_blocks
|
||||
|
||||
def get_block_unit(self, index: int):
|
||||
if index < len(self.double_blocks):
|
||||
return (self.double_blocks[index],)
|
||||
else:
|
||||
index -= len(self.double_blocks)
|
||||
index *= 2
|
||||
return self.single_blocks[index], self.single_blocks[index + 1]
|
||||
|
||||
def get_unit_index(self, is_double: bool, index: int):
|
||||
if is_double:
|
||||
return index
|
||||
else:
|
||||
return len(self.double_blocks) + index // 2
|
||||
|
||||
def prepare_block_swap_before_forward(self):
|
||||
# move last n blocks to cpu: they are on cuda
|
||||
if self.double_blocks_to_swap:
|
||||
for i in range(len(self.double_blocks) - self.double_blocks_to_swap):
|
||||
self.double_blocks[i].to(self.device)
|
||||
for i in range(len(self.double_blocks) - self.double_blocks_to_swap, len(self.double_blocks)):
|
||||
self.double_blocks[i].to("cpu") # , non_blocking=True)
|
||||
if self.single_blocks_to_swap:
|
||||
for i in range(len(self.single_blocks) - self.single_blocks_to_swap):
|
||||
self.single_blocks[i].to(self.device)
|
||||
for i in range(len(self.single_blocks) - self.single_blocks_to_swap, len(self.single_blocks)):
|
||||
self.single_blocks[i].to("cpu") # , non_blocking=True)
|
||||
# make: first n blocks are on cuda, and last n blocks are on cpu
|
||||
if self.blocks_to_swap is None:
|
||||
raise ValueError("Block swap is not enabled.")
|
||||
for i in range(self.num_block_units - self.blocks_to_swap):
|
||||
for b in self.get_block_unit(i):
|
||||
b.to(self.device)
|
||||
for i in range(self.num_block_units - self.blocks_to_swap, self.num_block_units):
|
||||
for b in self.get_block_unit(i):
|
||||
b.to("cpu")
|
||||
clean_memory_on_device(self.device)
|
||||
|
||||
def forward(
|
||||
@@ -1017,69 +1036,73 @@ class Flux(nn.Module):
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
if not self.double_blocks_to_swap:
|
||||
if not self.blocks_to_swap:
|
||||
for block in self.double_blocks:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
||||
else:
|
||||
# make sure first n blocks are on cuda, and last n blocks are on cpu at beginning
|
||||
for block_idx in range(self.double_blocks_to_swap):
|
||||
block = self.double_blocks[len(self.double_blocks) - self.double_blocks_to_swap + block_idx]
|
||||
if block.parameters().__next__().device.type != "cpu":
|
||||
block.to("cpu") # , non_blocking=True)
|
||||
# print(f"Moved double block {len(self.double_blocks) - self.double_blocks_to_swap + block_idx} to cpu.")
|
||||
|
||||
block = self.double_blocks[block_idx]
|
||||
if block.parameters().__next__().device.type == "cpu":
|
||||
block.to(self.device)
|
||||
# print(f"Moved double block {block_idx} to cuda.")
|
||||
|
||||
to_cpu_block_index = 0
|
||||
for block_idx, block in enumerate(self.double_blocks):
|
||||
# move last n blocks to cuda: they are on cpu, and move first n blocks to cpu: they are on cuda
|
||||
moving = block_idx >= len(self.double_blocks) - self.double_blocks_to_swap
|
||||
if moving:
|
||||
block.to(self.device) # move to cuda
|
||||
# print(f"Moved double block {block_idx} to cuda.")
|
||||
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
||||
|
||||
if moving:
|
||||
self.double_blocks[to_cpu_block_index].to("cpu") # , non_blocking=True)
|
||||
# print(f"Moved double block {to_cpu_block_index} to cpu.")
|
||||
to_cpu_block_index += 1
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
|
||||
if not self.single_blocks_to_swap:
|
||||
img = torch.cat((txt, img), 1)
|
||||
for block in self.single_blocks:
|
||||
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
||||
else:
|
||||
# make sure first n blocks are on cuda, and last n blocks are on cpu at beginning
|
||||
for block_idx in range(self.single_blocks_to_swap):
|
||||
block = self.single_blocks[len(self.single_blocks) - self.single_blocks_to_swap + block_idx]
|
||||
if block.parameters().__next__().device.type != "cpu":
|
||||
block.to("cpu") # , non_blocking=True)
|
||||
# print(f"Moved single block {len(self.single_blocks) - self.single_blocks_to_swap + block_idx} to cpu.")
|
||||
futures = {}
|
||||
|
||||
block = self.single_blocks[block_idx]
|
||||
if block.parameters().__next__().device.type == "cpu":
|
||||
block.to(self.device)
|
||||
# print(f"Moved single block {block_idx} to cuda.")
|
||||
def submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda):
|
||||
def move_blocks(bidx_to_cpu, blocks_to_cpu, bidx_to_cuda, blocks_to_cuda):
|
||||
# print(f"Moving {bidx_to_cpu} to cpu.")
|
||||
for block in blocks_to_cpu:
|
||||
block.to("cpu", non_blocking=True)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# print(f"Moving {bidx_to_cuda} to cuda.")
|
||||
for block in blocks_to_cuda:
|
||||
block.to(self.device, non_blocking=True)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
# print(f"Block move done. {bidx_to_cpu} to cpu, {bidx_to_cuda} to cuda.")
|
||||
return block_idx_to_cpu, block_idx_to_cuda
|
||||
|
||||
blocks_to_cpu = self.get_block_unit(block_idx_to_cpu)
|
||||
blocks_to_cuda = self.get_block_unit(block_idx_to_cuda)
|
||||
# print(f"Submit move blocks. {block_idx_to_cpu} to cpu, {block_idx_to_cuda} to cuda.")
|
||||
return self.thread_pool.submit(move_blocks, block_idx_to_cpu, blocks_to_cpu, block_idx_to_cuda, blocks_to_cuda)
|
||||
|
||||
def wait_for_blocks_move(block_idx, ftrs):
|
||||
if block_idx not in ftrs:
|
||||
return
|
||||
# print(f"Waiting for move blocks: {block_idx}")
|
||||
# start_time = time.perf_counter()
|
||||
ftr = ftrs.pop(block_idx)
|
||||
ftr.result()
|
||||
# torch.cuda.synchronize()
|
||||
# print(f"Move blocks took {time.perf_counter() - start_time:.2f} seconds")
|
||||
|
||||
for block_idx, block in enumerate(self.double_blocks):
|
||||
# print(f"Double block {block_idx}")
|
||||
unit_idx = self.get_unit_index(is_double=True, index=block_idx)
|
||||
wait_for_blocks_move(unit_idx, futures)
|
||||
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
||||
|
||||
if unit_idx < self.blocks_to_swap:
|
||||
block_idx_to_cpu = unit_idx
|
||||
block_idx_to_cuda = self.num_block_units - self.blocks_to_swap + unit_idx
|
||||
future = submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda)
|
||||
futures[block_idx_to_cuda] = future
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
|
||||
to_cpu_block_index = 0
|
||||
for block_idx, block in enumerate(self.single_blocks):
|
||||
# move last n blocks to cuda: they are on cpu, and move first n blocks to cpu: they are on cuda
|
||||
moving = block_idx >= len(self.single_blocks) - self.single_blocks_to_swap
|
||||
if moving:
|
||||
block.to(self.device) # move to cuda
|
||||
# print(f"Moved single block {block_idx} to cuda.")
|
||||
# print(f"Single block {block_idx}")
|
||||
unit_idx = self.get_unit_index(is_double=False, index=block_idx)
|
||||
if block_idx % 2 == 0:
|
||||
wait_for_blocks_move(unit_idx, futures)
|
||||
|
||||
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
||||
|
||||
if moving:
|
||||
self.single_blocks[to_cpu_block_index].to("cpu") # , non_blocking=True)
|
||||
# print(f"Moved single block {to_cpu_block_index} to cpu.")
|
||||
to_cpu_block_index += 1
|
||||
if block_idx % 2 == 1 and unit_idx < self.blocks_to_swap:
|
||||
block_idx_to_cpu = unit_idx
|
||||
block_idx_to_cuda = self.num_block_units - self.blocks_to_swap + unit_idx
|
||||
future = submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda)
|
||||
futures[block_idx_to_cuda] = future
|
||||
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
|
||||
@@ -1088,6 +1111,7 @@ class Flux(nn.Module):
|
||||
vec = vec.to(self.device)
|
||||
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user