From b0fea6a2ebb3d9b95b8075e0ce36473216da32c5 Mon Sep 17 00:00:00 2001 From: uwidev Date: Sun, 19 Jan 2025 13:01:47 -0800 Subject: [PATCH] pass scheduler args to custom scheduler as kwargs if supported --- library/train_util.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 100ef475..15486dd4 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -76,6 +76,7 @@ import library.huggingface_util as huggingface_util import library.sai_model_spec as sai_model_spec import library.deepspeed_utils as deepspeed_utils from library.utils import setup_logging, pil_resize +from inspect import signature, Parameter setup_logging() import logging @@ -4431,8 +4432,27 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): values = lr_scheduler_type.split(".") lr_scheduler_module = importlib.import_module(".".join(values[:-1])) lr_scheduler_type = values[-1] + lr_scheduler_class = getattr(lr_scheduler_module, lr_scheduler_type) - lr_scheduler = lr_scheduler_class(optimizer, **lr_scheduler_kwargs) + + # if constructor supports kwargs, pass previously calculated data + req_param = signature(lr_scheduler_class).parameters + if any(param.kind == Parameter.VAR_KEYWORD for param in req_param.values()): + packed_args = { + "num_training_steps": num_training_steps, + "num_warmup_steps": num_warmup_steps, + "num_decay_steps": num_decay_steps, + "num_stable_steps": num_stable_steps, + "num_cycles": num_cycles, + "power": power, + "timescale": timescale, + "min_lr_ratio": min_lr_ratio, + } + # explicitly defined lr_scheduler_args overwrite packed_args + lr_scheduler = lr_scheduler_class(optimizer, **{**packed_args, **lr_scheduler_kwargs}) + else: + lr_scheduler = lr_scheduler_class(optimizer, **lr_scheduler_kwargs) + return wrap_check_needless_num_warmup_steps(lr_scheduler) if name.startswith("adafactor"):