feat: add block swap for FLUX.1/SD3 LoRA training

This commit is contained in:
Kohya S
2024-11-12 21:39:13 +09:00
parent 17cf249d76
commit 2cb7a6db02
14 changed files with 288 additions and 629 deletions

View File

@@ -16,13 +16,29 @@ def synchronize_device(device: torch.device):
torch.mps.synchronize()
def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
weight_swap_jobs = []
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
# This is not working for all cases (e.g. SD3), so we need to find the corresponding modules
# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
# print(module_to_cpu.__class__, module_to_cuda.__class__)
# if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
# weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
modules_to_cpu = {k: v for k, v in layer_to_cpu.named_modules()}
for module_to_cuda_name, module_to_cuda in layer_to_cuda.named_modules():
if hasattr(module_to_cuda, "weight") and module_to_cuda.weight is not None:
module_to_cpu = modules_to_cpu.get(module_to_cuda_name, None)
if module_to_cpu is not None and module_to_cpu.weight.shape == module_to_cuda.weight.shape:
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
else:
if module_to_cuda.weight.data.device.type != device.type:
# print(
# f"Module {module_to_cuda_name} not found in CPU model or shape mismatch, so not swapping and moving to device"
# )
module_to_cuda.weight.data = module_to_cuda.weight.data.to(device)
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
@@ -92,7 +108,7 @@ class Offloader:
def swap_weight_devices(self, block_to_cpu: nn.Module, block_to_cuda: nn.Module):
if self.cuda_available:
swap_weight_devices(block_to_cpu, block_to_cuda)
swap_weight_devices_cuda(self.device, block_to_cpu, block_to_cuda)
else:
swap_weight_devices_no_cuda(self.device, block_to_cpu, block_to_cuda)
@@ -132,52 +148,6 @@ class Offloader:
print(f"Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s")
class TrainOffloader(Offloader):
"""
supports backward offloading
"""
def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
super().__init__(num_blocks, blocks_to_swap, device, debug)
self.hook_added = set()
def create_grad_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]:
if block_index in self.hook_added:
return None
self.hook_added.add(block_index)
# -1 for 0-based index, -1 for current block is not fully backpropagated yet
num_blocks_propagated = self.num_blocks - block_index - 2
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap
waiting = block_index > 0 and block_index <= self.blocks_to_swap
if not swapping and not waiting:
return None
# create hook
block_idx_to_cpu = self.num_blocks - num_blocks_propagated
block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated
block_idx_to_wait = block_index - 1
if self.debug:
print(
f"Backward: Created grad hook for block {block_index} with {block_idx_to_cpu}, {block_idx_to_cuda}, {block_idx_to_wait}"
)
if swapping:
def grad_hook(tensor: torch.Tensor):
self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)
return grad_hook
else:
def grad_hook(tensor: torch.Tensor):
self._wait_blocks_move(block_idx_to_wait)
return grad_hook
class ModelOffloader(Offloader):
"""
supports forward offloading
@@ -228,6 +198,9 @@ class ModelOffloader(Offloader):
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return
if self.debug:
print("Prepare block devices before forward")
for b in blocks[0 : self.num_blocks - self.blocks_to_swap]:
b.to(self.device)
weighs_to_device(b, self.device) # make sure weights are on device

View File

@@ -970,11 +970,16 @@ class Flux(nn.Module):
double_blocks_to_swap = num_blocks // 2
single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2
assert double_blocks_to_swap <= self.num_double_blocks - 2 and single_blocks_to_swap <= self.num_single_blocks - 2, (
f"Cannot swap more than {self.num_double_blocks - 2} double blocks and {self.num_single_blocks - 2} single blocks. "
f"Requested {double_blocks_to_swap} double blocks and {single_blocks_to_swap} single blocks."
)
self.offloader_double = custom_offloading_utils.ModelOffloader(
self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device #, debug=True
self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device # , debug=True
)
self.offloader_single = custom_offloading_utils.ModelOffloader(
self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device #, debug=True
self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device # , debug=True
)
print(
f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
@@ -1061,10 +1066,11 @@ class Flux(nn.Module):
return img
"""
class FluxUpper(nn.Module):
"""
""
Transformer model for flow matching on sequences.
"""
""
def __init__(self, params: FluxParams):
super().__init__()
@@ -1168,9 +1174,9 @@ class FluxUpper(nn.Module):
class FluxLower(nn.Module):
"""
""
Transformer model for flow matching on sequences.
"""
""
def __init__(self, params: FluxParams):
super().__init__()
@@ -1228,3 +1234,4 @@ class FluxLower(nn.Module):
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img
"""

View File

@@ -257,14 +257,9 @@ def sample_image_inference(
wandb_tracker = accelerator.get_tracker("wandb")
import wandb
# not to commit images to avoid inconsistency between training and logging steps
wandb_tracker.log(
{f"sample_{i}": wandb.Image(
image,
caption=prompt # positive prompt as a caption
)},
commit=False
)
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
def time_shift(mu: float, sigma: float, t: torch.Tensor):
@@ -324,7 +319,7 @@ def denoise(
)
img = img + (t_prev - t_curr) * pred
model.prepare_block_swap_before_forward()
return img
@@ -549,44 +544,7 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):
action="store_true",
help="apply attention mask to T5-XXL encode and FLUX double blocks / T5-XXLエンコードとFLUXダブルブロックにアテンションマスクを適用する",
)
parser.add_argument(
"--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
)
parser.add_argument(
"--cache_text_encoder_outputs_to_disk",
action="store_true",
help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
)
parser.add_argument(
"--text_encoder_batch_size",
type=int,
default=None,
help="text encoder batch size (default: None, use dataset's batch size)"
+ " / text encoderのバッチサイズデフォルト: None, データセットのバッチサイズを使用)",
)
parser.add_argument(
"--disable_mmap_load_safetensors",
action="store_true",
help="disable mmap load for safetensors. Speed up model loading in WSL environment / safetensorsのmmapロードを無効にする。WSL環境等でモデル読み込みを高速化できる",
)
# copy from Diffusers
parser.add_argument(
"--weighting_scheme",
type=str,
default="none",
choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"],
)
parser.add_argument(
"--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
)
parser.add_argument("--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme.")
parser.add_argument(
"--mode_scale",
type=float,
default=1.29,
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
)
parser.add_argument(
"--guidance_scale",
type=float,

View File

@@ -18,6 +18,7 @@ import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from transformers import CLIPTokenizer, T5TokenizerFast
from library import custom_offloading_utils
from library.device_utils import clean_memory_on_device
from .utils import setup_logging
@@ -862,7 +863,8 @@ class MMDiT(nn.Module):
# self.initialize_weights()
self.blocks_to_swap = None
self.thread_pool: Optional[ThreadPoolExecutor] = None
self.offloader = None
self.num_blocks = len(self.joint_blocks)
def enable_scaled_pos_embed(self, use_scaled_pos_embed: bool, latent_sizes: Optional[list[int]]):
self.use_scaled_pos_embed = use_scaled_pos_embed
@@ -1055,14 +1057,20 @@ class MMDiT(nn.Module):
# )
return spatial_pos_embed
def enable_block_swap(self, num_blocks: int):
def enable_block_swap(self, num_blocks: int, device: torch.device):
self.blocks_to_swap = num_blocks
n = 1 # async block swap. 1 is enough
self.thread_pool = ThreadPoolExecutor(max_workers=n)
assert (
self.blocks_to_swap <= self.num_blocks - 2
), f"Cannot swap more than {self.num_blocks - 2} blocks. Requested: {self.blocks_to_swap} blocks."
self.offloader = custom_offloading_utils.ModelOffloader(
self.joint_blocks, self.num_blocks, self.blocks_to_swap, device # , debug=True
)
print(f"SD3: Block swap enabled. Swapping {num_blocks} blocks, total blocks: {self.num_blocks}, device: {device}.")
def move_to_device_except_swap_blocks(self, device: torch.device):
# assume model is on cpu
# assume model is on cpu. do not move blocks to device to reduce temporary memory usage
if self.blocks_to_swap:
save_blocks = self.joint_blocks
self.joint_blocks = None
@@ -1073,16 +1081,9 @@ class MMDiT(nn.Module):
self.joint_blocks = save_blocks
def prepare_block_swap_before_forward(self):
# make: first n blocks are on cuda, and last n blocks are on cpu
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
# raise ValueError("Block swap is not enabled.")
return
num_blocks = len(self.joint_blocks)
for i in range(num_blocks - self.blocks_to_swap):
self.joint_blocks[i].to(self.device)
for i in range(num_blocks - self.blocks_to_swap, num_blocks):
self.joint_blocks[i].to("cpu")
clean_memory_on_device(self.device)
self.offloader.prepare_block_devices_before_forward(self.joint_blocks)
def forward(
self,
@@ -1122,57 +1123,19 @@ class MMDiT(nn.Module):
if self.register_length > 0:
context = torch.cat(
(
einops.repeat(self.register, "1 ... -> b ...", b=x.shape[0]),
default(context, torch.Tensor([]).type_as(x)),
),
1,
(einops.repeat(self.register, "1 ... -> b ...", b=x.shape[0]), default(context, torch.Tensor([]).type_as(x))), 1
)
if not self.blocks_to_swap:
for block in self.joint_blocks:
context, x = block(context, x, c)
else:
futures = {}
def submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda):
def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda):
# print(f"Moving {bidx_to_cpu} to cpu.")
block_to_cpu.to("cpu", non_blocking=True)
torch.cuda.empty_cache()
# print(f"Moving {bidx_to_cuda} to cuda.")
block_to_cuda.to(self.device, non_blocking=True)
torch.cuda.synchronize()
# print(f"Block move done. {bidx_to_cpu} to cpu, {bidx_to_cuda} to cuda.")
return block_idx_to_cpu, block_idx_to_cuda
block_to_cpu = self.joint_blocks[block_idx_to_cpu]
block_to_cuda = self.joint_blocks[block_idx_to_cuda]
# print(f"Submit move blocks. {block_idx_to_cpu} to cpu, {block_idx_to_cuda} to cuda.")
return self.thread_pool.submit(move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda)
def wait_for_blocks_move(block_idx, ftrs):
if block_idx not in ftrs:
return
# print(f"Waiting for move blocks: {block_idx}")
# start_time = time.perf_counter()
ftr = ftrs.pop(block_idx)
ftr.result()
# torch.cuda.synchronize()
# print(f"Move blocks took {time.perf_counter() - start_time:.2f} seconds")
for block_idx, block in enumerate(self.joint_blocks):
wait_for_blocks_move(block_idx, futures)
self.offloader.wait_for_block(block_idx)
context, x = block(context, x, c)
if block_idx < self.blocks_to_swap:
block_idx_to_cpu = block_idx
block_idx_to_cuda = len(self.joint_blocks) - self.blocks_to_swap + block_idx
future = submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda)
futures[block_idx_to_cuda] = future
self.offloader.submit_move_blocks(self.joint_blocks, block_idx)
x = self.final_layer(x, c, H, W) # Our final layer combined UnPatchify
return x[:, :, :H, :W]

View File

@@ -142,27 +142,6 @@ def save_sd3_model_on_epoch_end_or_stepwise(
def add_sd3_training_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
)
parser.add_argument(
"--cache_text_encoder_outputs_to_disk",
action="store_true",
help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
)
parser.add_argument(
"--text_encoder_batch_size",
type=int,
default=None,
help="text encoder batch size (default: None, use dataset's batch size)"
+ " / text encoderのバッチサイズデフォルト: None, データセットのバッチサイズを使用)",
)
parser.add_argument(
"--disable_mmap_load_safetensors",
action="store_true",
help="disable mmap load for safetensors. Speed up model loading in WSL environment / safetensorsのmmapロードを無効にする。WSL環境等でモデル読み込みを高速化できる",
)
parser.add_argument(
"--clip_l",
type=str,
@@ -253,32 +232,8 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser):
" / 複数解像度学習時に解像度ごとに位置埋め込みをスケーリングする。SD3.5M以外では予期しない動作になります",
)
# Dependencies of Diffusers noise sampler has been removed for clarity.
parser.add_argument(
"--weighting_scheme",
type=str,
default="uniform",
choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "uniform"],
help="weighting scheme for timestep distribution and loss / タイムステップ分布と損失のための重み付けスキーム",
)
parser.add_argument(
"--logit_mean",
type=float,
default=0.0,
help="mean to use when using the `'logit_normal'` weighting scheme for timestep distribution. / タイムステップ分布のために`'logit_normal'`重み付けスキームを使用する場合の平均",
)
parser.add_argument(
"--logit_std",
type=float,
default=1.0,
help="std to use when using the `'logit_normal'` weighting scheme for timestep distribution. / タイムステップ分布のために`'logit_normal'`重み付けスキームを使用する場合のstd",
)
parser.add_argument(
"--mode_scale",
type=float,
default=1.29,
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`. / モード重み付けスキームのスケール。`'mode'`を`weighting_scheme`として使用する場合のみ有効",
)
# Dependencies of Diffusers noise sampler has been removed for clarity in training
parser.add_argument(
"--training_shift",
type=float,

View File

@@ -1887,7 +1887,9 @@ class DreamBoothDataset(BaseDataset):
# make image path to npz path mapping
npz_paths = glob.glob(os.path.join(subset.image_dir, "*" + strategy.cache_suffix))
npz_paths.sort(key=lambda item: item.rsplit("_", maxsplit=2)[0]) # sort by name excluding resolution and cache_suffix
npz_paths.sort(
key=lambda item: item.rsplit("_", maxsplit=2)[0]
) # sort by name excluding resolution and cache_suffix
npz_path_index = 0
size_set_count = 0
@@ -3537,8 +3539,8 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--fused_backward_pass",
action="store_true",
help="Combines backward pass and optimizer step to reduce VRAM usage. Only available in SDXL"
+ " / バックワードパスとオプティマイザステップを組み合わせてVRAMの使用量を削減します。SDXLでのみ有効",
help="Combines backward pass and optimizer step to reduce VRAM usage. Only available in SDXL, SD3 and FLUX"
" / バックワードパスとオプティマイザステップを組み合わせてVRAMの使用量を削減します。SDXL、SD3、FLUXでのみ利用可能",
)
parser.add_argument(
"--lr_scheduler_timescale",
@@ -4027,6 +4029,72 @@ def add_masked_loss_arguments(parser: argparse.ArgumentParser):
)
def add_dit_training_arguments(parser: argparse.ArgumentParser):
# Text encoder related arguments
parser.add_argument(
"--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
)
parser.add_argument(
"--cache_text_encoder_outputs_to_disk",
action="store_true",
help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
)
parser.add_argument(
"--text_encoder_batch_size",
type=int,
default=None,
help="text encoder batch size (default: None, use dataset's batch size)"
+ " / text encoderのバッチサイズデフォルト: None, データセットのバッチサイズを使用)",
)
# Model loading optimization
parser.add_argument(
"--disable_mmap_load_safetensors",
action="store_true",
help="disable mmap load for safetensors. Speed up model loading in WSL environment / safetensorsのmmapロードを無効にする。WSL環境等でモデル読み込みを高速化できる",
)
# Training arguments. partial copy from Diffusers
parser.add_argument(
"--weighting_scheme",
type=str,
default="uniform",
choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none", "uniform"],
help="weighting scheme for timestep distribution. Default is uniform, uniform and none are the same behavior"
" / タイムステップ分布の重み付けスキーム、デフォルトはuniform、uniform と none は同じ挙動",
)
parser.add_argument(
"--logit_mean",
type=float,
default=0.0,
help="mean to use when using the `'logit_normal'` weighting scheme / `'logit_normal'`重み付けスキームを使用する場合の平均",
)
parser.add_argument(
"--logit_std",
type=float,
default=1.0,
help="std to use when using the `'logit_normal'` weighting scheme / `'logit_normal'`重み付けスキームを使用する場合のstd",
)
parser.add_argument(
"--mode_scale",
type=float,
default=1.29,
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme` / モード重み付けスキームのスケール",
)
# offloading
parser.add_argument(
"--blocks_to_swap",
type=int,
default=None,
help="[EXPERIMENTAL] "
"Sets the number of blocks to swap during the forward and backward passes."
"Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)."
" / 順伝播および逆伝播中にスワップするブロックの数を設定します。"
"この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度s/itも低下します。",
)
def get_sanitized_config_or_none(args: argparse.Namespace):
# if `--log_config` is enabled, return args for logging. if not, return None.
# when `--log_config is enabled, filter out sensitive values from args