feat: support wandb logging

This commit is contained in:
Plat
2023-04-20 01:41:12 +09:00
parent 334589af4e
commit 27ffd9fe3d
6 changed files with 33 additions and 7 deletions

View File

@@ -538,7 +538,7 @@ def train(args):
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
)
if accelerator.is_main_process:
accelerator.init_trackers("network_train")
accelerator.init_trackers("network_train" if args.log_tracker_name is None else args.log_tracker_name)
loss_list = []
loss_total = 0.0