mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
refactor optimizer selection for bnb
This commit is contained in:
@@ -65,6 +65,7 @@ import safetensors.torch
|
||||
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
|
||||
import library.model_util as model_util
|
||||
import library.huggingface_util as huggingface_util
|
||||
|
||||
# from library.attention_processors import FlashAttnProcessor
|
||||
# from library.hypernetwork import replace_attentions_for_hypernetwork
|
||||
from library.original_unet import UNet2DConditionModel
|
||||
@@ -2165,7 +2166,7 @@ def cache_batch_latents(
|
||||
info.latents = latent
|
||||
if flip_aug:
|
||||
info.latents_flipped = flipped_latent
|
||||
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@@ -3184,32 +3185,9 @@ def get_optimizer(args, trainable_params):
|
||||
# print("optkwargs:", optimizer_kwargs)
|
||||
|
||||
lr = args.learning_rate
|
||||
optimizer = None
|
||||
|
||||
if optimizer_type == "AdamW8bit".lower():
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
||||
print(f"use 8-bit AdamW optimizer | {optimizer_kwargs}")
|
||||
optimizer_class = bnb.optim.AdamW8bit
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type == "SGDNesterov8bit".lower():
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
||||
print(f"use 8-bit SGD with Nesterov optimizer | {optimizer_kwargs}")
|
||||
if "momentum" not in optimizer_kwargs:
|
||||
print(
|
||||
f"8-bit SGD with Nesterov must be with momentum, set momentum to 0.9 / 8-bit SGD with Nesterovはmomentum指定が必須のため0.9に設定します"
|
||||
)
|
||||
optimizer_kwargs["momentum"] = 0.9
|
||||
|
||||
optimizer_class = bnb.optim.SGD8bit
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type == "Lion".lower():
|
||||
if optimizer_type == "Lion".lower():
|
||||
try:
|
||||
import lion_pytorch
|
||||
except ImportError:
|
||||
@@ -3217,37 +3195,53 @@ def get_optimizer(args, trainable_params):
|
||||
print(f"use Lion optimizer | {optimizer_kwargs}")
|
||||
optimizer_class = lion_pytorch.Lion
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
|
||||
elif optimizer_type.endswith("8bit".lower()):
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです")
|
||||
|
||||
if optimizer_type == "Lion8bit".lower():
|
||||
print(f"use 8-bit Lion optimizer | {optimizer_kwargs}")
|
||||
try:
|
||||
optimizer_class = bnb.optim.Lion8bit
|
||||
except AttributeError:
|
||||
raise AttributeError(
|
||||
"No Lion8bit. The version of bitsandbytes installed seems to be old. Please install 0.38.0 or later. / Lion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.38.0以上をインストールしてください"
|
||||
)
|
||||
if optimizer_type == "AdamW8bit".lower():
|
||||
print(f"use 8-bit AdamW optimizer | {optimizer_kwargs}")
|
||||
optimizer_class = bnb.optim.AdamW8bit
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type == "SGDNesterov8bit".lower():
|
||||
print(f"use 8-bit SGD with Nesterov optimizer | {optimizer_kwargs}")
|
||||
if "momentum" not in optimizer_kwargs:
|
||||
print(
|
||||
f"8-bit SGD with Nesterov must be with momentum, set momentum to 0.9 / 8-bit SGD with Nesterovはmomentum指定が必須のため0.9に設定します"
|
||||
)
|
||||
optimizer_kwargs["momentum"] = 0.9
|
||||
|
||||
optimizer_class = bnb.optim.SGD8bit
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type == "Lion8bit".lower():
|
||||
print(f"use 8-bit Lion optimizer | {optimizer_kwargs}")
|
||||
try:
|
||||
optimizer_class = bnb.optim.Lion8bit
|
||||
except AttributeError:
|
||||
raise AttributeError(
|
||||
"No Lion8bit. The version of bitsandbytes installed seems to be old. Please install 0.38.0 or later. / Lion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.38.0以上をインストールしてください"
|
||||
)
|
||||
elif optimizer_type == "PagedAdamW8bit".lower():
|
||||
print(f"use 8-bit PagedAdamW optimizer | {optimizer_kwargs}")
|
||||
try:
|
||||
optimizer_class = bnb.optim.PagedAdamW8bit
|
||||
except AttributeError:
|
||||
raise AttributeError(
|
||||
"No PagedAdamW8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedAdamW8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください"
|
||||
)
|
||||
print(f"use 8-bit PagedAdamW optimizer | {optimizer_kwargs}")
|
||||
try:
|
||||
optimizer_class = bnb.optim.PagedAdamW8bit
|
||||
except AttributeError:
|
||||
raise AttributeError(
|
||||
"No PagedAdamW8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedAdamW8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください"
|
||||
)
|
||||
elif optimizer_type == "PagedLion8bit".lower():
|
||||
print(f"use 8-bit Paged Lion optimizer | {optimizer_kwargs}")
|
||||
try:
|
||||
optimizer_class = bnb.optim.PagedLion8bit
|
||||
except AttributeError:
|
||||
raise AttributeError(
|
||||
"No PagedLion8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedLion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください"
|
||||
)
|
||||
print(f"use 8-bit Paged Lion optimizer | {optimizer_kwargs}")
|
||||
try:
|
||||
optimizer_class = bnb.optim.PagedLion8bit
|
||||
except AttributeError:
|
||||
raise AttributeError(
|
||||
"No PagedLion8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedLion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください"
|
||||
)
|
||||
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
@@ -3379,7 +3373,7 @@ def get_optimizer(args, trainable_params):
|
||||
optimizer_class = torch.optim.AdamW
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
else:
|
||||
if optimizer is None:
|
||||
# 任意のoptimizerを使う
|
||||
optimizer_type = args.optimizer_type # lowerでないやつ(微妙)
|
||||
print(f"use {optimizer_type} | {optimizer_kwargs}")
|
||||
|
||||
Reference in New Issue
Block a user