Merge pull request #1277 from Cauldrath/negative_learning

Allow negative learning rate
This commit is contained in:
Kohya S
2024-05-19 18:55:50 +09:00
committed by GitHub

View File

@@ -272,7 +272,7 @@ def train(args):
# 学習を準備する:モデルを適切な状態にする
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
train_unet = args.learning_rate > 0
train_unet = args.learning_rate != 0
train_text_encoder1 = False
train_text_encoder2 = False
@@ -284,8 +284,8 @@ def train(args):
text_encoder2.gradient_checkpointing_enable()
lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train
lr_te2 = args.learning_rate_te2 if args.learning_rate_te2 is not None else args.learning_rate # 0 means not train
train_text_encoder1 = lr_te1 > 0
train_text_encoder2 = lr_te2 > 0
train_text_encoder1 = lr_te1 != 0
train_text_encoder2 = lr_te2 != 0
# caching one text encoder output is not supported
if not train_text_encoder1: