mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
refactor optimizer selection for bnb
This commit is contained in:
@@ -65,6 +65,7 @@ import safetensors.torch
|
|||||||
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
|
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
|
||||||
import library.model_util as model_util
|
import library.model_util as model_util
|
||||||
import library.huggingface_util as huggingface_util
|
import library.huggingface_util as huggingface_util
|
||||||
|
|
||||||
# from library.attention_processors import FlashAttnProcessor
|
# from library.attention_processors import FlashAttnProcessor
|
||||||
# from library.hypernetwork import replace_attentions_for_hypernetwork
|
# from library.hypernetwork import replace_attentions_for_hypernetwork
|
||||||
from library.original_unet import UNet2DConditionModel
|
from library.original_unet import UNet2DConditionModel
|
||||||
@@ -2165,7 +2166,7 @@ def cache_batch_latents(
|
|||||||
info.latents = latent
|
info.latents = latent
|
||||||
if flip_aug:
|
if flip_aug:
|
||||||
info.latents_flipped = flipped_latent
|
info.latents_flipped = flipped_latent
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
@@ -3184,32 +3185,9 @@ def get_optimizer(args, trainable_params):
|
|||||||
# print("optkwargs:", optimizer_kwargs)
|
# print("optkwargs:", optimizer_kwargs)
|
||||||
|
|
||||||
lr = args.learning_rate
|
lr = args.learning_rate
|
||||||
|
optimizer = None
|
||||||
|
|
||||||
if optimizer_type == "AdamW8bit".lower():
|
if optimizer_type == "Lion".lower():
|
||||||
try:
|
|
||||||
import bitsandbytes as bnb
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
|
||||||
print(f"use 8-bit AdamW optimizer | {optimizer_kwargs}")
|
|
||||||
optimizer_class = bnb.optim.AdamW8bit
|
|
||||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
|
||||||
|
|
||||||
elif optimizer_type == "SGDNesterov8bit".lower():
|
|
||||||
try:
|
|
||||||
import bitsandbytes as bnb
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
|
||||||
print(f"use 8-bit SGD with Nesterov optimizer | {optimizer_kwargs}")
|
|
||||||
if "momentum" not in optimizer_kwargs:
|
|
||||||
print(
|
|
||||||
f"8-bit SGD with Nesterov must be with momentum, set momentum to 0.9 / 8-bit SGD with Nesterovはmomentum指定が必須のため0.9に設定します"
|
|
||||||
)
|
|
||||||
optimizer_kwargs["momentum"] = 0.9
|
|
||||||
|
|
||||||
optimizer_class = bnb.optim.SGD8bit
|
|
||||||
optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs)
|
|
||||||
|
|
||||||
elif optimizer_type == "Lion".lower():
|
|
||||||
try:
|
try:
|
||||||
import lion_pytorch
|
import lion_pytorch
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@@ -3217,37 +3195,53 @@ def get_optimizer(args, trainable_params):
|
|||||||
print(f"use Lion optimizer | {optimizer_kwargs}")
|
print(f"use Lion optimizer | {optimizer_kwargs}")
|
||||||
optimizer_class = lion_pytorch.Lion
|
optimizer_class = lion_pytorch.Lion
|
||||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||||
|
|
||||||
elif optimizer_type.endswith("8bit".lower()):
|
elif optimizer_type.endswith("8bit".lower()):
|
||||||
try:
|
try:
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです")
|
raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです")
|
||||||
|
|
||||||
if optimizer_type == "Lion8bit".lower():
|
if optimizer_type == "AdamW8bit".lower():
|
||||||
print(f"use 8-bit Lion optimizer | {optimizer_kwargs}")
|
print(f"use 8-bit AdamW optimizer | {optimizer_kwargs}")
|
||||||
try:
|
optimizer_class = bnb.optim.AdamW8bit
|
||||||
optimizer_class = bnb.optim.Lion8bit
|
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||||
except AttributeError:
|
|
||||||
raise AttributeError(
|
elif optimizer_type == "SGDNesterov8bit".lower():
|
||||||
"No Lion8bit. The version of bitsandbytes installed seems to be old. Please install 0.38.0 or later. / Lion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.38.0以上をインストールしてください"
|
print(f"use 8-bit SGD with Nesterov optimizer | {optimizer_kwargs}")
|
||||||
)
|
if "momentum" not in optimizer_kwargs:
|
||||||
|
print(
|
||||||
|
f"8-bit SGD with Nesterov must be with momentum, set momentum to 0.9 / 8-bit SGD with Nesterovはmomentum指定が必須のため0.9に設定します"
|
||||||
|
)
|
||||||
|
optimizer_kwargs["momentum"] = 0.9
|
||||||
|
|
||||||
|
optimizer_class = bnb.optim.SGD8bit
|
||||||
|
optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs)
|
||||||
|
|
||||||
|
elif optimizer_type == "Lion8bit".lower():
|
||||||
|
print(f"use 8-bit Lion optimizer | {optimizer_kwargs}")
|
||||||
|
try:
|
||||||
|
optimizer_class = bnb.optim.Lion8bit
|
||||||
|
except AttributeError:
|
||||||
|
raise AttributeError(
|
||||||
|
"No Lion8bit. The version of bitsandbytes installed seems to be old. Please install 0.38.0 or later. / Lion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.38.0以上をインストールしてください"
|
||||||
|
)
|
||||||
elif optimizer_type == "PagedAdamW8bit".lower():
|
elif optimizer_type == "PagedAdamW8bit".lower():
|
||||||
print(f"use 8-bit PagedAdamW optimizer | {optimizer_kwargs}")
|
print(f"use 8-bit PagedAdamW optimizer | {optimizer_kwargs}")
|
||||||
try:
|
try:
|
||||||
optimizer_class = bnb.optim.PagedAdamW8bit
|
optimizer_class = bnb.optim.PagedAdamW8bit
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
raise AttributeError(
|
raise AttributeError(
|
||||||
"No PagedAdamW8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedAdamW8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください"
|
"No PagedAdamW8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedAdamW8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください"
|
||||||
)
|
)
|
||||||
elif optimizer_type == "PagedLion8bit".lower():
|
elif optimizer_type == "PagedLion8bit".lower():
|
||||||
print(f"use 8-bit Paged Lion optimizer | {optimizer_kwargs}")
|
print(f"use 8-bit Paged Lion optimizer | {optimizer_kwargs}")
|
||||||
try:
|
try:
|
||||||
optimizer_class = bnb.optim.PagedLion8bit
|
optimizer_class = bnb.optim.PagedLion8bit
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
raise AttributeError(
|
raise AttributeError(
|
||||||
"No PagedLion8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedLion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください"
|
"No PagedLion8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedLion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください"
|
||||||
)
|
)
|
||||||
|
|
||||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||||
|
|
||||||
@@ -3379,7 +3373,7 @@ def get_optimizer(args, trainable_params):
|
|||||||
optimizer_class = torch.optim.AdamW
|
optimizer_class = torch.optim.AdamW
|
||||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||||
|
|
||||||
else:
|
if optimizer is None:
|
||||||
# 任意のoptimizerを使う
|
# 任意のoptimizerを使う
|
||||||
optimizer_type = args.optimizer_type # lowerでないやつ(微妙)
|
optimizer_type = args.optimizer_type # lowerでないやつ(微妙)
|
||||||
print(f"use {optimizer_type} | {optimizer_kwargs}")
|
print(f"use {optimizer_type} | {optimizer_kwargs}")
|
||||||
|
|||||||
Reference in New Issue
Block a user