mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 00:32:25 +00:00
fix
This commit is contained in:
@@ -481,15 +481,17 @@ def train(args):
|
||||
for param_group, param_name_group in zip(optimizer.param_groups, param_names):
|
||||
for parameter, param_name in zip(param_group["params"], param_name_group):
|
||||
if parameter.requires_grad:
|
||||
|
||||
def create_grad_hook(p_name, p_group):
|
||||
def grad_hook(tensor: torch.Tensor):
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
|
||||
optimizer.step_param(tensor, p_group)
|
||||
tensor.grad = None
|
||||
|
||||
return grad_hook
|
||||
parameter.register_post_accumulate_grad_hook(create_grad_hook(param_name, param_group))
|
||||
elif args.optimizer_type == "ProdigyPlusScheduleFree":
|
||||
elif args.optimizer_type == "prodigyplus.ProdigyPlusScheduleFree":
|
||||
# ProdigyPlus uses its internal fused_back_pass mechanism, pass for now
|
||||
pass
|
||||
else:
|
||||
|
||||
@@ -602,18 +602,22 @@ def train(args):
|
||||
if args.optimizer_type == "AdaFactor":
|
||||
import library.adafactor_fused
|
||||
library.adafactor_fused.patch_adafactor_fused(optimizer)
|
||||
|
||||
for param_group, param_name_group in zip(optimizer.param_groups, param_names):
|
||||
for parameter, param_name in zip(param_group["params"], param_name_group):
|
||||
if parameter.requires_grad:
|
||||
|
||||
def create_grad_hook(p_name, p_group):
|
||||
def grad_hook(tensor: torch.Tensor):
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
|
||||
optimizer.step_param(tensor, p_group)
|
||||
tensor.grad = None
|
||||
|
||||
return grad_hook
|
||||
|
||||
parameter.register_post_accumulate_grad_hook(create_grad_hook(param_name, param_group))
|
||||
elif args.optimizer_type == "ProdigyPlusScheduleFree":
|
||||
elif args.optimizer_type == "prodigyplus.ProdigyPlusScheduleFree":
|
||||
# ProdigyPlus uses its internal fused_back_pass mechanism, pass for now
|
||||
pass
|
||||
else:
|
||||
|
||||
@@ -525,18 +525,18 @@ def train(args):
|
||||
if args.optimizer_type == "AdaFactor":
|
||||
import library.adafactor_fused
|
||||
library.adafactor_fused.patch_adafactor_fused(optimizer)
|
||||
for param_group, param_name_group in zip(optimizer.param_groups, param_names):
|
||||
for parameter, param_name in zip(param_group["params"], param_name_group):
|
||||
for param_group in optimizer.param_groups:
|
||||
for parameter in param_group["params"]:
|
||||
if parameter.requires_grad:
|
||||
def create_grad_hook(p_name, p_group):
|
||||
def grad_hook(tensor: torch.Tensor):
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
|
||||
optimizer.step_param(tensor, p_group)
|
||||
tensor.grad = None
|
||||
return grad_hook
|
||||
parameter.register_post_accumulate_grad_hook(create_grad_hook(param_name, param_group))
|
||||
elif args.optimizer_type == "ProdigyPlusScheduleFree":
|
||||
|
||||
def __grad_hook(tensor: torch.Tensor, param_group=param_group):
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
|
||||
optimizer.step_param(tensor, param_group)
|
||||
tensor.grad = None
|
||||
|
||||
parameter.register_post_accumulate_grad_hook(__grad_hook)
|
||||
elif args.optimizer_type == "prodigyplus.ProdigyPlusScheduleFree":
|
||||
# ProdigyPlus uses its internal fused_back_pass mechanism, pass for now
|
||||
pass
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user