From a69b24a06967284a76217ab9d2a9b351a555447c Mon Sep 17 00:00:00 2001 From: Plat Date: Thu, 20 Apr 2023 05:33:32 +0900 Subject: [PATCH] fix: tensorboard not working --- library/train_util.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index b8feac74..b06caf3d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2745,15 +2745,19 @@ def prepare_accelerator(args: argparse.Namespace): log_prefix = "" if args.log_prefix is None else args.log_prefix logging_dir = args.logging_dir + "/" + log_prefix + time.strftime("%Y%m%d%H%M%S", time.localtime()) - log_with = "tensorboard" if args.log_with is None else args.log_with - if log_with in ["wandb", "all"]: - if logging_dir is not None: - os.makedirs(logging_dir) - os.environ["WANDB_DIR"] = logging_dir - try: - import wandb - except ImportError: - raise ImportError("No wandb / wandb がインストールされていないようです") + if args.log_with is not None: + log_with = "tensorboard" if args.log_with is None else args.log_with + if log_with in ["tensorboard", "all"]: + if logging_dir is None: + raise ValueError("logging_dir is required when log_with is tensorboard / Tensorboardを使う場合、logging_dirを指定してください") + if log_with in ["wandb", "all"]: + try: + import wandb + except ImportError: + raise ImportError("No wandb / wandb がインストールされていないようです") + if logging_dir is not None: + os.makedirs(logging_dir) + os.environ["WANDB_DIR"] = logging_dir accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps,