diff --git a/fine_tune.py b/fine_tune.py index b6a8d1d7..4227dd04 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -376,7 +376,7 @@ def train(args): current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず if args.logging_dir is not None: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} - if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value + if args.optimizer_type.lower() == "DAdaptation".lower() or args.optimizer_type.lower() == "DAdaptAdam".lower() or args.optimizer_type.lower() == "DAdaptAdaGrad".lower() or args.optimizer_type.lower() == "DAdaptAdan".lower() or args.optimizer_type.lower() == "DAdaptSGD".lower(): # tracking d*lr value logs["lr/d*lr"] = ( lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] )