add lion optimizer support

This commit is contained in:
Kohya S
2023-02-19 15:26:14 +09:00
parent a76ad2d1d5
commit 048e7cd428
5 changed files with 34 additions and 2 deletions

View File

@@ -158,6 +158,13 @@ def train(args):
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです") raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
print("use 8-bit Adam optimizer") print("use 8-bit Adam optimizer")
optimizer_class = bnb.optim.AdamW8bit optimizer_class = bnb.optim.AdamW8bit
elif args.use_lion_optimizer:
try:
import lion_pytorch
except ImportError:
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
print("use Lion optimizer")
optimizer_class = lion_pytorch.Lion
else: else:
optimizer_class = torch.optim.AdamW optimizer_class = torch.optim.AdamW

View File

@@ -1389,6 +1389,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長未指定で75、150または225が指定可") help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長未指定で75、150または225が指定可")
parser.add_argument("--use_8bit_adam", action="store_true", parser.add_argument("--use_8bit_adam", action="store_true",
help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使うbitsandbytesのインストールが必要") help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使うbitsandbytesのインストールが必要")
parser.add_argument("--use_lion_optimizer", action="store_true",
help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う lion-pytorch のインストールが必要)")
parser.add_argument("--mem_eff_attn", action="store_true", parser.add_argument("--mem_eff_attn", action="store_true",
help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う") help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
parser.add_argument("--xformers", action="store_true", parser.add_argument("--xformers", action="store_true",
@@ -1424,8 +1426,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
parser.add_argument("--noise_offset", type=float, default=None, parser.add_argument("--noise_offset", type=float, default=None,
help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する有効にする場合は0.1程度を推奨)") help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する有効にする場合は0.1程度を推奨)")
parser.add_argument("--lowram", action="store_true", parser.add_argument("--lowram", action="store_true",
help="load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle)") help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込むなどColabやKaggleなどRAMに比べてVRAMが多い環境向け")
if support_dreambooth: if support_dreambooth:
# DreamBooth training # DreamBooth training
parser.add_argument("--prior_loss_weight", type=float, default=1.0, parser.add_argument("--prior_loss_weight", type=float, default=1.0,

View File

@@ -124,6 +124,13 @@ def train(args):
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです") raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
print("use 8-bit Adam optimizer") print("use 8-bit Adam optimizer")
optimizer_class = bnb.optim.AdamW8bit optimizer_class = bnb.optim.AdamW8bit
elif args.use_lion_optimizer:
try:
import lion_pytorch
except ImportError:
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
print("use Lion optimizer")
optimizer_class = lion_pytorch.Lion
else: else:
optimizer_class = torch.optim.AdamW optimizer_class = torch.optim.AdamW

View File

@@ -156,10 +156,12 @@ def train(args):
# モデルを読み込む # モデルを読み込む
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype) text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype)
# work on low-ram device # work on low-ram device
if args.lowram: if args.lowram:
text_encoder.to("cuda") text_encoder.to("cuda")
unet.to("cuda") unet.to("cuda")
# モデルに xformers とか memory efficient attention を組み込む # モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
@@ -214,6 +216,13 @@ def train(args):
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです") raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
print("use 8-bit Adam optimizer") print("use 8-bit Adam optimizer")
optimizer_class = bnb.optim.AdamW8bit optimizer_class = bnb.optim.AdamW8bit
elif args.use_lion_optimizer:
try:
import lion_pytorch
except ImportError:
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
print("use Lion optimizer")
optimizer_class = lion_pytorch.Lion
else: else:
optimizer_class = torch.optim.AdamW optimizer_class = torch.optim.AdamW

View File

@@ -207,6 +207,13 @@ def train(args):
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです") raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
print("use 8-bit Adam optimizer") print("use 8-bit Adam optimizer")
optimizer_class = bnb.optim.AdamW8bit optimizer_class = bnb.optim.AdamW8bit
elif args.use_lion_optimizer:
try:
import lion_pytorch
except ImportError:
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
print("use Lion optimizer")
optimizer_class = lion_pytorch.Lion
else: else:
optimizer_class = torch.optim.AdamW optimizer_class = torch.optim.AdamW