This commit is contained in:
Dave Lage
2025-10-10 19:26:56 +01:00
committed by GitHub
3 changed files with 266 additions and 0 deletions

View File

@@ -767,6 +767,9 @@ class NetworkTrainer:
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args)
if hasattr(network, "register_optimizer"):
network.register_optimizer(optimizer)
# prepare dataloader
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
# some strategies can be None
@@ -1453,6 +1456,8 @@ class NetworkTrainer:
network.update_grad_norms()
if hasattr(network, "update_norms"):
network.update_norms()
if hasattr(network, "update_gradient_ema"):
network.update_gradient_ema()
optimizer.step()
lr_scheduler.step()