mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
feat: support wandb logging
This commit is contained in:
@@ -260,7 +260,7 @@ def train(args):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if accelerator.is_main_process:
|
if accelerator.is_main_process:
|
||||||
accelerator.init_trackers("finetuning")
|
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name)
|
||||||
|
|
||||||
for epoch in range(num_train_epochs):
|
for epoch in range(num_train_epochs):
|
||||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||||
|
|||||||
@@ -2067,7 +2067,15 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|||||||
default=None,
|
default=None,
|
||||||
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する",
|
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--log_with",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
choices=["tensorboard", "wandb", "all"],
|
||||||
|
help="what logging tool(s) to use (if 'all', TensorBoard and WandB are both used) / ログ出力に使用するツール (allを指定するとTensorBoardとWandBの両方が使用される)",
|
||||||
|
)
|
||||||
parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
|
parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
|
||||||
|
parser.add_argument("--log_tracker_name", type=str, default=None, help="name of tracker to use for logging / ログ出力に使用するtrackerの名前")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--noise_offset",
|
"--noise_offset",
|
||||||
type=float,
|
type=float,
|
||||||
@@ -2735,7 +2743,12 @@ def prepare_accelerator(args: argparse.Namespace):
|
|||||||
log_with = None
|
log_with = None
|
||||||
logging_dir = None
|
logging_dir = None
|
||||||
else:
|
else:
|
||||||
log_with = "tensorboard"
|
log_with = "tensorboard" if args.log_with is None else args.log_with
|
||||||
|
if log_with in ["wandb", "all"]:
|
||||||
|
try:
|
||||||
|
import wandb
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("No wandb / wandb がインストールされていないようです")
|
||||||
log_prefix = "" if args.log_prefix is None else args.log_prefix
|
log_prefix = "" if args.log_prefix is None else args.log_prefix
|
||||||
logging_dir = args.logging_dir + "/" + log_prefix + time.strftime("%Y%m%d%H%M%S", time.localtime())
|
logging_dir = args.logging_dir + "/" + log_prefix + time.strftime("%Y%m%d%H%M%S", time.localtime())
|
||||||
|
|
||||||
@@ -3197,6 +3210,20 @@ def sample_images(
|
|||||||
|
|
||||||
image.save(os.path.join(save_dir, img_filename))
|
image.save(os.path.join(save_dir, img_filename))
|
||||||
|
|
||||||
|
# wandb有効時のみログを送信
|
||||||
|
try:
|
||||||
|
wandb_tracker = accelerator.get_tracker("wandb")
|
||||||
|
try:
|
||||||
|
import wandb
|
||||||
|
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
|
||||||
|
raise ImportError("No wandb / wandb がインストールされていないようです")
|
||||||
|
|
||||||
|
wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
|
||||||
|
except: # wandb 無効時
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# clear pipeline and cache to reduce vram usage
|
# clear pipeline and cache to reduce vram usage
|
||||||
del pipeline
|
del pipeline
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
@@ -3205,7 +3232,6 @@ def sample_images(
|
|||||||
torch.cuda.set_rng_state(cuda_rng_state)
|
torch.cuda.set_rng_state(cuda_rng_state)
|
||||||
vae.to(org_vae_device)
|
vae.to(org_vae_device)
|
||||||
|
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
# region 前処理用
|
# region 前処理用
|
||||||
|
|||||||
@@ -231,7 +231,7 @@ def train(args):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if accelerator.is_main_process:
|
if accelerator.is_main_process:
|
||||||
accelerator.init_trackers("dreambooth")
|
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name)
|
||||||
|
|
||||||
loss_list = []
|
loss_list = []
|
||||||
loss_total = 0.0
|
loss_total = 0.0
|
||||||
|
|||||||
@@ -538,7 +538,7 @@ def train(args):
|
|||||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
|
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
|
||||||
)
|
)
|
||||||
if accelerator.is_main_process:
|
if accelerator.is_main_process:
|
||||||
accelerator.init_trackers("network_train")
|
accelerator.init_trackers("network_train" if args.log_tracker_name is None else args.log_tracker_name)
|
||||||
|
|
||||||
loss_list = []
|
loss_list = []
|
||||||
loss_total = 0.0
|
loss_total = 0.0
|
||||||
|
|||||||
@@ -337,7 +337,7 @@ def train(args):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if accelerator.is_main_process:
|
if accelerator.is_main_process:
|
||||||
accelerator.init_trackers("textual_inversion")
|
accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name)
|
||||||
|
|
||||||
for epoch in range(num_train_epochs):
|
for epoch in range(num_train_epochs):
|
||||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||||
|
|||||||
@@ -371,7 +371,7 @@ def train(args):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if accelerator.is_main_process:
|
if accelerator.is_main_process:
|
||||||
accelerator.init_trackers("textual_inversion")
|
accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name)
|
||||||
|
|
||||||
for epoch in range(num_train_epochs):
|
for epoch in range(num_train_epochs):
|
||||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||||
|
|||||||
Reference in New Issue
Block a user