pass scheduler args to custom scheduler as kwargs if supported

This commit is contained in:
uwidev
2025-01-19 13:01:47 -08:00
parent 345daaa986
commit b0fea6a2eb

View File

@@ -76,6 +76,7 @@ import library.huggingface_util as huggingface_util
import library.sai_model_spec as sai_model_spec
import library.deepspeed_utils as deepspeed_utils
from library.utils import setup_logging, pil_resize
from inspect import signature, Parameter
setup_logging()
import logging
@@ -4431,8 +4432,27 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
values = lr_scheduler_type.split(".")
lr_scheduler_module = importlib.import_module(".".join(values[:-1]))
lr_scheduler_type = values[-1]
lr_scheduler_class = getattr(lr_scheduler_module, lr_scheduler_type)
lr_scheduler = lr_scheduler_class(optimizer, **lr_scheduler_kwargs)
# if constructor supports kwargs, pass previously calculated data
req_param = signature(lr_scheduler_class).parameters
if any(param.kind == Parameter.VAR_KEYWORD for param in req_param.values()):
packed_args = {
"num_training_steps": num_training_steps,
"num_warmup_steps": num_warmup_steps,
"num_decay_steps": num_decay_steps,
"num_stable_steps": num_stable_steps,
"num_cycles": num_cycles,
"power": power,
"timescale": timescale,
"min_lr_ratio": min_lr_ratio,
}
# explicitly defined lr_scheduler_args overwrite packed_args
lr_scheduler = lr_scheduler_class(optimizer, **{**packed_args, **lr_scheduler_kwargs})
else:
lr_scheduler = lr_scheduler_class(optimizer, **lr_scheduler_kwargs)
return wrap_check_needless_num_warmup_steps(lr_scheduler)
if name.startswith("adafactor"):