fix to call train/eval in schedulefree #1605

This commit is contained in:
Kohya S
2024-09-18 21:31:54 +09:00
parent e74502117b
commit 1286e00bb0
4 changed files with 33 additions and 1 deletions

View File

@@ -13,6 +13,7 @@ import shutil
import time
from typing import (
Any,
Callable,
Dict,
List,
NamedTuple,
@@ -4715,8 +4716,20 @@ def get_optimizer(args, trainable_params):
return optimizer_name, optimizer_args, optimizer
def get_optimizer_train_eval_fn(optimizer: Optimizer, args: argparse.Namespace) -> Tuple[Callable, Callable]:
if not is_schedulefree_optimizer(optimizer, args):
# return dummy func
return lambda: None, lambda: None
# get train and eval functions from optimizer
train_fn = optimizer.train
eval_fn = optimizer.eval
return train_fn, eval_fn
def is_schedulefree_optimizer(optimizer: Optimizer, args: argparse.Namespace) -> bool:
return args.optimizer_type.lower().endswith("schedulefree".lower()) # or args.optimizer_schedulefree_wrapper
return args.optimizer_type.lower().endswith("schedulefree".lower()) # or args.optimizer_schedulefree_wrapper
def get_dummy_scheduler(optimizer: Optimizer) -> Any: