From f6f3d6e34e795efb66253ab86a297def040ae612 Mon Sep 17 00:00:00 2001 From: araleza <70412719+araleza@users.noreply.github.com> Date: Sun, 28 Sep 2025 15:30:02 +0100 Subject: [PATCH] Now exp_avg[_sq] are stored on cpu in 24 bit format. Also changed some final Flux stages to f32. --- flux_train.py | 7 ++ library/adamw_fused.py | 148 +++++++++++++++++++---------------------- 2 files changed, 74 insertions(+), 81 deletions(-) diff --git a/flux_train.py b/flux_train.py index 5228ec13..14fb5cc8 100644 --- a/flux_train.py +++ b/flux_train.py @@ -454,6 +454,13 @@ def train(args): ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" accelerator.print("enable full bf16 training.") flux.to(weight_dtype) + + # Experimental: some layers have very few weights, and training quality seems + # to increase significantly if these are left in f32 format while training. + if args.fused_backward_pass: + flux.final_layer.linear.to(dtype=torch.float32) # Loses lower bits from some saved files, + flux.img_in .to(dtype=torch.float32) # but most saved models aren't f32/f16 anyway. + if clip_l is not None: clip_l.to(weight_dtype) t5xxl.to(weight_dtype) diff --git a/library/adamw_fused.py b/library/adamw_fused.py index b439c5d9..8419d422 100644 --- a/library/adamw_fused.py +++ b/library/adamw_fused.py @@ -5,6 +5,53 @@ from library.adafactor_fused import copy_stochastic_ from library.adafactor_fused import copy_kahan_ +def to_float24_bytes(tensor_f32: torch.Tensor) -> torch.Tensor: + """ + Converts a float32 tensor to a 'float24' representation for storage. + + This is done by taking the 3 most significant bytes of each float32 element. + On a little-endian system, these are the last 3 bytes. + # TODO - Check this works on Mac, which is a big-endian system + + Args: + tensor_f32: The input tensor with dtype torch.float32. + + Returns: + A 1D tensor of dtype torch.uint8 containing the packed 'float24' data. + """ + if tensor_f32.dtype != torch.float32: + raise TypeError("Input tensor must be of dtype torch.float32") + + tensor_u8 = tensor_f32.view(torch.uint8) + tensor_u8_reshaped = tensor_u8.view(-1, 4) + tensor_f24_bytes = tensor_u8_reshaped[:, 1:] + return tensor_f24_bytes.flatten() + + +def from_float24_bytes(tensor_f24_u8: torch.Tensor, original_shape: torch.Size) -> torch.Tensor: + """ + Restores a 'float24' byte tensor back to a float32 tensor. + + Args: + tensor_f24_u8: A 1D tensor of dtype torch.uint8 from to_float24_bytes. + original_shape: The shape of the original float32 tensor. + device: The device to create the restored tensor on. + + Returns: + The restored tensor with dtype torch.float32 and the original shape. + """ + if tensor_f24_u8.dtype != torch.uint8: + raise TypeError("Input byte tensor must be of dtype torch.uint8") + if tensor_f24_u8.numel() % 3 != 0: + raise ValueError("Input byte tensor size must be a multiple of 3") + + tensor_u8_3bytes = tensor_f24_u8.view(-1, 3) + padding = torch.zeros(tensor_u8_3bytes.shape[0], 1, dtype=torch.uint8, device=tensor_u8_3bytes.device) + tensor_u8_4bytes = torch.cat([padding, tensor_u8_3bytes], dim=1) + tensor_f32_flat = tensor_u8_4bytes.flatten().view(torch.float32) + return tensor_f32_flat.view(original_shape) + + @torch.no_grad() def adamw_offload_step_param(self, p, group): @@ -31,19 +78,10 @@ def adamw_offload_step_param(self, p, group): if len(state) == 0: state["step"] = 0 - if high_quality: - # Exponential averages stored in f32 format - state['exp_avg'] = torch.zeros_like(p, dtype=torch.float32) - state['exp_avg_sq'] = torch.zeros_like(p, dtype=torch.float32) - else: - # Exponential averages stored in u16 format - state['exp_avg'] = torch.zeros_like(p, dtype=torch.uint16) - state['exp_avg_min'] = 0.0 - state['exp_avg_max'] = 1.0 + data_type = torch.float32 if high_quality else torch.uint16 - state['exp_avg_sq'] = torch.zeros_like(p, dtype=torch.uint16) - state['exp_avg_sq_min'] = 0.0 - state['exp_avg_sq_max'] = 1.0 + state['exp_avg'] = torch.zeros_like(p, dtype=data_type) + state['exp_avg_sq'] = torch.zeros_like(p, dtype=data_type) state["step"] += 1 @@ -59,35 +97,20 @@ def adamw_offload_step_param(self, p, group): eps_p2: float = math.pow(eps, 2) - # Bring state back from CPU + # Bring state back (from CPU, if necessary) - if high_quality: - # These exponential averages are already in float32 format - state['exp_avg'] = state['exp_avg'] .to(p.device) - state['exp_avg_sq'] = state['exp_avg_sq'].to(p.device) - else: - # Unpack these exponential averages from uint16 format + # Recover the exp avg states from however they're stored + def unpack_tensor(state, key, target_device): - # A power function was applied to the tensor values, as they are usually - # distributed in an exponential fashion. After the power function was applied, - # the min and max of the results were noted, and then the values were scaled - # to the 0-65535 range for storage. This process is reversed here. + # Stored as f24 format? + if state[f'{key}'].dtype == torch.uint8: + return from_float24_bytes(state[f'{key}'].to(target_device), state[f'{key}_shape']) - u16power = 8.0 # This value worked acceptably in testing to spread the values more evenly - - exp_avg_min = state['exp_avg_min'] - exp_avg_max = state['exp_avg_max'] - exp_avg_sq_min = state['exp_avg_sq_min'] - exp_avg_sq_max = state['exp_avg_sq_max'] - - uint16_recreate_a = state['exp_avg'].to(p.device).to(dtype=torch.float32) / 65535.0 * (exp_avg_max - exp_avg_min) + exp_avg_min - state['exp_avg'] = torch.pow(torch.abs(uint16_recreate_a), u16power) * torch.sgn(uint16_recreate_a) - del uint16_recreate_a - - uint16_recreate_a_sq = state['exp_avg_sq'].to(p.device).to(dtype=torch.float32) / 65535.0 * (exp_avg_sq_max - exp_avg_sq_min) + exp_avg_sq_min - state['exp_avg_sq'] = torch.pow(torch.abs(uint16_recreate_a_sq), u16power) * torch.sgn(uint16_recreate_a_sq) - del uint16_recreate_a_sq + # bf16 / u16 / f32 + return state[f'{key}'].to(target_device).to(dtype=torch.float32) + state['exp_avg'] = unpack_tensor(state, 'exp_avg', p.device) + state['exp_avg_sq'] = unpack_tensor(state, 'exp_avg_sq', p.device) exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] # Update biased first and second moment estimates @@ -119,51 +142,14 @@ def adamw_offload_step_param(self, p, group): if weight_decay != 0: p_data_fp32.mul_(1 - lr * weight_decay) - if high_quality: + # Reduce the size of large exp_avg and exp_avg_sq tensors to 24-bit, + # and then move them to cpu memory + if not high_quality: + state[f'exp_avg_shape'] = state[f'exp_avg'].shape + state[f'exp_avg'] = to_float24_bytes(state[f'exp_avg']).to('cpu') - # These are kept in f32 format between steps - state['exp_avg'] = state['exp_avg'].to('cpu') - state['exp_avg_sq'] = state['exp_avg_sq'].to('cpu') - - else: - - # Compress the exp_avg and exp_avg_sq tensors to cut their size down - # from 32 bit to 16 bit. - # - # A power function is applied to try to linearize the tensor values, as - # they are usually distributed in an exponential fashion. It would have - # been preferable to use a log() function, but the input values can be - # negative, so a pow() function is used instead. The 1/16th power was - # chosen fairly arbitrarily, but seemed to distribute the values fairly - # reasonably in some simple tests. - # - # After the power function is applied, the min and max of the resulting - # values are stored, and the values are then scaled to the 0-65535 range - # for storage. - # - # Doing this instead of storing these values as bf16 reduced the L1 - # error between the stored values and the true f32 values by around 90%, - # with a notable increase in output image quality. - - log_exp_avg = torch.pow(torch.abs(state['exp_avg']), 1.0 / u16power) * torch.sgn(state['exp_avg']) - exp_avg_min = torch.min(log_exp_avg) - exp_avg_max = torch.max(log_exp_avg) - state['exp_avg_min'] = exp_avg_min - state['exp_avg_max'] = exp_avg_max - normalized = (log_exp_avg - exp_avg_min) / (exp_avg_max - exp_avg_min) - del log_exp_avg - - state['exp_avg'] = (normalized * 65535.0).clamp(0, 65535).to(dtype=torch.uint16).to('cpu') - - log_exp_avg_sq = torch.pow(torch.abs(state['exp_avg_sq']), 1.0 / u16power) * torch.sgn(state['exp_avg_sq']) - exp_avg_sq_min = torch.min(log_exp_avg_sq) - exp_avg_sq_max = torch.max(log_exp_avg_sq) - state['exp_avg_sq_min'] = exp_avg_sq_min - state['exp_avg_sq_max'] = exp_avg_sq_max - normalized_sq = (log_exp_avg_sq - exp_avg_sq_min) / (exp_avg_sq_max - exp_avg_sq_min) - del log_exp_avg_sq - - state['exp_avg_sq'] = (normalized_sq * 65535.0).clamp(0, 65535).to(dtype=torch.uint16).to('cpu') + state[f'exp_avg_sq_shape'] = state[f'exp_avg_sq'].shape + state[f'exp_avg_sq'] = to_float24_bytes(state[f'exp_avg_sq']).to('cpu') # Add on gradient update, but not if using kahan summation as the bottom # bits must be restored first. (This update occurs in copy_kahan_() instead)