This commit is contained in:
araleza
2026-04-02 18:20:27 +02:00
committed by GitHub
6 changed files with 607 additions and 27 deletions

View File

@@ -330,6 +330,8 @@ def train(args):
# 学習に必要なクラスを準備する
accelerator.print("prepare optimizer, data loader etc.")
fused_optimizers_supported = ['adafactor', 'adamoffload', 'nadamoffload', 'adamwoffload', 'nadamwoffload', 'adanoffload']
if args.blockwise_fused_optimizers:
# fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html
# Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each block of parameters.
@@ -381,10 +383,25 @@ def train(args):
raise ValueError("Schedule-free optimizer is not supported with blockwise fused optimizers")
optimizer_train_fn = lambda: None # dummy function
optimizer_eval_fn = lambda: None # dummy function
if (args.optimizer_type in fused_optimizers_supported) and args.full_bf16:
logger.warning("Use of --blockwise_fused_optimizers is preventing stochastic/Kahan weight updates.")
else:
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args)
# Pass any Kahan summation arg to the optimizer
if args.kahan_summation:
# Self check parameter compatibility
if args.optimizer_type.lower() not in fused_optimizers_supported:
logger.warning("Kahan summation has been requested, but this is not supported by the selected optimizer.")
if not args.full_bf16:
logger.warning("Kahan summation requires --full_bf16")
if args.blockwise_fused_optimizers:
logger.warning("Kahan summation has been requested, but these are not compatible with --blockwise_fused_optimizer. "\
"Perhaps try --fused_backward_pass instead.")
optimizer.use_kahan_summation = args.kahan_summation
# prepare dataloader
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
# some strategies can be None
@@ -437,6 +454,28 @@ def train(args):
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
accelerator.print("enable full bf16 training.")
flux.to(weight_dtype)
# Experimental: some layers have very few weights, and training quality seems
# to increase significantly if these are left in f32 format while training.
if args.fused_backward_pass:
from library.flux_models import MixedLinear
from library.flux_models import RMSNorm
flux.final_layer.linear.to(dtype=torch.float32)
flux.img_in .to(dtype=torch.float32)
for m in flux.modules():
num_params = sum(p.numel() for p in m.parameters())
if isinstance(m, MixedLinear) and m.bias is not None:
m.bias.data = m.bias.data.to(torch.float32)
if m.weight.data.numel() < 20000000: # Includes first Linear stage with 18m weights
m.weight.data = m.weight.data.to(torch.float32)
if isinstance(m, RMSNorm):
m.scale.data = m.scale.data.to(torch.float32)
if clip_l is not None:
clip_l.to(weight_dtype)
t5xxl.to(weight_dtype)
@@ -474,10 +513,21 @@ def train(args):
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
if args.fused_backward_pass:
# use fused optimizer for backward pass: other optimizers will be supported in the future
# use fused optimizer for backward pass. Only some specific optimizers are supported.
import library.adafactor_fused
import library.adamw_fused
import library.adan_fused
library.adafactor_fused.patch_adafactor_fused(optimizer)
if args.optimizer_type.lower() == "adafactor":
library.adafactor_fused.patch_adafactor_fused(optimizer)
elif args.optimizer_type.lower() == "adamoffload" or args.optimizer_type.lower() == "adamwoffload":
library.adamw_fused.patch_adamw_offload_fused(optimizer, False)
elif args.optimizer_type.lower() == "nadamoffload" or args.optimizer_type.lower() == "nadamwoffload":
library.adamw_fused.patch_adamw_offload_fused(optimizer, True) # Nesterov
elif args.optimizer_type.lower() == "adanoffload":
library.adan_fused.patch_adan_offload_fused(optimizer, False) # Adan
else:
logger.error(f"Optimizer '{args.optimizer_type}' does not have a --fused_backward_pass implementation available")
for param_group, param_name_group in zip(optimizer.param_groups, param_names):
for parameter, param_name in zip(param_group["params"], param_name_group):
@@ -816,6 +866,12 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="enable blockwise optimizers for fused backward pass and optimizer step / fused backward passとoptimizer step のためブロック単位のoptimizerを有効にする",
)
parser.add_argument(
"--kahan_summation",
action="store_true",
help="Offloads to CPU the float part lost during bf16 quantization, and re-adds it to the next step / "\
"bf16 量子化中に失われた浮動小数点部分を CPU にオフロードし、次のステップに再度追加します",
)
parser.add_argument(
"--skip_latents_validity_check",
action="store_true",

View File

@@ -28,6 +28,62 @@ def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
del result
# Kahan summation for bfloat16
# The implementation was provided by araleza.
# Based on paper "Revisiting BFloat16 Training": https://arxiv.org/pdf/2010.06192
def copy_kahan_(target: torch.Tensor, source: torch.Tensor, state, update):
"""
Copies source into target using Kahan summation.
The lower bits of the float32 weight that are lost on conversion to bfloat16
are sent to the CPU until the next step, where they are re-added onto the weights
before adding the gradient update. This produces near float32-like weight behavior,
although the copies back and forth to main memory result in slower training steps.
Args:
target: the target tensor with dtype=bfloat16
source: the target tensor with dtype=float32
state: the optimizer state, used to store kahan residuals
update: the change in weights due to the gradient
"""
# Initialize residuals to 0 for first step
if state.get('kahan_residuals') is None:
state['kahan_residuals'] = torch.zeros_like(source, dtype=torch.int16)
# Need this in 32 bit as PyTorch doesn't support mixed 32-bit and 16-bit math operations
state['kahan_residuals'] = state['kahan_residuals'].to(source.device).to(dtype=torch.int32)
# Bring the previous step's lower bits of the weights back from the
# cpu device, and add them back to the weights of the current step.
source_i32 = source.view(dtype=torch.int32) # Can't do math on uint32
source_i32.add_(state['kahan_residuals'])
# Reverse any rounding up during the cast to bf16 on the previous step
rounded_up = state['kahan_residuals'] >= 32768
source_i32[rounded_up] -= 65536
# Must add the gradient update after the bottom bits are restored in case
# the exponent is changed by the update, or the -65536 on the line above
# would drop the uint32 value below zero, which is invalid.
source.add_(-update)
# Get the lower bits into the residual
torch.bitwise_and(source_i32, 0x0000FFFF, out=state['kahan_residuals'])
# Ensure rounding to bfloat16 matches expectations. These lines may not be
# necessary as target.copy_ should do this rounding anyway.
source_i32.add_(32768) # Add offset so clipping bits performs round-to-nearest
source_i32.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32. Leaves only upper bits in source
# Move the 16-bit Kahan bits from VRAM to main memory
state['kahan_residuals'] = state['kahan_residuals'].to(dtype=torch.uint16).to("cpu")
# Copy the quantized floats into the target tensor
target.copy_(source)
@torch.no_grad()
def adafactor_step_param(self, p, group):
if p.grad is None:
@@ -102,13 +158,19 @@ def adafactor_step_param(self, p, group):
if group["weight_decay"] != 0:
p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))
p_data_fp32.add_(-update)
# Add on gradient update, but not if using kahan summation as the bottom
# bits must be restored first. (This update occurs in copy_kahan_() instead)
if not self.optimizer.use_kahan_summation:
p_data_fp32.add_(-update)
# if p.dtype in {torch.float16, torch.bfloat16}:
# p.copy_(p_data_fp32)
if p.dtype == torch.bfloat16:
copy_stochastic_(p, p_data_fp32)
if self.optimizer.use_kahan_summation:
copy_kahan_(p, p_data_fp32, state, update)
else:
copy_stochastic_(p, p_data_fp32)
elif p.dtype == torch.float16:
p.copy_(p_data_fp32)

198
library/adamw_fused.py Normal file
View File

@@ -0,0 +1,198 @@
import math
import torch
from library.adafactor_fused import copy_stochastic_
from library.adafactor_fused import copy_kahan_
def to_float24_bytes(tensor_f32: torch.Tensor) -> torch.Tensor:
"""
Converts a float32 tensor to a 'float24' representation for storage.
This is done by taking the 3 most significant bytes of each float32 element.
On a little-endian system, these are the last 3 bytes.
# TODO - Check this works on Mac, which is a big-endian system
Args:
tensor_f32: The input tensor with dtype torch.float32.
Returns:
A 1D tensor of dtype torch.uint8 containing the packed 'float24' data.
"""
if tensor_f32.dtype != torch.float32:
raise TypeError("Input tensor must be of dtype torch.float32")
tensor_u8 = tensor_f32.view(torch.uint8)
tensor_u8_reshaped = tensor_u8.view(-1, 4)
tensor_f24_bytes = tensor_u8_reshaped[:, 1:]
return tensor_f24_bytes.flatten()
def from_float24_bytes(tensor_f24_u8: torch.Tensor, original_shape: torch.Size) -> torch.Tensor:
"""
Restores a 'float24' byte tensor back to a float32 tensor.
Args:
tensor_f24_u8: A 1D tensor of dtype torch.uint8 from to_float24_bytes.
original_shape: The shape of the original float32 tensor.
device: The device to create the restored tensor on.
Returns:
The restored tensor with dtype torch.float32 and the original shape.
"""
if tensor_f24_u8.dtype != torch.uint8:
raise TypeError("Input byte tensor must be of dtype torch.uint8")
if tensor_f24_u8.numel() % 3 != 0:
raise ValueError("Input byte tensor size must be a multiple of 3")
tensor_u8_3bytes = tensor_f24_u8.view(-1, 3)
padding = torch.zeros(tensor_u8_3bytes.shape[0], 1, dtype=torch.uint8, device=tensor_u8_3bytes.device)
tensor_u8_4bytes = torch.cat([padding, tensor_u8_3bytes], dim=1)
tensor_f32_flat = tensor_u8_4bytes.flatten().view(torch.float32)
return tensor_f32_flat.view(original_shape)
@torch.no_grad()
def adamw_offload_step_param(self, p, group):
if p.grad is None:
return
grad = p.grad
if grad.dtype in {torch.float16, torch.bfloat16}:
grad = grad.float()
if grad.is_sparse:
raise RuntimeError("This (N)AdamW implementation does not support sparse gradients.")
state = self.state[p]
grad_shape = grad.shape
p_data_fp32 = p
if p.dtype in {torch.float16, torch.bfloat16}:
p_data_fp32 = p_data_fp32.float()
# Tensors with few elements may be more sensitive to quantization
# errors, so keep them in float32
high_quality = torch.numel(p) <= 4096
# State Initialization
if len(state) == 0:
state["step"] = 0
data_type = torch.float32 if high_quality else torch.uint16
state['exp_avg'] = torch.zeros_like(p, dtype=data_type)
state['exp_avg_sq'] = torch.zeros_like(p, dtype=data_type)
state["step"] += 1
# NAdam
beta1, beta2 = group['betas']
eps = group['eps'] # 1e-8
weight_decay = group.get('weight_decay', 0.0)
# Bias correction terms
bias_correction1 = 1.0 - math.pow(beta1, state['step'])
bias_correction2 = 1.0 - math.pow(beta2, state['step'])
eps_p2: float = math.pow(eps, 2)
# Bring state back (from CPU, if necessary)
# Recover the exp avg states from however they're stored
def unpack_tensor(state, key, target_device):
# Stored as f24 format?
if state[f'{key}'].dtype == torch.uint8:
return from_float24_bytes(state[f'{key}'].to(target_device), state[f'{key}_shape'])
# bf16 / u16 / f32
return state[f'{key}'].to(target_device).to(dtype=torch.float32)
state['exp_avg'] = unpack_tensor(state, 'exp_avg', p.device)
state['exp_avg_sq'] = unpack_tensor(state, 'exp_avg_sq', p.device)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
# Update biased first and second moment estimates
exp_avg .mul_(beta1).add_ (grad, alpha=1.0 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
# Compute bias-corrected second moment for denominator
exp_avg_sq_corrected = exp_avg_sq / bias_correction2
# Compute update based on whether Nesterov momentum (NAdam) is being used
if self.use_nesterov:
# The next step's bias correction for momentum is needed
bias_correction1_next = 1.0 - math.pow(beta1, state['step'] + 1)
# NAdam update: combines current gradient with momentum look-ahead
momentum_cache = exp_avg / bias_correction1_next
update = (beta1 * momentum_cache + (1.0 - beta1) * grad / bias_correction1) / (exp_avg_sq_corrected.sqrt() + eps)
else:
# Standard Adam update: use bias-corrected first moment directly
exp_avg_corrected = exp_avg / bias_correction1
update = exp_avg_corrected / (exp_avg_sq_corrected.sqrt() + eps)
lr: float = group['lr']
# Implement 'cautious optimizer' from https://arxiv.org/pdf/2411.16085
# The scaling factor - dividing by m.mean() - does not seem to work with parameter
# groups, but it also appears to be an optional step, so it has been removed.
m = (update * grad >= 0).to(grad.dtype)
update = update * m #/ (m.mean() + eps)
# Apply learning rate
update.mul_(lr)
# Apply weight decay
if weight_decay != 0:
p_data_fp32.mul_(1 - lr * weight_decay)
# Reduce the size of large exp_avg and exp_avg_sq tensors to 24-bit,
# and then move them to cpu memory
if not high_quality:
state[f'exp_avg_shape'] = state[f'exp_avg'].shape
state[f'exp_avg'] = to_float24_bytes(state[f'exp_avg']).to('cpu')
state[f'exp_avg_sq_shape'] = state[f'exp_avg_sq'].shape
state[f'exp_avg_sq'] = to_float24_bytes(state[f'exp_avg_sq']).to('cpu')
# Add on gradient update, but not if using kahan summation as the bottom
# bits must be restored first. (This update occurs in copy_kahan_() instead)
if not self.optimizer.use_kahan_summation:
p_data_fp32.add_(-update)
if p.dtype == torch.bfloat16:
if self.optimizer.use_kahan_summation:
copy_kahan_(p, p_data_fp32, state, update)
else:
copy_stochastic_(p, p_data_fp32)
elif p.dtype == torch.float16:
p.copy_(p_data_fp32)
@torch.no_grad()
def adamw_offload_step(self, closure=None):
"""
Performs a single optimization step
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group["params"]:
adamw_offload_step_param(self, p, group)
return loss
def patch_adamw_offload_fused(optimizer, use_nesterov):
optimizer.use_nesterov = use_nesterov
optimizer.step_param = adamw_offload_step_param.__get__(optimizer)
optimizer.step = adamw_offload_step.__get__(optimizer)

218
library/adan_fused.py Normal file
View File

@@ -0,0 +1,218 @@
import math
import torch
from library.adafactor_fused import copy_stochastic_
from library.adafactor_fused import copy_kahan_
# Pack floating point tensors into uint16. Their float32 bytes are interpreted as uint32
# bytes (not cast to uint32). Since positive floats are in sequential order when interpreted
# as uint32s, the groups of positive and negative floats appear as small ranges in uint32
# format. The three clumps (negative floats, zeros, postive floats) then have their min/max
# positions noted, and stretched to cover a uint16 range.
def pack_tensor(state, key, support_neg):
k = state[f'{key}']
k_uint32_f = torch.abs(k).view(torch.uint32).to(torch.float32)
min_val, max_val = torch.aminmax(k_uint32_f[k_uint32_f != 0.0])
# No support_neg (i.e. input floats are only zero or positive). Outputs values in these uint16 ranges:
# 0 <-- 0.0s
# 1..65535 <-- positive floats
# support_neg (i.e. input floats can be zero or +/-). Outputs values in these uint16 ranges:
# 0 <-- 0.0s
# 1..32767 <-- positive floats
# 32768 <-- -0.0 ? Not used.
# 32769..65535 <-- negative floats
range = 32768 if support_neg else 65536
k_int32_scale = (k_uint32_f - min_val) * (range - 2) / (max_val - min_val) + 1 # Scale into [1..range]
packed = torch.where(k > 0, k_int32_scale, 0) # Positive floats and zero
if support_neg:
packed = torch.where(k < 0, k_int32_scale + 32768, packed) # Negative floats
del k_int32_scale
k_uint16_scale = packed.to(torch.uint16)
state[f'{key}'] = k_uint16_scale
state[f'{key}_min'] = min_val
state[f'{key}_max'] = max_val
pass
# Recover adan state tensors packed wtih pack_tensor()
def unpack_tensor(state, key, support_neg):
# uint16 format = packed floats
if state[f'{key}'].dtype == torch.uint16:
packed = state[f'{key}'].to('cuda').to(dtype=torch.float32)
min_val = state[f'{key}_min']
max_val = state[f'{key}_max']
range = 32768.0 if support_neg else 65536.0
if support_neg:
pack_merge_signs = torch.where(packed >= 32768, packed - 32768, packed)
else:
pack_merge_signs = packed
upck = (pack_merge_signs - 1) / (range - 2) * (max_val - min_val) + min_val
upck = torch.where(pack_merge_signs == 0, 0, upck) # 0's are special cased
upck = upck.to(torch.uint32)
upck_final_but_no_negs = upck.view(torch.float32)
if support_neg:
upck_final = torch.where(packed >= 32768, -upck_final_but_no_negs, upck_final_but_no_negs)
else:
upck_final = upck_final_but_no_negs
return upck_final
# bf16 / f32
return state[f'{key}'].to('cuda').to(dtype=torch.float32)
@torch.no_grad()
def adan_offload_step_param(self, p, group):
if p.grad is None:
return
grad = p.grad
if grad.dtype in {torch.float16, torch.bfloat16}:
grad = grad.float()
if grad.is_sparse:
raise RuntimeError("This Adan implementation does not support sparse gradients.")
state = self.state[p]
grad_shape = grad.shape
p_data_fp32 = p
if p.dtype in {torch.float16, torch.bfloat16}:
p_data_fp32 = p_data_fp32.float()
# Tensors with few elements may be more sensitive to quantization
# errors, so keep them in float32
#global tot_4096, tot_all
high_quality = torch.numel(p) <= 2000000
# State Initialization
if len(state) == 0:
state["step"] = 0
state['exp_avg'] = torch.zeros_like(p, dtype=torch.float32 if high_quality else torch.bfloat16)
state['exp_avg_sq'] = torch.zeros_like(p, dtype=torch.float32 if high_quality else torch.bfloat16)
state['exp_avg_diff'] = torch.zeros_like(p, dtype=torch.float32 if high_quality else torch.bfloat16)
state['neg_grad_or_diff'] = torch.zeros_like(p, dtype=torch.float32 if high_quality else torch.bfloat16)
else:
pass
state["step"] += 1
#beta1, beta2, beta3 = group['betas'] # Don't have custom class, so beta3 not available
beta1, beta2, beta3 = (0.98, 0.92, 0.99) # Hard coded betas for now
eps = group['eps'] # 1e-8
weight_decay = group.get('weight_decay', 0.0) # Not currently implemented
# Bias correction terms
bias_correction1 = 1.0 - math.pow(beta1, state['step'])
bias_correction2 = 1.0 - math.pow(beta2, state['step'])
bias_correction3 = 1.0 - math.pow(beta3, state['step'])
bias_correction3_sqrt = math.sqrt(bias_correction3)
eps_p2: float = math.pow(eps, 2)
# Recover the exp avg states from however they're stored
state['exp_avg'] = unpack_tensor(state, 'exp_avg', True)
state['exp_avg_sq'] = unpack_tensor(state, 'exp_avg_sq', False)
state['exp_avg_diff'] = unpack_tensor(state, 'exp_avg_diff', True)
state['neg_grad_or_diff'] = unpack_tensor(state, 'neg_grad_or_diff', True)
exp_avg = state['exp_avg']
exp_avg_sq = state['exp_avg_sq']
exp_avg_diff = state['exp_avg_diff']
neg_grad_or_diff = state['neg_grad_or_diff']
# for memory saving, we use `neg_grad_or_diff`
# to get some temp variable in a inplace way
neg_grad_or_diff.add_(grad)
exp_avg .mul_(beta1).add_(grad, alpha= 1 - beta1) # m_t
exp_avg_diff.mul_(beta2).add_(neg_grad_or_diff, alpha= 1 - beta2) # diff_t
neg_grad_or_diff.mul_(beta2).add_(grad)
exp_avg_sq .mul_(beta3).addcmul_(neg_grad_or_diff, neg_grad_or_diff, value= 1 - beta3) # n_t
lr: float = group['lr']
denom = (exp_avg_sq.sqrt() / bias_correction3_sqrt).add_(eps)
step_size = lr / bias_correction1
step_size_diff = lr * beta2 / bias_correction2
# todo: weight decay not supported
update = (exp_avg * step_size ) / denom
update += (exp_avg_diff * step_size_diff) / denom
neg_grad_or_diff.zero_().add_(grad, alpha=-1.0)
# Just build momentum for first few steps
if state['step'] <= 3:
update.mul_(0.0)
# Move the optimizer state tensors to main memory
if not high_quality:
# float32 to uint16 compression, hopefully provides more precision
pack_tensor(state, 'exp_avg', True)
pack_tensor(state, 'exp_avg_sq', False) # Only positive floats
pack_tensor(state, 'exp_avg_diff', True)
state[f'exp_avg'] = state[f'exp_avg'] .to('cpu')
state[f'exp_avg_sq'] = state[f'exp_avg_sq'] .to('cpu')
state[f'exp_avg_diff'] = state[f'exp_avg_diff'].to('cpu')
# Neg_grad is always a bfloat16 (stored in a float32) already apparently! So
# can be stored as a bfloat16 exactly.
state[f'neg_grad_or_diff'] = state[f'neg_grad_or_diff'].to(torch.bfloat16).to('cpu')
# Add on gradient update, but not if using kahan summation as the bottom
# bits must be restored first. (This update occurs in copy_kahan_() instead)
if not self.optimizer.use_kahan_summation:
p_data_fp32.add_(-update)
if p.dtype == torch.bfloat16:
if self.optimizer.use_kahan_summation:
copy_kahan_(p, p_data_fp32, state, update)
else:
copy_stochastic_(p, p_data_fp32)
elif p.dtype == torch.float16:
p.copy_(p_data_fp32)
@torch.no_grad()
def adan_offload_step(self, closure=None):
"""
Performs a single optimization step
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group["params"]:
adan_offload_step_param(self, p, group)
return loss
def patch_adan_offload_fused(optimizer, use_nesterov):
optimizer.use_nesterov = use_nesterov
optimizer.step_param = adan_offload_step_param.__get__(optimizer)
optimizer.step = adan_offload_step.__get__(optimizer)

View File

@@ -543,12 +543,43 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
return embedding
import torch.nn.functional as F
# A class that supports having the biases have a dtype of float32
# while the more numerous weights are still in bfloat16 format.
class MixedLinear(nn.Module):
def __init__(self, in_features, out_features, bias=True):
super().__init__()
# Initialize weights in float32 first, then cast to bfloat16
weight = torch.empty(out_features, in_features, dtype=torch.float32)
nn.init.kaiming_uniform_(weight, a=5**0.5)
self.weight = nn.Parameter(weight.to(torch.bfloat16))
if bias:
bias_param = torch.empty(out_features, dtype=torch.float32) # High precision
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weight)
bound = 1 / fan_in**0.5
nn.init.uniform_(bias_param, -bound, bound)
self.bias = nn.Parameter(bias_param)
else:
self.bias = None
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.weight.dtype == torch.bfloat16:
weight_fp32 = self.weight.to(torch.float32)
else:
weight_fp32 = self.weight
return F.linear(input, weight_fp32, self.bias)
class MLPEmbedder(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int):
super().__init__()
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
self.in_layer = MixedLinear(in_dim, hidden_dim, bias=True)
self.silu = nn.SiLU()
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
self.out_layer = MixedLinear(hidden_dim, hidden_dim, bias=True)
self.gradient_checkpointing = False
@@ -609,9 +640,9 @@ class SelfAttention(nn.Module):
self.num_heads = num_heads
head_dim = dim // num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.qkv = MixedLinear(dim, dim * 3, bias=qkv_bias)
self.norm = QKNorm(head_dim)
self.proj = nn.Linear(dim, dim)
self.proj = MixedLinear(dim, dim)
# this is not called from DoubleStreamBlock/SingleStreamBlock because they uses attention function directly
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
@@ -635,7 +666,7 @@ class Modulation(nn.Module):
super().__init__()
self.is_double = double
self.multiplier = 6 if double else 3
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
self.lin = MixedLinear(dim, self.multiplier * dim, bias=True)
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
@@ -659,9 +690,9 @@ class DoubleStreamBlock(nn.Module):
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.img_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
MixedLinear(hidden_size, mlp_hidden_dim, bias=True),
nn.GELU(approximate="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
MixedLinear(mlp_hidden_dim, hidden_size, bias=True),
)
self.txt_mod = Modulation(hidden_size, double=True)
@@ -670,9 +701,9 @@ class DoubleStreamBlock(nn.Module):
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
MixedLinear(hidden_size, mlp_hidden_dim, bias=True),
nn.GELU(approximate="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
MixedLinear(mlp_hidden_dim, hidden_size, bias=True),
)
self.gradient_checkpointing = False
@@ -780,9 +811,9 @@ class SingleStreamBlock(nn.Module):
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
# qkv and mlp_in
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
self.linear1 = MixedLinear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
# proj and mlp_out
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
self.linear2 = MixedLinear(hidden_size + self.mlp_hidden_dim, hidden_size)
self.norm = QKNorm(head_dim)
@@ -862,8 +893,8 @@ class LastLayer(nn.Module):
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
self.linear = MixedLinear(hidden_size, patch_size * patch_size * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), MixedLinear(hidden_size, 2 * hidden_size, bias=True))
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
@@ -894,11 +925,11 @@ class Flux(nn.Module):
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
self.img_in = MixedLinear(self.in_channels, self.hidden_size, bias=True)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
self.txt_in = MixedLinear(params.context_in_dim, self.hidden_size)
self.double_blocks = nn.ModuleList(
[
@@ -1114,11 +1145,11 @@ class ControlNetFlux(nn.Module):
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
self.img_in = MixedLinear(self.in_channels, self.hidden_size, bias=True)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
self.txt_in = MixedLinear(params.context_in_dim, self.hidden_size)
self.double_blocks = nn.ModuleList(
[
@@ -1151,15 +1182,15 @@ class ControlNetFlux(nn.Module):
# add ControlNet blocks
self.controlnet_blocks = nn.ModuleList([])
for _ in range(controlnet_depth):
controlnet_block = nn.Linear(self.hidden_size, self.hidden_size)
controlnet_block = MixedLinear(self.hidden_size, self.hidden_size)
controlnet_block = zero_module(controlnet_block)
self.controlnet_blocks.append(controlnet_block)
self.controlnet_blocks_for_single = nn.ModuleList([])
for _ in range(controlnet_single_depth):
controlnet_block = nn.Linear(self.hidden_size, self.hidden_size)
controlnet_block = MixedLinear(self.hidden_size, self.hidden_size)
controlnet_block = zero_module(controlnet_block)
self.controlnet_blocks_for_single.append(controlnet_block)
self.pos_embed_input = nn.Linear(self.in_channels, self.hidden_size, bias=True)
self.pos_embed_input = MixedLinear(self.in_channels, self.hidden_size, bias=True)
self.gradient_checkpointing = False
self.input_hint_block = nn.Sequential(
nn.Conv2d(3, 16, 3, padding=1),

View File

@@ -4909,9 +4909,6 @@ def get_optimizer(args, trainable_params) -> tuple[str, str, object]:
optimizer_type = optimizer_type.lower()
if args.fused_backward_pass:
assert (
optimizer_type == "Adafactor".lower()
), "fused_backward_pass currently only works with optimizer_type Adafactor / fused_backward_passは現在optimizer_type Adafactorでのみ機能します"
assert (
args.gradient_accumulation_steps == 1
), "fused_backward_pass does not work with gradient_accumulation_steps > 1 / fused_backward_passはgradient_accumulation_steps>1では機能しません"
@@ -5155,6 +5152,24 @@ def get_optimizer(args, trainable_params) -> tuple[str, str, object]:
optimizer_class = transformers.optimization.Adafactor
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
elif optimizer_type.lower() == "adamoffload" or optimizer_type.lower() == "nadamoffload":
logger.info(f"use [N]AdamOffload optimizer | {optimizer_kwargs}")
optimizer_class = torch.optim.Adam
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
elif optimizer_type.lower() == "adamwoffload" or optimizer_type.lower() == "nadamwoffload":
logger.info(f"use [N]AdamWOffload optimizer | {optimizer_kwargs}")
optimizer_class = torch.optim.AdamW # default weight_decay seems to be 0.01
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
elif optimizer_type.lower() == "adanoffload":
logger.info(f"use AdanOffload optimizer | {optimizer_kwargs}")
optimizer_class = torch.optim.AdamW # todo: can't set beta3 here yet, need a custom Adan class
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
elif optimizer_type == "AdamW".lower():
logger.info(f"use AdamW optimizer | {optimizer_kwargs}")
optimizer_class = torch.optim.AdamW