mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
Merge pull request #209 from BootsofLagrangian/dadaptation
Dadaptation optimizer
This commit is contained in:
10
fine_tune.py
10
fine_tune.py
@@ -165,6 +165,16 @@ def train(args):
|
||||
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
|
||||
print("use Lion optimizer")
|
||||
optimizer_class = lion_pytorch.Lion
|
||||
elif args.use_dadaptation_optimizer:
|
||||
try:
|
||||
import dadaptation
|
||||
except ImportError:
|
||||
raise ImportError("No dadaptation / dadaptation がインストールされていないようです")
|
||||
print("use dadaptation optimizer")
|
||||
optimizer_class = dadaptation.DAdaptAdam
|
||||
if args.learning_rate <= 0.1:
|
||||
print('learning rate is too low. If using dadaptaion, set learning rate around 1.0.')
|
||||
print('recommend option: lr=1.0')
|
||||
else:
|
||||
optimizer_class = torch.optim.AdamW
|
||||
|
||||
|
||||
@@ -1391,6 +1391,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)")
|
||||
parser.add_argument("--use_lion_optimizer", action="store_true",
|
||||
help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)")
|
||||
parser.add_argument("--use_dadaptation_optimizer", action="store_true",
|
||||
help="use dadaptation optimizer (requires dadaptation) / dadaptaionオプティマイザを使う( dadaptation のインストールが必要)")
|
||||
parser.add_argument("--mem_eff_attn", action="store_true",
|
||||
help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
|
||||
parser.add_argument("--xformers", action="store_true",
|
||||
|
||||
@@ -37,6 +37,9 @@ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_sche
|
||||
logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
|
||||
logs["lr/unet"] = lr_scheduler.get_last_lr()[-1] # may be same to textencoder
|
||||
|
||||
if args.use_dadaptation_optimizer: # tracking d*lr value of unet.
|
||||
logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]['d']*lr_scheduler.optimizers[-1].param_groups[0]['lr']
|
||||
|
||||
return logs
|
||||
|
||||
|
||||
@@ -223,6 +226,18 @@ def train(args):
|
||||
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
|
||||
print("use Lion optimizer")
|
||||
optimizer_class = lion_pytorch.Lion
|
||||
elif args.use_dadaptation_optimizer:
|
||||
try:
|
||||
import dadaptation
|
||||
except ImportError:
|
||||
raise ImportError("No dadaptation / dadaptation がインストールされていないようです")
|
||||
print("use dadaptation optimizer")
|
||||
optimizer_class = dadaptation.DAdaptAdam
|
||||
if args.network_dim > args.network_alpha:
|
||||
print('network dimension is greater than network alpha. It possibly makes network blow up.')
|
||||
if args.learning_rate <= 0.1 or args.text_encoder_lr <= 0.1 or args.unet_lr <= 0.1:
|
||||
print('learning rate is too low. If using dadaptaion, set learning rate around 1.0.')
|
||||
print('recommend option: lr=1.0, unet_lr=1.0, txtencoder_lr=0.5')
|
||||
else:
|
||||
optimizer_class = torch.optim.AdamW
|
||||
|
||||
|
||||
Reference in New Issue
Block a user