feat: support wandb logging

This commit is contained in:
Plat
2023-04-20 01:41:12 +09:00
parent 334589af4e
commit 27ffd9fe3d
6 changed files with 33 additions and 7 deletions

View File

@@ -2067,7 +2067,15 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
default=None,
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する",
)
parser.add_argument(
"--log_with",
type=str,
default=None,
choices=["tensorboard", "wandb", "all"],
help="what logging tool(s) to use (if 'all', TensorBoard and WandB are both used) / ログ出力に使用するツール (allを指定するとTensorBoardとWandBの両方が使用される)",
)
parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
parser.add_argument("--log_tracker_name", type=str, default=None, help="name of tracker to use for logging / ログ出力に使用するtrackerの名前")
parser.add_argument(
"--noise_offset",
type=float,
@@ -2735,7 +2743,12 @@ def prepare_accelerator(args: argparse.Namespace):
log_with = None
logging_dir = None
else:
log_with = "tensorboard"
log_with = "tensorboard" if args.log_with is None else args.log_with
if log_with in ["wandb", "all"]:
try:
import wandb
except ImportError:
raise ImportError("No wandb / wandb がインストールされていないようです")
log_prefix = "" if args.log_prefix is None else args.log_prefix
logging_dir = args.logging_dir + "/" + log_prefix + time.strftime("%Y%m%d%H%M%S", time.localtime())
@@ -3197,6 +3210,20 @@ def sample_images(
image.save(os.path.join(save_dir, img_filename))
# wandb有効時のみログを送信
try:
wandb_tracker = accelerator.get_tracker("wandb")
try:
import wandb
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
raise ImportError("No wandb / wandb がインストールされていないようです")
wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
except: # wandb 無効時
pass
# clear pipeline and cache to reduce vram usage
del pipeline
torch.cuda.empty_cache()
@@ -3205,7 +3232,6 @@ def sample_images(
torch.cuda.set_rng_state(cuda_rng_state)
vae.to(org_vae_device)
# endregion
# region 前処理用