This commit is contained in:
some_ai
2025-01-06 01:19:43 -04:00
parent 8cee727a99
commit d0eba37946
3 changed files with 19 additions and 13 deletions

View File

@@ -481,15 +481,17 @@ def train(args):
for param_group, param_name_group in zip(optimizer.param_groups, param_names):
for parameter, param_name in zip(param_group["params"], param_name_group):
if parameter.requires_grad:
def create_grad_hook(p_name, p_group):
def grad_hook(tensor: torch.Tensor):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
optimizer.step_param(tensor, p_group)
tensor.grad = None
return grad_hook
parameter.register_post_accumulate_grad_hook(create_grad_hook(param_name, param_group))
elif args.optimizer_type == "ProdigyPlusScheduleFree":
elif args.optimizer_type == "prodigyplus.ProdigyPlusScheduleFree":
# ProdigyPlus uses its internal fused_back_pass mechanism, pass for now
pass
else:

View File

@@ -602,18 +602,22 @@ def train(args):
if args.optimizer_type == "AdaFactor":
import library.adafactor_fused
library.adafactor_fused.patch_adafactor_fused(optimizer)
for param_group, param_name_group in zip(optimizer.param_groups, param_names):
for parameter, param_name in zip(param_group["params"], param_name_group):
if parameter.requires_grad:
def create_grad_hook(p_name, p_group):
def grad_hook(tensor: torch.Tensor):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
optimizer.step_param(tensor, p_group)
tensor.grad = None
return grad_hook
parameter.register_post_accumulate_grad_hook(create_grad_hook(param_name, param_group))
elif args.optimizer_type == "ProdigyPlusScheduleFree":
elif args.optimizer_type == "prodigyplus.ProdigyPlusScheduleFree":
# ProdigyPlus uses its internal fused_back_pass mechanism, pass for now
pass
else:

View File

@@ -525,18 +525,18 @@ def train(args):
if args.optimizer_type == "AdaFactor":
import library.adafactor_fused
library.adafactor_fused.patch_adafactor_fused(optimizer)
for param_group, param_name_group in zip(optimizer.param_groups, param_names):
for parameter, param_name in zip(param_group["params"], param_name_group):
for param_group in optimizer.param_groups:
for parameter in param_group["params"]:
if parameter.requires_grad:
def create_grad_hook(p_name, p_group):
def grad_hook(tensor: torch.Tensor):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
optimizer.step_param(tensor, p_group)
tensor.grad = None
return grad_hook
parameter.register_post_accumulate_grad_hook(create_grad_hook(param_name, param_group))
elif args.optimizer_type == "ProdigyPlusScheduleFree":
def __grad_hook(tensor: torch.Tensor, param_group=param_group):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
optimizer.step_param(tensor, param_group)
tensor.grad = None
parameter.register_post_accumulate_grad_hook(__grad_hook)
elif args.optimizer_type == "prodigyplus.ProdigyPlusScheduleFree":
# ProdigyPlus uses its internal fused_back_pass mechanism, pass for now
pass
else: