fix: tensorboard not working

This commit is contained in:
Plat
2023-04-20 05:33:32 +09:00
parent 12567f55cd
commit a69b24a069

View File

@@ -2745,15 +2745,19 @@ def prepare_accelerator(args: argparse.Namespace):
log_prefix = "" if args.log_prefix is None else args.log_prefix log_prefix = "" if args.log_prefix is None else args.log_prefix
logging_dir = args.logging_dir + "/" + log_prefix + time.strftime("%Y%m%d%H%M%S", time.localtime()) logging_dir = args.logging_dir + "/" + log_prefix + time.strftime("%Y%m%d%H%M%S", time.localtime())
if args.log_with is not None:
log_with = "tensorboard" if args.log_with is None else args.log_with log_with = "tensorboard" if args.log_with is None else args.log_with
if log_with in ["tensorboard", "all"]:
if logging_dir is None:
raise ValueError("logging_dir is required when log_with is tensorboard / Tensorboardを使う場合、logging_dirを指定してください")
if log_with in ["wandb", "all"]: if log_with in ["wandb", "all"]:
if logging_dir is not None:
os.makedirs(logging_dir)
os.environ["WANDB_DIR"] = logging_dir
try: try:
import wandb import wandb
except ImportError: except ImportError:
raise ImportError("No wandb / wandb がインストールされていないようです") raise ImportError("No wandb / wandb がインストールされていないようです")
if logging_dir is not None:
os.makedirs(logging_dir)
os.environ["WANDB_DIR"] = logging_dir
accelerator = Accelerator( accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,