add stochastic rounding, fix single block

This commit is contained in:
Kohya S
2024-08-21 21:04:10 +09:00
parent 2b07a92c8d
commit e1cd19c0c0
4 changed files with 135 additions and 14 deletions

View File

@@ -277,7 +277,10 @@ def train(args):
training_models = []
params_to_optimize = []
training_models.append(flux)
params_to_optimize.append({"params": list(flux.parameters()), "lr": args.learning_rate})
name_and_params = list(flux.named_parameters())
# single param group for now
params_to_optimize.append({"params": [p for _, p in name_and_params], "lr": args.learning_rate})
param_names = [[n for n, _ in name_and_params]]
# calculate number of trainable parameters
n_params = 0
@@ -433,17 +436,89 @@ def train(args):
import library.adafactor_fused
library.adafactor_fused.patch_adafactor_fused(optimizer)
for param_group in optimizer.param_groups:
for parameter in param_group["params"]:
double_blocks_to_swap = args.double_blocks_to_swap
single_blocks_to_swap = args.single_blocks_to_swap
num_double_blocks = len(flux.double_blocks)
num_single_blocks = len(flux.single_blocks)
handled_double_block_indices = set()
handled_single_block_indices = set()
for param_group, param_name_group in zip(optimizer.param_groups, param_names):
for parameter, param_name in zip(param_group["params"], param_name_group):
if parameter.requires_grad:
grad_hook = None
def __grad_hook(tensor: torch.Tensor, param_group=param_group):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
optimizer.step_param(tensor, param_group)
tensor.grad = None
if double_blocks_to_swap:
if param_name.startswith("double_blocks"):
block_idx = int(param_name.split(".")[1])
if (
block_idx not in handled_double_block_indices
and block_idx >= (num_double_blocks - double_blocks_to_swap) - 1
and block_idx < num_double_blocks - 1
):
# swap next (already backpropagated) block
handled_double_block_indices.add(block_idx)
block_idx_cpu = block_idx + 1
block_idx_cuda = double_blocks_to_swap - (num_double_blocks - block_idx_cpu)
parameter.register_post_accumulate_grad_hook(__grad_hook)
# create swap hook
def create_double_swap_grad_hook(bidx, bidx_cuda):
def __grad_hook(tensor: torch.Tensor):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
optimizer.step_param(tensor, param_group)
tensor.grad = None
# swap blocks if necessary
flux.double_blocks[bidx].to("cpu")
flux.double_blocks[bidx_cuda].to(accelerator.device)
# print(f"Move double block {bidx} to cpu and {bidx_cuda} to device")
return __grad_hook
grad_hook = create_double_swap_grad_hook(block_idx_cpu, block_idx_cuda)
if single_blocks_to_swap:
if param_name.startswith("single_blocks"):
block_idx = int(param_name.split(".")[1])
if (
block_idx not in handled_single_block_indices
and block_idx >= (num_single_blocks - single_blocks_to_swap) - 1
and block_idx < num_single_blocks - 1
):
handled_single_block_indices.add(block_idx)
block_idx_cpu = block_idx + 1
block_idx_cuda = single_blocks_to_swap - (num_single_blocks - block_idx_cpu)
# print(param_name, block_idx_cpu, block_idx_cuda)
# create swap hook
def create_single_swap_grad_hook(bidx, bidx_cuda):
def __grad_hook(tensor: torch.Tensor):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
optimizer.step_param(tensor, param_group)
tensor.grad = None
# swap blocks if necessary
flux.single_blocks[bidx].to("cpu")
flux.single_blocks[bidx_cuda].to(accelerator.device)
# print(f"Move single block {bidx} to cpu and {bidx_cuda} to device")
return __grad_hook
grad_hook = create_single_swap_grad_hook(block_idx_cpu, block_idx_cuda)
if grad_hook is None:
def __grad_hook(tensor: torch.Tensor, param_group=param_group):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
optimizer.step_param(tensor, param_group)
tensor.grad = None
grad_hook = __grad_hook
parameter.register_post_accumulate_grad_hook(grad_hook)
elif args.blockwise_fused_optimizers:
# prepare for additional optimizers and lr schedulers