add stochastic rounding, fix single block

This commit is contained in:
Kohya S
2024-08-21 21:04:10 +09:00
parent 2b07a92c8d
commit e1cd19c0c0
4 changed files with 135 additions and 14 deletions

View File

@@ -2,6 +2,32 @@ import math
import torch
from transformers import Adafactor
# stochastic rounding for bfloat16
# The implementation was provided by 2kpr. Thank you very much!
def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
"""
copies source into target using stochastic rounding
Args:
target: the target tensor with dtype=bfloat16
source: the target tensor with dtype=float32
"""
# create a random 16 bit integer
result = torch.randint_like(source, dtype=torch.int32, low=0, high=(1 << 16))
# add the random number to the lower 16 bit of the mantissa
result.add_(source.view(dtype=torch.int32))
# mask off the lower 16 bit of the mantissa
result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32
# copy the higher 16 bit into the target tensor
target.copy_(result.view(dtype=torch.float32))
del result
@torch.no_grad()
def adafactor_step_param(self, p, group):
if p.grad is None:
@@ -48,7 +74,7 @@ def adafactor_step_param(self, p, group):
lr = Adafactor._get_lr(group, state)
beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
update = (grad ** 2) + group["eps"][0]
update = (grad**2) + group["eps"][0]
if factored:
exp_avg_sq_row = state["exp_avg_sq_row"]
exp_avg_sq_col = state["exp_avg_sq_col"]
@@ -78,7 +104,12 @@ def adafactor_step_param(self, p, group):
p_data_fp32.add_(-update)
if p.dtype in {torch.float16, torch.bfloat16}:
# if p.dtype in {torch.float16, torch.bfloat16}:
# p.copy_(p_data_fp32)
if p.dtype == torch.bfloat16:
copy_stochastic_(p, p_data_fp32)
elif p.dtype == torch.float16:
p.copy_(p_data_fp32)
@@ -101,6 +132,7 @@ def adafactor_step(self, closure=None):
return loss
def patch_adafactor_fused(optimizer: Adafactor):
optimizer.step_param = adafactor_step_param.__get__(optimizer)
optimizer.step = adafactor_step.__get__(optimizer)

View File

@@ -1078,6 +1078,7 @@ class Flux(nn.Module):
if moving:
self.single_blocks[to_cpu_block_index].to("cpu") # , non_blocking=True)
# print(f"Moved single block {to_cpu_block_index} to cpu.")
to_cpu_block_index += 1
img = img[:, txt.shape[1] :, ...]