Now exp_avg[_sq] are stored on cpu in 24 bit format. Also changed some final Flux stages to f32.

This commit is contained in:
araleza
2025-09-28 15:30:02 +01:00
parent 657813346b
commit f6f3d6e34e
2 changed files with 74 additions and 81 deletions

View File

@@ -454,6 +454,13 @@ def train(args):
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
accelerator.print("enable full bf16 training.")
flux.to(weight_dtype)
# Experimental: some layers have very few weights, and training quality seems
# to increase significantly if these are left in f32 format while training.
if args.fused_backward_pass:
flux.final_layer.linear.to(dtype=torch.float32) # Loses lower bits from some saved files,
flux.img_in .to(dtype=torch.float32) # but most saved models aren't f32/f16 anyway.
if clip_l is not None:
clip_l.to(weight_dtype)
t5xxl.to(weight_dtype)

View File

@@ -5,6 +5,53 @@ from library.adafactor_fused import copy_stochastic_
from library.adafactor_fused import copy_kahan_
def to_float24_bytes(tensor_f32: torch.Tensor) -> torch.Tensor:
"""
Converts a float32 tensor to a 'float24' representation for storage.
This is done by taking the 3 most significant bytes of each float32 element.
On a little-endian system, these are the last 3 bytes.
# TODO - Check this works on Mac, which is a big-endian system
Args:
tensor_f32: The input tensor with dtype torch.float32.
Returns:
A 1D tensor of dtype torch.uint8 containing the packed 'float24' data.
"""
if tensor_f32.dtype != torch.float32:
raise TypeError("Input tensor must be of dtype torch.float32")
tensor_u8 = tensor_f32.view(torch.uint8)
tensor_u8_reshaped = tensor_u8.view(-1, 4)
tensor_f24_bytes = tensor_u8_reshaped[:, 1:]
return tensor_f24_bytes.flatten()
def from_float24_bytes(tensor_f24_u8: torch.Tensor, original_shape: torch.Size) -> torch.Tensor:
"""
Restores a 'float24' byte tensor back to a float32 tensor.
Args:
tensor_f24_u8: A 1D tensor of dtype torch.uint8 from to_float24_bytes.
original_shape: The shape of the original float32 tensor.
device: The device to create the restored tensor on.
Returns:
The restored tensor with dtype torch.float32 and the original shape.
"""
if tensor_f24_u8.dtype != torch.uint8:
raise TypeError("Input byte tensor must be of dtype torch.uint8")
if tensor_f24_u8.numel() % 3 != 0:
raise ValueError("Input byte tensor size must be a multiple of 3")
tensor_u8_3bytes = tensor_f24_u8.view(-1, 3)
padding = torch.zeros(tensor_u8_3bytes.shape[0], 1, dtype=torch.uint8, device=tensor_u8_3bytes.device)
tensor_u8_4bytes = torch.cat([padding, tensor_u8_3bytes], dim=1)
tensor_f32_flat = tensor_u8_4bytes.flatten().view(torch.float32)
return tensor_f32_flat.view(original_shape)
@torch.no_grad()
def adamw_offload_step_param(self, p, group):
@@ -31,19 +78,10 @@ def adamw_offload_step_param(self, p, group):
if len(state) == 0:
state["step"] = 0
if high_quality:
# Exponential averages stored in f32 format
state['exp_avg'] = torch.zeros_like(p, dtype=torch.float32)
state['exp_avg_sq'] = torch.zeros_like(p, dtype=torch.float32)
else:
# Exponential averages stored in u16 format
state['exp_avg'] = torch.zeros_like(p, dtype=torch.uint16)
state['exp_avg_min'] = 0.0
state['exp_avg_max'] = 1.0
data_type = torch.float32 if high_quality else torch.uint16
state['exp_avg_sq'] = torch.zeros_like(p, dtype=torch.uint16)
state['exp_avg_sq_min'] = 0.0
state['exp_avg_sq_max'] = 1.0
state['exp_avg'] = torch.zeros_like(p, dtype=data_type)
state['exp_avg_sq'] = torch.zeros_like(p, dtype=data_type)
state["step"] += 1
@@ -59,35 +97,20 @@ def adamw_offload_step_param(self, p, group):
eps_p2: float = math.pow(eps, 2)
# Bring state back from CPU
# Bring state back (from CPU, if necessary)
if high_quality:
# These exponential averages are already in float32 format
state['exp_avg'] = state['exp_avg'] .to(p.device)
state['exp_avg_sq'] = state['exp_avg_sq'].to(p.device)
else:
# Unpack these exponential averages from uint16 format
# Recover the exp avg states from however they're stored
def unpack_tensor(state, key, target_device):
# A power function was applied to the tensor values, as they are usually
# distributed in an exponential fashion. After the power function was applied,
# the min and max of the results were noted, and then the values were scaled
# to the 0-65535 range for storage. This process is reversed here.
# Stored as f24 format?
if state[f'{key}'].dtype == torch.uint8:
return from_float24_bytes(state[f'{key}'].to(target_device), state[f'{key}_shape'])
u16power = 8.0 # This value worked acceptably in testing to spread the values more evenly
exp_avg_min = state['exp_avg_min']
exp_avg_max = state['exp_avg_max']
exp_avg_sq_min = state['exp_avg_sq_min']
exp_avg_sq_max = state['exp_avg_sq_max']
uint16_recreate_a = state['exp_avg'].to(p.device).to(dtype=torch.float32) / 65535.0 * (exp_avg_max - exp_avg_min) + exp_avg_min
state['exp_avg'] = torch.pow(torch.abs(uint16_recreate_a), u16power) * torch.sgn(uint16_recreate_a)
del uint16_recreate_a
uint16_recreate_a_sq = state['exp_avg_sq'].to(p.device).to(dtype=torch.float32) / 65535.0 * (exp_avg_sq_max - exp_avg_sq_min) + exp_avg_sq_min
state['exp_avg_sq'] = torch.pow(torch.abs(uint16_recreate_a_sq), u16power) * torch.sgn(uint16_recreate_a_sq)
del uint16_recreate_a_sq
# bf16 / u16 / f32
return state[f'{key}'].to(target_device).to(dtype=torch.float32)
state['exp_avg'] = unpack_tensor(state, 'exp_avg', p.device)
state['exp_avg_sq'] = unpack_tensor(state, 'exp_avg_sq', p.device)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
# Update biased first and second moment estimates
@@ -119,51 +142,14 @@ def adamw_offload_step_param(self, p, group):
if weight_decay != 0:
p_data_fp32.mul_(1 - lr * weight_decay)
if high_quality:
# Reduce the size of large exp_avg and exp_avg_sq tensors to 24-bit,
# and then move them to cpu memory
if not high_quality:
state[f'exp_avg_shape'] = state[f'exp_avg'].shape
state[f'exp_avg'] = to_float24_bytes(state[f'exp_avg']).to('cpu')
# These are kept in f32 format between steps
state['exp_avg'] = state['exp_avg'].to('cpu')
state['exp_avg_sq'] = state['exp_avg_sq'].to('cpu')
else:
# Compress the exp_avg and exp_avg_sq tensors to cut their size down
# from 32 bit to 16 bit.
#
# A power function is applied to try to linearize the tensor values, as
# they are usually distributed in an exponential fashion. It would have
# been preferable to use a log() function, but the input values can be
# negative, so a pow() function is used instead. The 1/16th power was
# chosen fairly arbitrarily, but seemed to distribute the values fairly
# reasonably in some simple tests.
#
# After the power function is applied, the min and max of the resulting
# values are stored, and the values are then scaled to the 0-65535 range
# for storage.
#
# Doing this instead of storing these values as bf16 reduced the L1
# error between the stored values and the true f32 values by around 90%,
# with a notable increase in output image quality.
log_exp_avg = torch.pow(torch.abs(state['exp_avg']), 1.0 / u16power) * torch.sgn(state['exp_avg'])
exp_avg_min = torch.min(log_exp_avg)
exp_avg_max = torch.max(log_exp_avg)
state['exp_avg_min'] = exp_avg_min
state['exp_avg_max'] = exp_avg_max
normalized = (log_exp_avg - exp_avg_min) / (exp_avg_max - exp_avg_min)
del log_exp_avg
state['exp_avg'] = (normalized * 65535.0).clamp(0, 65535).to(dtype=torch.uint16).to('cpu')
log_exp_avg_sq = torch.pow(torch.abs(state['exp_avg_sq']), 1.0 / u16power) * torch.sgn(state['exp_avg_sq'])
exp_avg_sq_min = torch.min(log_exp_avg_sq)
exp_avg_sq_max = torch.max(log_exp_avg_sq)
state['exp_avg_sq_min'] = exp_avg_sq_min
state['exp_avg_sq_max'] = exp_avg_sq_max
normalized_sq = (log_exp_avg_sq - exp_avg_sq_min) / (exp_avg_sq_max - exp_avg_sq_min)
del log_exp_avg_sq
state['exp_avg_sq'] = (normalized_sq * 65535.0).clamp(0, 65535).to(dtype=torch.uint16).to('cpu')
state[f'exp_avg_sq_shape'] = state[f'exp_avg_sq'].shape
state[f'exp_avg_sq'] = to_float24_bytes(state[f'exp_avg_sq']).to('cpu')
# Add on gradient update, but not if using kahan summation as the bottom
# bits must be restored first. (This update occurs in copy_kahan_() instead)