mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 23:01:22 +00:00
Changed cpu storage of exp_avg[_sq] from bf16 to powed/scaled u16
This commit is contained in:
@@ -23,11 +23,27 @@ def adamw_offload_step_param(self, p, group):
|
||||
if p.dtype in {torch.float16, torch.bfloat16}:
|
||||
p_data_fp32 = p_data_fp32.float()
|
||||
|
||||
# Tensors with few elements may be more sensitive to quantization
|
||||
# errors, so keep them in float32
|
||||
high_quality = torch.numel(p) <= 4096
|
||||
|
||||
# State Initialization
|
||||
if len(state) == 0:
|
||||
state["step"] = 0
|
||||
state['exp_avg'] = torch.zeros_like(p, dtype=torch.bfloat16)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p, dtype=torch.bfloat16)
|
||||
|
||||
if high_quality:
|
||||
# Exponential averages stored in f32 format
|
||||
state['exp_avg'] = torch.zeros_like(p, dtype=torch.float32)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p, dtype=torch.float32)
|
||||
else:
|
||||
# Exponential averages stored in u16 format
|
||||
state['exp_avg'] = torch.zeros_like(p, dtype=torch.uint16)
|
||||
state['exp_avg_min'] = 0.0
|
||||
state['exp_avg_max'] = 1.0
|
||||
|
||||
state['exp_avg_sq'] = torch.zeros_like(p, dtype=torch.uint16)
|
||||
state['exp_avg_sq_min'] = 0.0
|
||||
state['exp_avg_sq_max'] = 1.0
|
||||
|
||||
state["step"] += 1
|
||||
|
||||
@@ -44,8 +60,34 @@ def adamw_offload_step_param(self, p, group):
|
||||
eps_p2: float = math.pow(eps, 2)
|
||||
|
||||
# Bring state back from CPU
|
||||
state['exp_avg'] = state['exp_avg'] .to('cuda').to(dtype=torch.float32)
|
||||
state['exp_avg_sq'] = state['exp_avg_sq'].to('cuda').to(dtype=torch.float32)
|
||||
|
||||
if high_quality:
|
||||
# These exponential averages are already in float32 format
|
||||
state['exp_avg'] = state['exp_avg'] .to(p.device)
|
||||
state['exp_avg_sq'] = state['exp_avg_sq'].to(p.device)
|
||||
else:
|
||||
# Unpack these exponential averages from uint16 format
|
||||
|
||||
# A power function was applied to the tensor values, as they are usually
|
||||
# distributed in an exponential fashion. After the power function was applied,
|
||||
# the min and max of the results were noted, and then the values were scaled
|
||||
# to the 0-65535 range for storage. This process is reversed here.
|
||||
|
||||
u16power = 16.0 # This value worked acceptably in testing to spread the values more evenly
|
||||
|
||||
exp_avg_min = state['exp_avg_min']
|
||||
exp_avg_max = state['exp_avg_max']
|
||||
exp_avg_sq_min = state['exp_avg_sq_min']
|
||||
exp_avg_sq_max = state['exp_avg_sq_max']
|
||||
|
||||
uint16_recreate_a = state['exp_avg'].to(p.device).to(dtype=torch.float32) / 65535.0 * (exp_avg_max - exp_avg_min) + exp_avg_min
|
||||
state['exp_avg'] = torch.pow(torch.abs(uint16_recreate_a), u16power) * torch.sgn(uint16_recreate_a)
|
||||
del uint16_recreate_a
|
||||
|
||||
uint16_recreate_a_sq = state['exp_avg_sq'].to(p.device).to(dtype=torch.float32) / 65535.0 * (exp_avg_sq_max - exp_avg_sq_min) + exp_avg_sq_min
|
||||
state['exp_avg_sq'] = torch.pow(torch.abs(uint16_recreate_a_sq), u16power) * torch.sgn(uint16_recreate_a_sq)
|
||||
del uint16_recreate_a_sq
|
||||
|
||||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
||||
|
||||
# Update biased first and second moment estimates
|
||||
@@ -77,9 +119,51 @@ def adamw_offload_step_param(self, p, group):
|
||||
if weight_decay != 0:
|
||||
p_data_fp32.mul_(1 - lr * weight_decay)
|
||||
|
||||
# Keep state on CPU
|
||||
state['exp_avg'] = state['exp_avg'] .to(dtype=torch.bfloat16).to('cpu')
|
||||
state['exp_avg_sq'] = state['exp_avg_sq'].to(dtype=torch.bfloat16).to('cpu')
|
||||
if high_quality:
|
||||
|
||||
# These are kept in f32 format between steps
|
||||
state['exp_avg'] = state['exp_avg'].to('cpu')
|
||||
state['exp_avg_sq'] = state['exp_avg_sq'].to('cpu')
|
||||
|
||||
else:
|
||||
|
||||
# Compress the exp_avg and exp_avg_sq tensors to cut their size down
|
||||
# from 32 bit to 16 bit.
|
||||
#
|
||||
# A power function is applied to try to linearize the tensor values, as
|
||||
# they are usually distributed in an exponential fashion. It would have
|
||||
# been preferable to use a log() function, but the input values can be
|
||||
# negative, so a pow() function is used instead. The 1/16th power was
|
||||
# chosen fairly arbitrarily, but seemed to distribute the values fairly
|
||||
# reasonably in some simple tests.
|
||||
#
|
||||
# After the power function is applied, the min and max of the resulting
|
||||
# values are stored, and the values are then scaled to the 0-65535 range
|
||||
# for storage.
|
||||
#
|
||||
# Doing this instead of storing these values as bf16 reduced the L1
|
||||
# error between the stored values and the true f32 values by around 90%,
|
||||
# with a notable increase in output image quality.
|
||||
|
||||
log_exp_avg = torch.pow(torch.abs(state['exp_avg']), 1.0 / u16power) * torch.sgn(state['exp_avg'])
|
||||
exp_avg_min = torch.min(log_exp_avg)
|
||||
exp_avg_max = torch.max(log_exp_avg)
|
||||
state['exp_avg_min'] = exp_avg_min
|
||||
state['exp_avg_max'] = exp_avg_max
|
||||
normalized = (log_exp_avg - exp_avg_min) / (exp_avg_max - exp_avg_min)
|
||||
del log_exp_avg
|
||||
|
||||
state['exp_avg'] = (normalized * 65535.0).clamp(0, 65535).to(dtype=torch.uint16).to('cpu')
|
||||
|
||||
log_exp_avg_sq = torch.pow(torch.abs(state['exp_avg_sq']), 1.0 / u16power) * torch.sgn(state['exp_avg_sq'])
|
||||
exp_avg_sq_min = torch.min(log_exp_avg_sq)
|
||||
exp_avg_sq_max = torch.max(log_exp_avg_sq)
|
||||
state['exp_avg_sq_min'] = exp_avg_sq_min
|
||||
state['exp_avg_sq_max'] = exp_avg_sq_max
|
||||
normalized_sq = (log_exp_avg_sq - exp_avg_sq_min) / (exp_avg_sq_max - exp_avg_sq_min)
|
||||
del log_exp_avg_sq
|
||||
|
||||
state['exp_avg_sq'] = (normalized_sq * 65535.0).clamp(0, 65535).to(dtype=torch.uint16).to('cpu')
|
||||
|
||||
# Add on gradient update, but not if using kahan summation as the bottom
|
||||
# bits must be restored first. (This update occurs in copy_kahan_() instead)
|
||||
|
||||
Reference in New Issue
Block a user