Merge branch 'dev' into sd3

This commit is contained in:
Kohya S
2024-09-26 21:41:25 +09:00
3 changed files with 15 additions and 5 deletions

View File

@@ -707,9 +707,12 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
### Working in progress
- __important__ The dependent libraries are updated. Please see [Upgrade](#upgrade) and update the libraries.
- transformers, accelerate and huggingface_hub are updated.
- bitsandbytes, transformers, accelerate and huggingface_hub are updated.
- If you encounter any issues, please report them.
- `bitsandbytes` is updated to 0.44.0. Now you can use `AdEMAMix8bit` and `PagedAdEMAMix8bit` in the training script. PR [#1640](https://github.com/kohya-ss/sd-scripts/pull/1640) Thanks to sdbds!
- There is no abbreviation, so please specify the full path like `--optimizer_type bitsandbytes.optim.AdEMAMix8bit` (not bnb but bitsandbytes).
- Fixed a bug in the cache of latents. When `flip_aug`, `alpha_mask`, and `random_crop` are different in multiple subsets in the dataset configuration file (.toml), the last subset is used instead of reflecting them correctly.
- Fixed an issue where the timesteps in the batch were the same when using Huber loss. PR [#1628](https://github.com/kohya-ss/sd-scripts/pull/1628) Thanks to recris!

View File

@@ -3312,7 +3312,11 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
"--optimizer_type",
type=str,
default="",
help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, AdaFactor",
help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, "
"Lion8bit, PagedLion8bit, Lion, SGDNesterov, SGDNesterov8bit, "
"DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, "
"AdaFactor. "
"Also, you can use any optimizer by specifying the full path to the class, like 'bitsandbytes.optim.AdEMAMix8bit' or 'bitsandbytes.optim.PagedAdEMAMix8bit'.",
)
# backward compatibility
@@ -4364,7 +4368,7 @@ def resume_from_local_or_hf_if_specified(accelerator, args):
def get_optimizer(args, trainable_params):
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor"
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, AdEMAMix8bit, PagedAdEMAMix8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor"
optimizer_type = args.optimizer_type
if args.use_8bit_adam:
@@ -4417,6 +4421,7 @@ def get_optimizer(args, trainable_params):
lr = args.learning_rate
optimizer = None
optimizer_class = None
if optimizer_type == "Lion".lower():
try:
@@ -4474,7 +4479,8 @@ def get_optimizer(args, trainable_params):
"No PagedLion8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedLion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください"
)
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
if optimizer_class is not None:
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
elif optimizer_type == "PagedAdamW".lower():
logger.info(f"use PagedAdamW optimizer | {optimizer_kwargs}")
@@ -4749,6 +4755,7 @@ def get_optimizer(args, trainable_params):
logger.info(f"wrap optimizer with ScheduleFreeWrapper | {schedulefree_wrapper_kwargs}")
"""
# for logging
optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()])

View File

@@ -6,7 +6,7 @@ ftfy==6.1.1
opencv-python==4.8.1.78
einops==0.7.0
pytorch-lightning==1.9.0
bitsandbytes==0.43.3
bitsandbytes==0.44.0
prodigyopt==1.0
lion-pytorch==0.0.6
schedulefree==1.2.7