refactor selection and logging for DAdaptation

This commit is contained in:
Kohya S
2023-05-06 18:14:16 +09:00
parent 164a1978de
commit 2127907dd3
6 changed files with 28 additions and 98 deletions

View File

@@ -367,7 +367,7 @@ def train(args):
current_loss = loss.detach().item()
if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
if args.optimizer_type.lower() == "DAdaptation".lower() or args.optimizer_type.lower() == "DAdaptAdam".lower() or args.optimizer_type.lower() == "DAdaptAdaGrad".lower() or args.optimizer_type.lower() == "DAdaptAdan".lower() or args.optimizer_type.lower() == "DAdaptSGD".lower(): # tracking d*lr value
if args.optimizer_type.lower().startswith("DAdapt".lower()): # tracking d*lr value
logs["lr/d*lr"] = (
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
)