apply dadaptation

This commit is contained in:
unknown
2023-02-19 18:37:07 +09:00
parent 08ae46b163
commit 045a3dbe48
3 changed files with 27 additions and 0 deletions

View File

@@ -37,6 +37,9 @@ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_sche
logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
logs["lr/unet"] = lr_scheduler.get_last_lr()[-1] # may be same to textencoder
if args.use_dadaptation_optimizer: # tracking d*lr value of unet.
logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]['d']*lr_scheduler.optimizers[-1].param_groups[0]['lr']
return logs
@@ -223,6 +226,18 @@ def train(args):
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
print("use Lion optimizer")
optimizer_class = lion_pytorch.Lion
elif args.use_dadaptation_optimizer:
try:
import dadaptation
except ImportError:
raise ImportError("No dadaptation / dadaptation がインストールされていないようです")
print("use dadaptation optimizer")
optimizer_class = dadaptation.DAdaptAdam
if args.network_dim > args.network_alpha:
print('network dimension is greater than network alpha. It possibly makes network blow up.')
if args.learning_rate <= 0.1 or args.text_encoder_lr <= 0.1 or args.unet_lr <= 0.1:
print('learning rate is too low. If using dadaptaion, set learning rate around 1.0.')
print('recommend option: lr=1.0, unet_lr=1.0, txtencoder_lr=0.5')
else:
optimizer_class = torch.optim.AdamW