mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
feat: support wandb logging
This commit is contained in:
@@ -2067,7 +2067,15 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
default=None,
|
||||
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log_with",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["tensorboard", "wandb", "all"],
|
||||
help="what logging tool(s) to use (if 'all', TensorBoard and WandB are both used) / ログ出力に使用するツール (allを指定するとTensorBoardとWandBの両方が使用される)",
|
||||
)
|
||||
parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
|
||||
parser.add_argument("--log_tracker_name", type=str, default=None, help="name of tracker to use for logging / ログ出力に使用するtrackerの名前")
|
||||
parser.add_argument(
|
||||
"--noise_offset",
|
||||
type=float,
|
||||
@@ -2735,7 +2743,12 @@ def prepare_accelerator(args: argparse.Namespace):
|
||||
log_with = None
|
||||
logging_dir = None
|
||||
else:
|
||||
log_with = "tensorboard"
|
||||
log_with = "tensorboard" if args.log_with is None else args.log_with
|
||||
if log_with in ["wandb", "all"]:
|
||||
try:
|
||||
import wandb
|
||||
except ImportError:
|
||||
raise ImportError("No wandb / wandb がインストールされていないようです")
|
||||
log_prefix = "" if args.log_prefix is None else args.log_prefix
|
||||
logging_dir = args.logging_dir + "/" + log_prefix + time.strftime("%Y%m%d%H%M%S", time.localtime())
|
||||
|
||||
@@ -3197,6 +3210,20 @@ def sample_images(
|
||||
|
||||
image.save(os.path.join(save_dir, img_filename))
|
||||
|
||||
# wandb有効時のみログを送信
|
||||
try:
|
||||
wandb_tracker = accelerator.get_tracker("wandb")
|
||||
try:
|
||||
import wandb
|
||||
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
|
||||
raise ImportError("No wandb / wandb がインストールされていないようです")
|
||||
|
||||
wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
|
||||
except: # wandb 無効時
|
||||
pass
|
||||
|
||||
|
||||
|
||||
# clear pipeline and cache to reduce vram usage
|
||||
del pipeline
|
||||
torch.cuda.empty_cache()
|
||||
@@ -3205,7 +3232,6 @@ def sample_images(
|
||||
torch.cuda.set_rng_state(cuda_rng_state)
|
||||
vae.to(org_vae_device)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region 前処理用
|
||||
|
||||
Reference in New Issue
Block a user