diff --git a/library/train_util.py b/library/train_util.py index 8e91de05..cfb5b7ee 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2075,7 +2075,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: help="what logging tool(s) to use (if 'all', TensorBoard and WandB are both used) / ログ出力に使用するツール (allを指定するとTensorBoardとWandBの両方が使用される)", ) parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列") - parser.add_argument("--log_tracker_name", type=str, default=None, help="name of tracker to use for logging / ログ出力に使用するtrackerの名前") + parser.add_argument( + "--log_tracker_name", type=str, default=None, help="name of tracker to use for logging, default is script-specific default name / ログ出力に使用するtrackerの名前、省略時はスクリプトごとのデフォルト名" + ) parser.add_argument( "--noise_offset", type=float, @@ -2746,7 +2748,7 @@ def prepare_accelerator(args: argparse.Namespace): logging_dir = args.logging_dir + "/" + log_prefix + time.strftime("%Y%m%d%H%M%S", time.localtime()) if args.log_with is not None: - log_with = "tensorboard" if args.log_with is None else args.log_with + log_with = args.log_with if log_with in ["tensorboard", "all"]: if logging_dir is None: raise ValueError("logging_dir is required when log_with is tensorboard / Tensorboardを使う場合、logging_dirを指定してください") @@ -2756,7 +2758,7 @@ def prepare_accelerator(args: argparse.Namespace): except ImportError: raise ImportError("No wandb / wandb がインストールされていないようです") if logging_dir is not None: - os.makedirs(logging_dir) + os.makedirs(logging_dir, exist_ok=True) os.environ["WANDB_DIR"] = logging_dir accelerator = Accelerator( @@ -3222,14 +3224,12 @@ def sample_images( wandb_tracker = accelerator.get_tracker("wandb") try: import wandb - except ImportError: # 事前に一度確認するのでここはエラー出ないはず + except ImportError: # 事前に一度確認するのでここはエラー出ないはず raise ImportError("No wandb / wandb がインストールされていないようです") - + wandb_tracker.log({f"sample_{i}": wandb.Image(image)}) - except: # wandb 無効時 + except: # wandb 無効時 pass - - # clear pipeline and cache to reduce vram usage del pipeline @@ -3239,6 +3239,7 @@ def sample_images( torch.cuda.set_rng_state(cuda_rng_state) vae.to(org_vae_device) + # endregion # region 前処理用