mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 23:01:22 +00:00
pass scheduler args to custom scheduler as kwargs if supported
This commit is contained in:
@@ -76,6 +76,7 @@ import library.huggingface_util as huggingface_util
|
||||
import library.sai_model_spec as sai_model_spec
|
||||
import library.deepspeed_utils as deepspeed_utils
|
||||
from library.utils import setup_logging, pil_resize
|
||||
from inspect import signature, Parameter
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
@@ -4431,8 +4432,27 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
|
||||
values = lr_scheduler_type.split(".")
|
||||
lr_scheduler_module = importlib.import_module(".".join(values[:-1]))
|
||||
lr_scheduler_type = values[-1]
|
||||
|
||||
lr_scheduler_class = getattr(lr_scheduler_module, lr_scheduler_type)
|
||||
lr_scheduler = lr_scheduler_class(optimizer, **lr_scheduler_kwargs)
|
||||
|
||||
# if constructor supports kwargs, pass previously calculated data
|
||||
req_param = signature(lr_scheduler_class).parameters
|
||||
if any(param.kind == Parameter.VAR_KEYWORD for param in req_param.values()):
|
||||
packed_args = {
|
||||
"num_training_steps": num_training_steps,
|
||||
"num_warmup_steps": num_warmup_steps,
|
||||
"num_decay_steps": num_decay_steps,
|
||||
"num_stable_steps": num_stable_steps,
|
||||
"num_cycles": num_cycles,
|
||||
"power": power,
|
||||
"timescale": timescale,
|
||||
"min_lr_ratio": min_lr_ratio,
|
||||
}
|
||||
# explicitly defined lr_scheduler_args overwrite packed_args
|
||||
lr_scheduler = lr_scheduler_class(optimizer, **{**packed_args, **lr_scheduler_kwargs})
|
||||
else:
|
||||
lr_scheduler = lr_scheduler_class(optimizer, **lr_scheduler_kwargs)
|
||||
|
||||
return wrap_check_needless_num_warmup_steps(lr_scheduler)
|
||||
|
||||
if name.startswith("adafactor"):
|
||||
|
||||
Reference in New Issue
Block a user