mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix to call train/eval in schedulefree #1605
This commit is contained in:
@@ -13,6 +13,7 @@ import shutil
|
||||
import time
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
NamedTuple,
|
||||
@@ -4715,8 +4716,20 @@ def get_optimizer(args, trainable_params):
|
||||
return optimizer_name, optimizer_args, optimizer
|
||||
|
||||
|
||||
def get_optimizer_train_eval_fn(optimizer: Optimizer, args: argparse.Namespace) -> Tuple[Callable, Callable]:
|
||||
if not is_schedulefree_optimizer(optimizer, args):
|
||||
# return dummy func
|
||||
return lambda: None, lambda: None
|
||||
|
||||
# get train and eval functions from optimizer
|
||||
train_fn = optimizer.train
|
||||
eval_fn = optimizer.eval
|
||||
|
||||
return train_fn, eval_fn
|
||||
|
||||
|
||||
def is_schedulefree_optimizer(optimizer: Optimizer, args: argparse.Namespace) -> bool:
|
||||
return args.optimizer_type.lower().endswith("schedulefree".lower()) # or args.optimizer_schedulefree_wrapper
|
||||
return args.optimizer_type.lower().endswith("schedulefree".lower()) # or args.optimizer_schedulefree_wrapper
|
||||
|
||||
|
||||
def get_dummy_scheduler(optimizer: Optimizer) -> Any:
|
||||
|
||||
Reference in New Issue
Block a user