mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
sd3 schedule free opt (#1605)
* New ScheduleFree support for Flux (#1600) * init * use no schedule * fix typo * update for eval() * fix typo * update * Update train_util.py * Update requirements.txt * update sfwrapper WIP * no need to check schedulefree optimizer * remove debug print * comment out schedulefree wrapper * update readme --------- Co-authored-by: 青龍聖者@bdsqlsz <865105819@qq.com>
This commit is contained in:
@@ -11,6 +11,14 @@ The command to install PyTorch is as follows:
|
|||||||
|
|
||||||
### Recent Updates
|
### Recent Updates
|
||||||
|
|
||||||
|
Sep 18, 2024:
|
||||||
|
|
||||||
|
- Schedule-free optimizer is added. Thanks to sdbds! See PR [#1600](https://github.com/kohya-ss/sd-scripts/pull/1600) for details.
|
||||||
|
- `schedulefree` is added to the dependencies. Please update the library if necessary.
|
||||||
|
- AdamWScheduleFree or SGDScheduleFree can be used. Specify `adamwschedulefree` or `sgdschedulefree` in `--optimizer_type`.
|
||||||
|
- Wrapper classes are not available for now.
|
||||||
|
- These can be used not only for FLUX.1 training but also for other training scripts after merging to the dev/main branch.
|
||||||
|
|
||||||
Sep 16, 2024:
|
Sep 16, 2024:
|
||||||
|
|
||||||
Added `train_double_block_indices` and `train_double_block_indices` to the LoRA training script to specify the indices of the blocks to train. See [Specify blocks to train in FLUX.1 LoRA training](#specify-blocks-to-train-in-flux1-lora-training) for details.
|
Added `train_double_block_indices` and `train_double_block_indices` to the LoRA training script to specify the indices of the blocks to train. See [Specify blocks to train in FLUX.1 LoRA training](#specify-blocks-to-train-in-flux1-lora-training) for details.
|
||||||
|
|||||||
@@ -3303,6 +3303,20 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
|||||||
help='additional arguments for optimizer (like "weight_decay=0.01 betas=0.9,0.999 ...") / オプティマイザの追加引数(例: "weight_decay=0.01 betas=0.9,0.999 ...")',
|
help='additional arguments for optimizer (like "weight_decay=0.01 betas=0.9,0.999 ...") / オプティマイザの追加引数(例: "weight_decay=0.01 betas=0.9,0.999 ...")',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# parser.add_argument(
|
||||||
|
# "--optimizer_schedulefree_wrapper",
|
||||||
|
# action="store_true",
|
||||||
|
# help="use schedulefree_wrapper any optimizer / 任意のオプティマイザにschedulefree_wrapperを使用",
|
||||||
|
# )
|
||||||
|
|
||||||
|
# parser.add_argument(
|
||||||
|
# "--schedulefree_wrapper_args",
|
||||||
|
# type=str,
|
||||||
|
# default=None,
|
||||||
|
# nargs="*",
|
||||||
|
# help='additional arguments for schedulefree_wrapper (like "momentum=0.9 weight_decay_at_y=0.1 ...") / オプティマイザの追加引数(例: "momentum=0.9 weight_decay_at_y=0.1 ...")',
|
||||||
|
# )
|
||||||
|
|
||||||
parser.add_argument("--lr_scheduler_type", type=str, default="", help="custom scheduler module / 使用するスケジューラ")
|
parser.add_argument("--lr_scheduler_type", type=str, default="", help="custom scheduler module / 使用するスケジューラ")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lr_scheduler_args",
|
"--lr_scheduler_args",
|
||||||
@@ -4582,26 +4596,146 @@ def get_optimizer(args, trainable_params):
|
|||||||
optimizer_class = torch.optim.AdamW
|
optimizer_class = torch.optim.AdamW
|
||||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||||
|
|
||||||
|
elif optimizer_type.endswith("schedulefree".lower()):
|
||||||
|
try:
|
||||||
|
import schedulefree as sf
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("No schedulefree / schedulefreeがインストールされていないようです")
|
||||||
|
if optimizer_type == "AdamWScheduleFree".lower():
|
||||||
|
optimizer_class = sf.AdamWScheduleFree
|
||||||
|
logger.info(f"use AdamWScheduleFree optimizer | {optimizer_kwargs}")
|
||||||
|
elif optimizer_type == "SGDScheduleFree".lower():
|
||||||
|
optimizer_class = sf.SGDScheduleFree
|
||||||
|
logger.info(f"use SGDScheduleFree optimizer | {optimizer_kwargs}")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown optimizer type: {optimizer_type}")
|
||||||
|
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||||
|
# make optimizer as train mode: we don't need to call train again, because eval will not be called in training loop
|
||||||
|
optimizer.train()
|
||||||
|
|
||||||
if optimizer is None:
|
if optimizer is None:
|
||||||
# 任意のoptimizerを使う
|
# 任意のoptimizerを使う
|
||||||
optimizer_type = args.optimizer_type # lowerでないやつ(微妙)
|
case_sensitive_optimizer_type = args.optimizer_type # not lower
|
||||||
logger.info(f"use {optimizer_type} | {optimizer_kwargs}")
|
logger.info(f"use {case_sensitive_optimizer_type} | {optimizer_kwargs}")
|
||||||
if "." not in optimizer_type:
|
|
||||||
optimizer_module = torch.optim
|
|
||||||
else:
|
|
||||||
values = optimizer_type.split(".")
|
|
||||||
optimizer_module = importlib.import_module(".".join(values[:-1]))
|
|
||||||
optimizer_type = values[-1]
|
|
||||||
|
|
||||||
optimizer_class = getattr(optimizer_module, optimizer_type)
|
if "." not in case_sensitive_optimizer_type: # from torch.optim
|
||||||
|
optimizer_module = torch.optim
|
||||||
|
else: # from other library
|
||||||
|
values = case_sensitive_optimizer_type.split(".")
|
||||||
|
optimizer_module = importlib.import_module(".".join(values[:-1]))
|
||||||
|
case_sensitive_optimizer_type = values[-1]
|
||||||
|
|
||||||
|
optimizer_class = getattr(optimizer_module, case_sensitive_optimizer_type)
|
||||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||||
|
|
||||||
|
"""
|
||||||
|
# wrap any of above optimizer with schedulefree, if optimizer is not schedulefree
|
||||||
|
if args.optimizer_schedulefree_wrapper and not optimizer_type.endswith("schedulefree".lower()):
|
||||||
|
try:
|
||||||
|
import schedulefree as sf
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("No schedulefree / schedulefreeがインストールされていないようです")
|
||||||
|
|
||||||
|
schedulefree_wrapper_kwargs = {}
|
||||||
|
if args.schedulefree_wrapper_args is not None and len(args.schedulefree_wrapper_args) > 0:
|
||||||
|
for arg in args.schedulefree_wrapper_args:
|
||||||
|
key, value = arg.split("=")
|
||||||
|
value = ast.literal_eval(value)
|
||||||
|
schedulefree_wrapper_kwargs[key] = value
|
||||||
|
|
||||||
|
sf_wrapper = sf.ScheduleFreeWrapper(optimizer, **schedulefree_wrapper_kwargs)
|
||||||
|
sf_wrapper.train() # make optimizer as train mode
|
||||||
|
|
||||||
|
# we need to make optimizer as a subclass of torch.optim.Optimizer, we make another Proxy class over SFWrapper
|
||||||
|
class OptimizerProxy(torch.optim.Optimizer):
|
||||||
|
def __init__(self, sf_wrapper):
|
||||||
|
self._sf_wrapper = sf_wrapper
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
return getattr(self._sf_wrapper, name)
|
||||||
|
|
||||||
|
# override properties
|
||||||
|
@property
|
||||||
|
def state(self):
|
||||||
|
return self._sf_wrapper.state
|
||||||
|
|
||||||
|
@state.setter
|
||||||
|
def state(self, state):
|
||||||
|
self._sf_wrapper.state = state
|
||||||
|
|
||||||
|
@property
|
||||||
|
def param_groups(self):
|
||||||
|
return self._sf_wrapper.param_groups
|
||||||
|
|
||||||
|
@param_groups.setter
|
||||||
|
def param_groups(self, param_groups):
|
||||||
|
self._sf_wrapper.param_groups = param_groups
|
||||||
|
|
||||||
|
@property
|
||||||
|
def defaults(self):
|
||||||
|
return self._sf_wrapper.defaults
|
||||||
|
|
||||||
|
@defaults.setter
|
||||||
|
def defaults(self, defaults):
|
||||||
|
self._sf_wrapper.defaults = defaults
|
||||||
|
|
||||||
|
def add_param_group(self, param_group):
|
||||||
|
self._sf_wrapper.add_param_group(param_group)
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict):
|
||||||
|
self._sf_wrapper.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
return self._sf_wrapper.state_dict()
|
||||||
|
|
||||||
|
def zero_grad(self):
|
||||||
|
self._sf_wrapper.zero_grad()
|
||||||
|
|
||||||
|
def step(self, closure=None):
|
||||||
|
self._sf_wrapper.step(closure)
|
||||||
|
|
||||||
|
def train(self):
|
||||||
|
self._sf_wrapper.train()
|
||||||
|
|
||||||
|
def eval(self):
|
||||||
|
self._sf_wrapper.eval()
|
||||||
|
|
||||||
|
# isinstance チェックをパスするためのメソッド
|
||||||
|
def __instancecheck__(self, instance):
|
||||||
|
return isinstance(instance, (type(self), Optimizer))
|
||||||
|
|
||||||
|
optimizer = OptimizerProxy(sf_wrapper)
|
||||||
|
|
||||||
|
logger.info(f"wrap optimizer with ScheduleFreeWrapper | {schedulefree_wrapper_kwargs}")
|
||||||
|
"""
|
||||||
|
|
||||||
optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
|
optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
|
||||||
optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()])
|
optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()])
|
||||||
|
|
||||||
return optimizer_name, optimizer_args, optimizer
|
return optimizer_name, optimizer_args, optimizer
|
||||||
|
|
||||||
|
|
||||||
|
def is_schedulefree_optimizer(optimizer: Optimizer, args: argparse.Namespace) -> bool:
|
||||||
|
return args.optimizer_type.lower().endswith("schedulefree".lower()) # or args.optimizer_schedulefree_wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def get_dummy_scheduler(optimizer: Optimizer) -> Any:
|
||||||
|
# dummy scheduler for schedulefree optimizer. supports only empty step(), get_last_lr() and optimizers.
|
||||||
|
# this scheduler is used for logging only.
|
||||||
|
# this isn't be wrapped by accelerator because of this class is not a subclass of torch.optim.lr_scheduler._LRScheduler
|
||||||
|
class DummyScheduler:
|
||||||
|
def __init__(self, optimizer: Optimizer):
|
||||||
|
self.optimizer = optimizer
|
||||||
|
|
||||||
|
def step(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_last_lr(self):
|
||||||
|
return [group["lr"] for group in self.optimizer.param_groups]
|
||||||
|
|
||||||
|
return DummyScheduler(optimizer)
|
||||||
|
|
||||||
|
|
||||||
# Modified version of get_scheduler() function from diffusers.optimizer.get_scheduler
|
# Modified version of get_scheduler() function from diffusers.optimizer.get_scheduler
|
||||||
# Add some checking and features to the original function.
|
# Add some checking and features to the original function.
|
||||||
|
|
||||||
@@ -4610,6 +4744,10 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
|
|||||||
"""
|
"""
|
||||||
Unified API to get any scheduler from its name.
|
Unified API to get any scheduler from its name.
|
||||||
"""
|
"""
|
||||||
|
# if schedulefree optimizer, return dummy scheduler
|
||||||
|
if is_schedulefree_optimizer(optimizer, args):
|
||||||
|
return get_dummy_scheduler(optimizer)
|
||||||
|
|
||||||
name = args.lr_scheduler
|
name = args.lr_scheduler
|
||||||
num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps
|
num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps
|
||||||
num_warmup_steps: Optional[int] = (
|
num_warmup_steps: Optional[int] = (
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ pytorch-lightning==1.9.0
|
|||||||
bitsandbytes==0.43.3
|
bitsandbytes==0.43.3
|
||||||
prodigyopt==1.0
|
prodigyopt==1.0
|
||||||
lion-pytorch==0.0.6
|
lion-pytorch==0.0.6
|
||||||
|
schedulefree==1.2.7
|
||||||
tensorboard
|
tensorboard
|
||||||
safetensors==0.4.4
|
safetensors==0.4.4
|
||||||
# gradio==3.16.2
|
# gradio==3.16.2
|
||||||
|
|||||||
Reference in New Issue
Block a user