feat: support wandb logging

This commit is contained in:
Plat
2023-04-20 01:41:12 +09:00
parent 334589af4e
commit 27ffd9fe3d
6 changed files with 33 additions and 7 deletions

View File

@@ -260,7 +260,7 @@ def train(args):
) )
if accelerator.is_main_process: if accelerator.is_main_process:
accelerator.init_trackers("finetuning") accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name)
for epoch in range(num_train_epochs): for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}") print(f"epoch {epoch+1}/{num_train_epochs}")

View File

@@ -2067,7 +2067,15 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
default=None, default=None,
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する", help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する",
) )
parser.add_argument(
"--log_with",
type=str,
default=None,
choices=["tensorboard", "wandb", "all"],
help="what logging tool(s) to use (if 'all', TensorBoard and WandB are both used) / ログ出力に使用するツール (allを指定するとTensorBoardとWandBの両方が使用される)",
)
parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列") parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
parser.add_argument("--log_tracker_name", type=str, default=None, help="name of tracker to use for logging / ログ出力に使用するtrackerの名前")
parser.add_argument( parser.add_argument(
"--noise_offset", "--noise_offset",
type=float, type=float,
@@ -2735,7 +2743,12 @@ def prepare_accelerator(args: argparse.Namespace):
log_with = None log_with = None
logging_dir = None logging_dir = None
else: else:
log_with = "tensorboard" log_with = "tensorboard" if args.log_with is None else args.log_with
if log_with in ["wandb", "all"]:
try:
import wandb
except ImportError:
raise ImportError("No wandb / wandb がインストールされていないようです")
log_prefix = "" if args.log_prefix is None else args.log_prefix log_prefix = "" if args.log_prefix is None else args.log_prefix
logging_dir = args.logging_dir + "/" + log_prefix + time.strftime("%Y%m%d%H%M%S", time.localtime()) logging_dir = args.logging_dir + "/" + log_prefix + time.strftime("%Y%m%d%H%M%S", time.localtime())
@@ -3197,6 +3210,20 @@ def sample_images(
image.save(os.path.join(save_dir, img_filename)) image.save(os.path.join(save_dir, img_filename))
# wandb有効時のみログを送信
try:
wandb_tracker = accelerator.get_tracker("wandb")
try:
import wandb
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
raise ImportError("No wandb / wandb がインストールされていないようです")
wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
except: # wandb 無効時
pass
# clear pipeline and cache to reduce vram usage # clear pipeline and cache to reduce vram usage
del pipeline del pipeline
torch.cuda.empty_cache() torch.cuda.empty_cache()
@@ -3205,7 +3232,6 @@ def sample_images(
torch.cuda.set_rng_state(cuda_rng_state) torch.cuda.set_rng_state(cuda_rng_state)
vae.to(org_vae_device) vae.to(org_vae_device)
# endregion # endregion
# region 前処理用 # region 前処理用

View File

@@ -231,7 +231,7 @@ def train(args):
) )
if accelerator.is_main_process: if accelerator.is_main_process:
accelerator.init_trackers("dreambooth") accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name)
loss_list = [] loss_list = []
loss_total = 0.0 loss_total = 0.0

View File

@@ -538,7 +538,7 @@ def train(args):
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
) )
if accelerator.is_main_process: if accelerator.is_main_process:
accelerator.init_trackers("network_train") accelerator.init_trackers("network_train" if args.log_tracker_name is None else args.log_tracker_name)
loss_list = [] loss_list = []
loss_total = 0.0 loss_total = 0.0

View File

@@ -337,7 +337,7 @@ def train(args):
) )
if accelerator.is_main_process: if accelerator.is_main_process:
accelerator.init_trackers("textual_inversion") accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name)
for epoch in range(num_train_epochs): for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}") print(f"epoch {epoch+1}/{num_train_epochs}")

View File

@@ -371,7 +371,7 @@ def train(args):
) )
if accelerator.is_main_process: if accelerator.is_main_process:
accelerator.init_trackers("textual_inversion") accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name)
for epoch in range(num_train_epochs): for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}") print(f"epoch {epoch+1}/{num_train_epochs}")