mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
refactor optimizer selection for bnb
This commit is contained in:
@@ -65,6 +65,7 @@ import safetensors.torch
|
|||||||
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
|
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
|
||||||
import library.model_util as model_util
|
import library.model_util as model_util
|
||||||
import library.huggingface_util as huggingface_util
|
import library.huggingface_util as huggingface_util
|
||||||
|
|
||||||
# from library.attention_processors import FlashAttnProcessor
|
# from library.attention_processors import FlashAttnProcessor
|
||||||
# from library.hypernetwork import replace_attentions_for_hypernetwork
|
# from library.hypernetwork import replace_attentions_for_hypernetwork
|
||||||
from library.original_unet import UNet2DConditionModel
|
from library.original_unet import UNet2DConditionModel
|
||||||
@@ -3184,32 +3185,9 @@ def get_optimizer(args, trainable_params):
|
|||||||
# print("optkwargs:", optimizer_kwargs)
|
# print("optkwargs:", optimizer_kwargs)
|
||||||
|
|
||||||
lr = args.learning_rate
|
lr = args.learning_rate
|
||||||
|
optimizer = None
|
||||||
|
|
||||||
if optimizer_type == "AdamW8bit".lower():
|
if optimizer_type == "Lion".lower():
|
||||||
try:
|
|
||||||
import bitsandbytes as bnb
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
|
||||||
print(f"use 8-bit AdamW optimizer | {optimizer_kwargs}")
|
|
||||||
optimizer_class = bnb.optim.AdamW8bit
|
|
||||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
|
||||||
|
|
||||||
elif optimizer_type == "SGDNesterov8bit".lower():
|
|
||||||
try:
|
|
||||||
import bitsandbytes as bnb
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
|
||||||
print(f"use 8-bit SGD with Nesterov optimizer | {optimizer_kwargs}")
|
|
||||||
if "momentum" not in optimizer_kwargs:
|
|
||||||
print(
|
|
||||||
f"8-bit SGD with Nesterov must be with momentum, set momentum to 0.9 / 8-bit SGD with Nesterovはmomentum指定が必須のため0.9に設定します"
|
|
||||||
)
|
|
||||||
optimizer_kwargs["momentum"] = 0.9
|
|
||||||
|
|
||||||
optimizer_class = bnb.optim.SGD8bit
|
|
||||||
optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs)
|
|
||||||
|
|
||||||
elif optimizer_type == "Lion".lower():
|
|
||||||
try:
|
try:
|
||||||
import lion_pytorch
|
import lion_pytorch
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@@ -3224,7 +3202,23 @@ def get_optimizer(args, trainable_params):
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです")
|
raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです")
|
||||||
|
|
||||||
if optimizer_type == "Lion8bit".lower():
|
if optimizer_type == "AdamW8bit".lower():
|
||||||
|
print(f"use 8-bit AdamW optimizer | {optimizer_kwargs}")
|
||||||
|
optimizer_class = bnb.optim.AdamW8bit
|
||||||
|
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||||
|
|
||||||
|
elif optimizer_type == "SGDNesterov8bit".lower():
|
||||||
|
print(f"use 8-bit SGD with Nesterov optimizer | {optimizer_kwargs}")
|
||||||
|
if "momentum" not in optimizer_kwargs:
|
||||||
|
print(
|
||||||
|
f"8-bit SGD with Nesterov must be with momentum, set momentum to 0.9 / 8-bit SGD with Nesterovはmomentum指定が必須のため0.9に設定します"
|
||||||
|
)
|
||||||
|
optimizer_kwargs["momentum"] = 0.9
|
||||||
|
|
||||||
|
optimizer_class = bnb.optim.SGD8bit
|
||||||
|
optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs)
|
||||||
|
|
||||||
|
elif optimizer_type == "Lion8bit".lower():
|
||||||
print(f"use 8-bit Lion optimizer | {optimizer_kwargs}")
|
print(f"use 8-bit Lion optimizer | {optimizer_kwargs}")
|
||||||
try:
|
try:
|
||||||
optimizer_class = bnb.optim.Lion8bit
|
optimizer_class = bnb.optim.Lion8bit
|
||||||
@@ -3379,7 +3373,7 @@ def get_optimizer(args, trainable_params):
|
|||||||
optimizer_class = torch.optim.AdamW
|
optimizer_class = torch.optim.AdamW
|
||||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||||
|
|
||||||
else:
|
if optimizer is None:
|
||||||
# 任意のoptimizerを使う
|
# 任意のoptimizerを使う
|
||||||
optimizer_type = args.optimizer_type # lowerでないやつ(微妙)
|
optimizer_type = args.optimizer_type # lowerでないやつ(微妙)
|
||||||
print(f"use {optimizer_type} | {optimizer_kwargs}")
|
print(f"use {optimizer_type} | {optimizer_kwargs}")
|
||||||
|
|||||||
Reference in New Issue
Block a user