diff --git a/library/train_util.py b/library/train_util.py index 29956df3..46b76df4 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -65,6 +65,7 @@ import safetensors.torch from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline import library.model_util as model_util import library.huggingface_util as huggingface_util + # from library.attention_processors import FlashAttnProcessor # from library.hypernetwork import replace_attentions_for_hypernetwork from library.original_unet import UNet2DConditionModel @@ -2165,7 +2166,7 @@ def cache_batch_latents( info.latents = latent if flip_aug: info.latents_flipped = flipped_latent - + if torch.cuda.is_available(): torch.cuda.empty_cache() @@ -3184,32 +3185,9 @@ def get_optimizer(args, trainable_params): # print("optkwargs:", optimizer_kwargs) lr = args.learning_rate + optimizer = None - if optimizer_type == "AdamW8bit".lower(): - try: - import bitsandbytes as bnb - except ImportError: - raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです") - print(f"use 8-bit AdamW optimizer | {optimizer_kwargs}") - optimizer_class = bnb.optim.AdamW8bit - optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) - - elif optimizer_type == "SGDNesterov8bit".lower(): - try: - import bitsandbytes as bnb - except ImportError: - raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです") - print(f"use 8-bit SGD with Nesterov optimizer | {optimizer_kwargs}") - if "momentum" not in optimizer_kwargs: - print( - f"8-bit SGD with Nesterov must be with momentum, set momentum to 0.9 / 8-bit SGD with Nesterovはmomentum指定が必須のため0.9に設定します" - ) - optimizer_kwargs["momentum"] = 0.9 - - optimizer_class = bnb.optim.SGD8bit - optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs) - - elif optimizer_type == "Lion".lower(): + if optimizer_type == "Lion".lower(): try: import lion_pytorch except ImportError: @@ -3217,37 +3195,53 @@ def get_optimizer(args, trainable_params): print(f"use Lion optimizer | {optimizer_kwargs}") optimizer_class = lion_pytorch.Lion optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) - + elif optimizer_type.endswith("8bit".lower()): try: import bitsandbytes as bnb except ImportError: raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです") - if optimizer_type == "Lion8bit".lower(): - print(f"use 8-bit Lion optimizer | {optimizer_kwargs}") - try: - optimizer_class = bnb.optim.Lion8bit - except AttributeError: - raise AttributeError( - "No Lion8bit. The version of bitsandbytes installed seems to be old. Please install 0.38.0 or later. / Lion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.38.0以上をインストールしてください" - ) + if optimizer_type == "AdamW8bit".lower(): + print(f"use 8-bit AdamW optimizer | {optimizer_kwargs}") + optimizer_class = bnb.optim.AdamW8bit + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + + elif optimizer_type == "SGDNesterov8bit".lower(): + print(f"use 8-bit SGD with Nesterov optimizer | {optimizer_kwargs}") + if "momentum" not in optimizer_kwargs: + print( + f"8-bit SGD with Nesterov must be with momentum, set momentum to 0.9 / 8-bit SGD with Nesterovはmomentum指定が必須のため0.9に設定します" + ) + optimizer_kwargs["momentum"] = 0.9 + + optimizer_class = bnb.optim.SGD8bit + optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs) + + elif optimizer_type == "Lion8bit".lower(): + print(f"use 8-bit Lion optimizer | {optimizer_kwargs}") + try: + optimizer_class = bnb.optim.Lion8bit + except AttributeError: + raise AttributeError( + "No Lion8bit. The version of bitsandbytes installed seems to be old. Please install 0.38.0 or later. / Lion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.38.0以上をインストールしてください" + ) elif optimizer_type == "PagedAdamW8bit".lower(): - print(f"use 8-bit PagedAdamW optimizer | {optimizer_kwargs}") - try: - optimizer_class = bnb.optim.PagedAdamW8bit - except AttributeError: - raise AttributeError( - "No PagedAdamW8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedAdamW8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" - ) + print(f"use 8-bit PagedAdamW optimizer | {optimizer_kwargs}") + try: + optimizer_class = bnb.optim.PagedAdamW8bit + except AttributeError: + raise AttributeError( + "No PagedAdamW8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedAdamW8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" + ) elif optimizer_type == "PagedLion8bit".lower(): - print(f"use 8-bit Paged Lion optimizer | {optimizer_kwargs}") - try: - optimizer_class = bnb.optim.PagedLion8bit - except AttributeError: - raise AttributeError( - "No PagedLion8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedLion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" - ) + print(f"use 8-bit Paged Lion optimizer | {optimizer_kwargs}") + try: + optimizer_class = bnb.optim.PagedLion8bit + except AttributeError: + raise AttributeError( + "No PagedLion8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedLion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" + ) optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) @@ -3379,7 +3373,7 @@ def get_optimizer(args, trainable_params): optimizer_class = torch.optim.AdamW optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) - else: + if optimizer is None: # 任意のoptimizerを使う optimizer_type = args.optimizer_type # lowerでないやつ(微妙) print(f"use {optimizer_type} | {optimizer_kwargs}")