mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 23:01:22 +00:00
Now exp_avg[_sq] are stored on cpu in 24 bit format. Also changed some final Flux stages to f32.
This commit is contained in:
@@ -454,6 +454,13 @@ def train(args):
|
||||
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
|
||||
accelerator.print("enable full bf16 training.")
|
||||
flux.to(weight_dtype)
|
||||
|
||||
# Experimental: some layers have very few weights, and training quality seems
|
||||
# to increase significantly if these are left in f32 format while training.
|
||||
if args.fused_backward_pass:
|
||||
flux.final_layer.linear.to(dtype=torch.float32) # Loses lower bits from some saved files,
|
||||
flux.img_in .to(dtype=torch.float32) # but most saved models aren't f32/f16 anyway.
|
||||
|
||||
if clip_l is not None:
|
||||
clip_l.to(weight_dtype)
|
||||
t5xxl.to(weight_dtype)
|
||||
|
||||
@@ -5,6 +5,53 @@ from library.adafactor_fused import copy_stochastic_
|
||||
from library.adafactor_fused import copy_kahan_
|
||||
|
||||
|
||||
def to_float24_bytes(tensor_f32: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Converts a float32 tensor to a 'float24' representation for storage.
|
||||
|
||||
This is done by taking the 3 most significant bytes of each float32 element.
|
||||
On a little-endian system, these are the last 3 bytes.
|
||||
# TODO - Check this works on Mac, which is a big-endian system
|
||||
|
||||
Args:
|
||||
tensor_f32: The input tensor with dtype torch.float32.
|
||||
|
||||
Returns:
|
||||
A 1D tensor of dtype torch.uint8 containing the packed 'float24' data.
|
||||
"""
|
||||
if tensor_f32.dtype != torch.float32:
|
||||
raise TypeError("Input tensor must be of dtype torch.float32")
|
||||
|
||||
tensor_u8 = tensor_f32.view(torch.uint8)
|
||||
tensor_u8_reshaped = tensor_u8.view(-1, 4)
|
||||
tensor_f24_bytes = tensor_u8_reshaped[:, 1:]
|
||||
return tensor_f24_bytes.flatten()
|
||||
|
||||
|
||||
def from_float24_bytes(tensor_f24_u8: torch.Tensor, original_shape: torch.Size) -> torch.Tensor:
|
||||
"""
|
||||
Restores a 'float24' byte tensor back to a float32 tensor.
|
||||
|
||||
Args:
|
||||
tensor_f24_u8: A 1D tensor of dtype torch.uint8 from to_float24_bytes.
|
||||
original_shape: The shape of the original float32 tensor.
|
||||
device: The device to create the restored tensor on.
|
||||
|
||||
Returns:
|
||||
The restored tensor with dtype torch.float32 and the original shape.
|
||||
"""
|
||||
if tensor_f24_u8.dtype != torch.uint8:
|
||||
raise TypeError("Input byte tensor must be of dtype torch.uint8")
|
||||
if tensor_f24_u8.numel() % 3 != 0:
|
||||
raise ValueError("Input byte tensor size must be a multiple of 3")
|
||||
|
||||
tensor_u8_3bytes = tensor_f24_u8.view(-1, 3)
|
||||
padding = torch.zeros(tensor_u8_3bytes.shape[0], 1, dtype=torch.uint8, device=tensor_u8_3bytes.device)
|
||||
tensor_u8_4bytes = torch.cat([padding, tensor_u8_3bytes], dim=1)
|
||||
tensor_f32_flat = tensor_u8_4bytes.flatten().view(torch.float32)
|
||||
return tensor_f32_flat.view(original_shape)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def adamw_offload_step_param(self, p, group):
|
||||
|
||||
@@ -31,19 +78,10 @@ def adamw_offload_step_param(self, p, group):
|
||||
if len(state) == 0:
|
||||
state["step"] = 0
|
||||
|
||||
if high_quality:
|
||||
# Exponential averages stored in f32 format
|
||||
state['exp_avg'] = torch.zeros_like(p, dtype=torch.float32)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p, dtype=torch.float32)
|
||||
else:
|
||||
# Exponential averages stored in u16 format
|
||||
state['exp_avg'] = torch.zeros_like(p, dtype=torch.uint16)
|
||||
state['exp_avg_min'] = 0.0
|
||||
state['exp_avg_max'] = 1.0
|
||||
data_type = torch.float32 if high_quality else torch.uint16
|
||||
|
||||
state['exp_avg_sq'] = torch.zeros_like(p, dtype=torch.uint16)
|
||||
state['exp_avg_sq_min'] = 0.0
|
||||
state['exp_avg_sq_max'] = 1.0
|
||||
state['exp_avg'] = torch.zeros_like(p, dtype=data_type)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p, dtype=data_type)
|
||||
|
||||
state["step"] += 1
|
||||
|
||||
@@ -59,35 +97,20 @@ def adamw_offload_step_param(self, p, group):
|
||||
|
||||
eps_p2: float = math.pow(eps, 2)
|
||||
|
||||
# Bring state back from CPU
|
||||
# Bring state back (from CPU, if necessary)
|
||||
|
||||
if high_quality:
|
||||
# These exponential averages are already in float32 format
|
||||
state['exp_avg'] = state['exp_avg'] .to(p.device)
|
||||
state['exp_avg_sq'] = state['exp_avg_sq'].to(p.device)
|
||||
else:
|
||||
# Unpack these exponential averages from uint16 format
|
||||
# Recover the exp avg states from however they're stored
|
||||
def unpack_tensor(state, key, target_device):
|
||||
|
||||
# A power function was applied to the tensor values, as they are usually
|
||||
# distributed in an exponential fashion. After the power function was applied,
|
||||
# the min and max of the results were noted, and then the values were scaled
|
||||
# to the 0-65535 range for storage. This process is reversed here.
|
||||
# Stored as f24 format?
|
||||
if state[f'{key}'].dtype == torch.uint8:
|
||||
return from_float24_bytes(state[f'{key}'].to(target_device), state[f'{key}_shape'])
|
||||
|
||||
u16power = 8.0 # This value worked acceptably in testing to spread the values more evenly
|
||||
|
||||
exp_avg_min = state['exp_avg_min']
|
||||
exp_avg_max = state['exp_avg_max']
|
||||
exp_avg_sq_min = state['exp_avg_sq_min']
|
||||
exp_avg_sq_max = state['exp_avg_sq_max']
|
||||
|
||||
uint16_recreate_a = state['exp_avg'].to(p.device).to(dtype=torch.float32) / 65535.0 * (exp_avg_max - exp_avg_min) + exp_avg_min
|
||||
state['exp_avg'] = torch.pow(torch.abs(uint16_recreate_a), u16power) * torch.sgn(uint16_recreate_a)
|
||||
del uint16_recreate_a
|
||||
|
||||
uint16_recreate_a_sq = state['exp_avg_sq'].to(p.device).to(dtype=torch.float32) / 65535.0 * (exp_avg_sq_max - exp_avg_sq_min) + exp_avg_sq_min
|
||||
state['exp_avg_sq'] = torch.pow(torch.abs(uint16_recreate_a_sq), u16power) * torch.sgn(uint16_recreate_a_sq)
|
||||
del uint16_recreate_a_sq
|
||||
# bf16 / u16 / f32
|
||||
return state[f'{key}'].to(target_device).to(dtype=torch.float32)
|
||||
|
||||
state['exp_avg'] = unpack_tensor(state, 'exp_avg', p.device)
|
||||
state['exp_avg_sq'] = unpack_tensor(state, 'exp_avg_sq', p.device)
|
||||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
||||
|
||||
# Update biased first and second moment estimates
|
||||
@@ -119,51 +142,14 @@ def adamw_offload_step_param(self, p, group):
|
||||
if weight_decay != 0:
|
||||
p_data_fp32.mul_(1 - lr * weight_decay)
|
||||
|
||||
if high_quality:
|
||||
# Reduce the size of large exp_avg and exp_avg_sq tensors to 24-bit,
|
||||
# and then move them to cpu memory
|
||||
if not high_quality:
|
||||
state[f'exp_avg_shape'] = state[f'exp_avg'].shape
|
||||
state[f'exp_avg'] = to_float24_bytes(state[f'exp_avg']).to('cpu')
|
||||
|
||||
# These are kept in f32 format between steps
|
||||
state['exp_avg'] = state['exp_avg'].to('cpu')
|
||||
state['exp_avg_sq'] = state['exp_avg_sq'].to('cpu')
|
||||
|
||||
else:
|
||||
|
||||
# Compress the exp_avg and exp_avg_sq tensors to cut their size down
|
||||
# from 32 bit to 16 bit.
|
||||
#
|
||||
# A power function is applied to try to linearize the tensor values, as
|
||||
# they are usually distributed in an exponential fashion. It would have
|
||||
# been preferable to use a log() function, but the input values can be
|
||||
# negative, so a pow() function is used instead. The 1/16th power was
|
||||
# chosen fairly arbitrarily, but seemed to distribute the values fairly
|
||||
# reasonably in some simple tests.
|
||||
#
|
||||
# After the power function is applied, the min and max of the resulting
|
||||
# values are stored, and the values are then scaled to the 0-65535 range
|
||||
# for storage.
|
||||
#
|
||||
# Doing this instead of storing these values as bf16 reduced the L1
|
||||
# error between the stored values and the true f32 values by around 90%,
|
||||
# with a notable increase in output image quality.
|
||||
|
||||
log_exp_avg = torch.pow(torch.abs(state['exp_avg']), 1.0 / u16power) * torch.sgn(state['exp_avg'])
|
||||
exp_avg_min = torch.min(log_exp_avg)
|
||||
exp_avg_max = torch.max(log_exp_avg)
|
||||
state['exp_avg_min'] = exp_avg_min
|
||||
state['exp_avg_max'] = exp_avg_max
|
||||
normalized = (log_exp_avg - exp_avg_min) / (exp_avg_max - exp_avg_min)
|
||||
del log_exp_avg
|
||||
|
||||
state['exp_avg'] = (normalized * 65535.0).clamp(0, 65535).to(dtype=torch.uint16).to('cpu')
|
||||
|
||||
log_exp_avg_sq = torch.pow(torch.abs(state['exp_avg_sq']), 1.0 / u16power) * torch.sgn(state['exp_avg_sq'])
|
||||
exp_avg_sq_min = torch.min(log_exp_avg_sq)
|
||||
exp_avg_sq_max = torch.max(log_exp_avg_sq)
|
||||
state['exp_avg_sq_min'] = exp_avg_sq_min
|
||||
state['exp_avg_sq_max'] = exp_avg_sq_max
|
||||
normalized_sq = (log_exp_avg_sq - exp_avg_sq_min) / (exp_avg_sq_max - exp_avg_sq_min)
|
||||
del log_exp_avg_sq
|
||||
|
||||
state['exp_avg_sq'] = (normalized_sq * 65535.0).clamp(0, 65535).to(dtype=torch.uint16).to('cpu')
|
||||
state[f'exp_avg_sq_shape'] = state[f'exp_avg_sq'].shape
|
||||
state[f'exp_avg_sq'] = to_float24_bytes(state[f'exp_avg_sq']).to('cpu')
|
||||
|
||||
# Add on gradient update, but not if using kahan summation as the bottom
|
||||
# bits must be restored first. (This update occurs in copy_kahan_() instead)
|
||||
|
||||
Reference in New Issue
Block a user