mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
sd3 schedule free opt (#1605)
* New ScheduleFree support for Flux (#1600) * init * use no schedule * fix typo * update for eval() * fix typo * update * Update train_util.py * Update requirements.txt * update sfwrapper WIP * no need to check schedulefree optimizer * remove debug print * comment out schedulefree wrapper * update readme --------- Co-authored-by: 青龍聖者@bdsqlsz <865105819@qq.com>
This commit is contained in:
@@ -3303,6 +3303,20 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
||||
help='additional arguments for optimizer (like "weight_decay=0.01 betas=0.9,0.999 ...") / オプティマイザの追加引数(例: "weight_decay=0.01 betas=0.9,0.999 ...")',
|
||||
)
|
||||
|
||||
# parser.add_argument(
|
||||
# "--optimizer_schedulefree_wrapper",
|
||||
# action="store_true",
|
||||
# help="use schedulefree_wrapper any optimizer / 任意のオプティマイザにschedulefree_wrapperを使用",
|
||||
# )
|
||||
|
||||
# parser.add_argument(
|
||||
# "--schedulefree_wrapper_args",
|
||||
# type=str,
|
||||
# default=None,
|
||||
# nargs="*",
|
||||
# help='additional arguments for schedulefree_wrapper (like "momentum=0.9 weight_decay_at_y=0.1 ...") / オプティマイザの追加引数(例: "momentum=0.9 weight_decay_at_y=0.1 ...")',
|
||||
# )
|
||||
|
||||
parser.add_argument("--lr_scheduler_type", type=str, default="", help="custom scheduler module / 使用するスケジューラ")
|
||||
parser.add_argument(
|
||||
"--lr_scheduler_args",
|
||||
@@ -4582,26 +4596,146 @@ def get_optimizer(args, trainable_params):
|
||||
optimizer_class = torch.optim.AdamW
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type.endswith("schedulefree".lower()):
|
||||
try:
|
||||
import schedulefree as sf
|
||||
except ImportError:
|
||||
raise ImportError("No schedulefree / schedulefreeがインストールされていないようです")
|
||||
if optimizer_type == "AdamWScheduleFree".lower():
|
||||
optimizer_class = sf.AdamWScheduleFree
|
||||
logger.info(f"use AdamWScheduleFree optimizer | {optimizer_kwargs}")
|
||||
elif optimizer_type == "SGDScheduleFree".lower():
|
||||
optimizer_class = sf.SGDScheduleFree
|
||||
logger.info(f"use SGDScheduleFree optimizer | {optimizer_kwargs}")
|
||||
else:
|
||||
raise ValueError(f"Unknown optimizer type: {optimizer_type}")
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
# make optimizer as train mode: we don't need to call train again, because eval will not be called in training loop
|
||||
optimizer.train()
|
||||
|
||||
if optimizer is None:
|
||||
# 任意のoptimizerを使う
|
||||
optimizer_type = args.optimizer_type # lowerでないやつ(微妙)
|
||||
logger.info(f"use {optimizer_type} | {optimizer_kwargs}")
|
||||
if "." not in optimizer_type:
|
||||
optimizer_module = torch.optim
|
||||
else:
|
||||
values = optimizer_type.split(".")
|
||||
optimizer_module = importlib.import_module(".".join(values[:-1]))
|
||||
optimizer_type = values[-1]
|
||||
case_sensitive_optimizer_type = args.optimizer_type # not lower
|
||||
logger.info(f"use {case_sensitive_optimizer_type} | {optimizer_kwargs}")
|
||||
|
||||
optimizer_class = getattr(optimizer_module, optimizer_type)
|
||||
if "." not in case_sensitive_optimizer_type: # from torch.optim
|
||||
optimizer_module = torch.optim
|
||||
else: # from other library
|
||||
values = case_sensitive_optimizer_type.split(".")
|
||||
optimizer_module = importlib.import_module(".".join(values[:-1]))
|
||||
case_sensitive_optimizer_type = values[-1]
|
||||
|
||||
optimizer_class = getattr(optimizer_module, case_sensitive_optimizer_type)
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
"""
|
||||
# wrap any of above optimizer with schedulefree, if optimizer is not schedulefree
|
||||
if args.optimizer_schedulefree_wrapper and not optimizer_type.endswith("schedulefree".lower()):
|
||||
try:
|
||||
import schedulefree as sf
|
||||
except ImportError:
|
||||
raise ImportError("No schedulefree / schedulefreeがインストールされていないようです")
|
||||
|
||||
schedulefree_wrapper_kwargs = {}
|
||||
if args.schedulefree_wrapper_args is not None and len(args.schedulefree_wrapper_args) > 0:
|
||||
for arg in args.schedulefree_wrapper_args:
|
||||
key, value = arg.split("=")
|
||||
value = ast.literal_eval(value)
|
||||
schedulefree_wrapper_kwargs[key] = value
|
||||
|
||||
sf_wrapper = sf.ScheduleFreeWrapper(optimizer, **schedulefree_wrapper_kwargs)
|
||||
sf_wrapper.train() # make optimizer as train mode
|
||||
|
||||
# we need to make optimizer as a subclass of torch.optim.Optimizer, we make another Proxy class over SFWrapper
|
||||
class OptimizerProxy(torch.optim.Optimizer):
|
||||
def __init__(self, sf_wrapper):
|
||||
self._sf_wrapper = sf_wrapper
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._sf_wrapper, name)
|
||||
|
||||
# override properties
|
||||
@property
|
||||
def state(self):
|
||||
return self._sf_wrapper.state
|
||||
|
||||
@state.setter
|
||||
def state(self, state):
|
||||
self._sf_wrapper.state = state
|
||||
|
||||
@property
|
||||
def param_groups(self):
|
||||
return self._sf_wrapper.param_groups
|
||||
|
||||
@param_groups.setter
|
||||
def param_groups(self, param_groups):
|
||||
self._sf_wrapper.param_groups = param_groups
|
||||
|
||||
@property
|
||||
def defaults(self):
|
||||
return self._sf_wrapper.defaults
|
||||
|
||||
@defaults.setter
|
||||
def defaults(self, defaults):
|
||||
self._sf_wrapper.defaults = defaults
|
||||
|
||||
def add_param_group(self, param_group):
|
||||
self._sf_wrapper.add_param_group(param_group)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self._sf_wrapper.load_state_dict(state_dict)
|
||||
|
||||
def state_dict(self):
|
||||
return self._sf_wrapper.state_dict()
|
||||
|
||||
def zero_grad(self):
|
||||
self._sf_wrapper.zero_grad()
|
||||
|
||||
def step(self, closure=None):
|
||||
self._sf_wrapper.step(closure)
|
||||
|
||||
def train(self):
|
||||
self._sf_wrapper.train()
|
||||
|
||||
def eval(self):
|
||||
self._sf_wrapper.eval()
|
||||
|
||||
# isinstance チェックをパスするためのメソッド
|
||||
def __instancecheck__(self, instance):
|
||||
return isinstance(instance, (type(self), Optimizer))
|
||||
|
||||
optimizer = OptimizerProxy(sf_wrapper)
|
||||
|
||||
logger.info(f"wrap optimizer with ScheduleFreeWrapper | {schedulefree_wrapper_kwargs}")
|
||||
"""
|
||||
|
||||
optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
|
||||
optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()])
|
||||
|
||||
return optimizer_name, optimizer_args, optimizer
|
||||
|
||||
|
||||
def is_schedulefree_optimizer(optimizer: Optimizer, args: argparse.Namespace) -> bool:
|
||||
return args.optimizer_type.lower().endswith("schedulefree".lower()) # or args.optimizer_schedulefree_wrapper
|
||||
|
||||
|
||||
def get_dummy_scheduler(optimizer: Optimizer) -> Any:
|
||||
# dummy scheduler for schedulefree optimizer. supports only empty step(), get_last_lr() and optimizers.
|
||||
# this scheduler is used for logging only.
|
||||
# this isn't be wrapped by accelerator because of this class is not a subclass of torch.optim.lr_scheduler._LRScheduler
|
||||
class DummyScheduler:
|
||||
def __init__(self, optimizer: Optimizer):
|
||||
self.optimizer = optimizer
|
||||
|
||||
def step(self):
|
||||
pass
|
||||
|
||||
def get_last_lr(self):
|
||||
return [group["lr"] for group in self.optimizer.param_groups]
|
||||
|
||||
return DummyScheduler(optimizer)
|
||||
|
||||
|
||||
# Modified version of get_scheduler() function from diffusers.optimizer.get_scheduler
|
||||
# Add some checking and features to the original function.
|
||||
|
||||
@@ -4610,6 +4744,10 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
|
||||
"""
|
||||
Unified API to get any scheduler from its name.
|
||||
"""
|
||||
# if schedulefree optimizer, return dummy scheduler
|
||||
if is_schedulefree_optimizer(optimizer, args):
|
||||
return get_dummy_scheduler(optimizer)
|
||||
|
||||
name = args.lr_scheduler
|
||||
num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps
|
||||
num_warmup_steps: Optional[int] = (
|
||||
|
||||
Reference in New Issue
Block a user