Update train_util.py for DAdaptadam

This commit is contained in:
青龍聖者@bdsqlsz
2023-04-26 16:51:12 +08:00
committed by GitHub
parent 0c9c90a87e
commit 0db2eddace

View File

@@ -2535,7 +2535,7 @@ def get_optimizer(args, trainable_params):
optimizer_class = torch.optim.SGD
optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs)
elif optimizer_type == "DAdaptation".lower():
elif optimizer_type == "DAdaptation".lower() or optimizer_type == "DAdaptAdam".lower():
try:
import dadaptation
except ImportError: