diff --git a/library/train_util.py b/library/train_util.py index 62d82c44..b8feac74 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2740,18 +2740,21 @@ def load_tokenizer(args: argparse.Namespace): def prepare_accelerator(args: argparse.Namespace): if args.logging_dir is None: - log_with = None logging_dir = None else: - log_with = "tensorboard" if args.log_with is None else args.log_with - if log_with in ["wandb", "all"]: - try: - import wandb - except ImportError: - raise ImportError("No wandb / wandb がインストールされていないようです") log_prefix = "" if args.log_prefix is None else args.log_prefix logging_dir = args.logging_dir + "/" + log_prefix + time.strftime("%Y%m%d%H%M%S", time.localtime()) + log_with = "tensorboard" if args.log_with is None else args.log_with + if log_with in ["wandb", "all"]: + if logging_dir is not None: + os.makedirs(logging_dir) + os.environ["WANDB_DIR"] = logging_dir + try: + import wandb + except ImportError: + raise ImportError("No wandb / wandb がインストールされていないようです") + accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision,