Merge pull request #428 from p1atdev/dev

Add WandB logging support
This commit is contained in:
Kohya S
2023-04-22 09:39:01 +09:00
committed by GitHub
7 changed files with 43 additions and 9 deletions

3
.gitignore vendored
View File

@@ -4,4 +4,5 @@ wd14_tagger_model
venv venv
*.egg-info *.egg-info
build build
.vscode .vscode
wandb

View File

@@ -260,7 +260,7 @@ def train(args):
) )
if accelerator.is_main_process: if accelerator.is_main_process:
accelerator.init_trackers("finetuning") accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name)
for epoch in range(num_train_epochs): for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}") print(f"epoch {epoch+1}/{num_train_epochs}")

View File

@@ -2067,7 +2067,15 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
default=None, default=None,
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する", help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する",
) )
parser.add_argument(
"--log_with",
type=str,
default=None,
choices=["tensorboard", "wandb", "all"],
help="what logging tool(s) to use (if 'all', TensorBoard and WandB are both used) / ログ出力に使用するツール (allを指定するとTensorBoardとWandBの両方が使用される)",
)
parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列") parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
parser.add_argument("--log_tracker_name", type=str, default=None, help="name of tracker to use for logging / ログ出力に使用するtrackerの名前")
parser.add_argument( parser.add_argument(
"--noise_offset", "--noise_offset",
type=float, type=float,
@@ -2732,13 +2740,25 @@ def load_tokenizer(args: argparse.Namespace):
def prepare_accelerator(args: argparse.Namespace): def prepare_accelerator(args: argparse.Namespace):
if args.logging_dir is None: if args.logging_dir is None:
log_with = None
logging_dir = None logging_dir = None
else: else:
log_with = "tensorboard"
log_prefix = "" if args.log_prefix is None else args.log_prefix log_prefix = "" if args.log_prefix is None else args.log_prefix
logging_dir = args.logging_dir + "/" + log_prefix + time.strftime("%Y%m%d%H%M%S", time.localtime()) logging_dir = args.logging_dir + "/" + log_prefix + time.strftime("%Y%m%d%H%M%S", time.localtime())
if args.log_with is not None:
log_with = "tensorboard" if args.log_with is None else args.log_with
if log_with in ["tensorboard", "all"]:
if logging_dir is None:
raise ValueError("logging_dir is required when log_with is tensorboard / Tensorboardを使う場合、logging_dirを指定してください")
if log_with in ["wandb", "all"]:
try:
import wandb
except ImportError:
raise ImportError("No wandb / wandb がインストールされていないようです")
if logging_dir is not None:
os.makedirs(logging_dir)
os.environ["WANDB_DIR"] = logging_dir
accelerator = Accelerator( accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision, mixed_precision=args.mixed_precision,
@@ -3197,6 +3217,20 @@ def sample_images(
image.save(os.path.join(save_dir, img_filename)) image.save(os.path.join(save_dir, img_filename))
# wandb有効時のみログを送信
try:
wandb_tracker = accelerator.get_tracker("wandb")
try:
import wandb
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
raise ImportError("No wandb / wandb がインストールされていないようです")
wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
except: # wandb 無効時
pass
# clear pipeline and cache to reduce vram usage # clear pipeline and cache to reduce vram usage
del pipeline del pipeline
torch.cuda.empty_cache() torch.cuda.empty_cache()
@@ -3205,7 +3239,6 @@ def sample_images(
torch.cuda.set_rng_state(cuda_rng_state) torch.cuda.set_rng_state(cuda_rng_state)
vae.to(org_vae_device) vae.to(org_vae_device)
# endregion # endregion
# region 前処理用 # region 前処理用

View File

@@ -231,7 +231,7 @@ def train(args):
) )
if accelerator.is_main_process: if accelerator.is_main_process:
accelerator.init_trackers("dreambooth") accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name)
loss_list = [] loss_list = []
loss_total = 0.0 loss_total = 0.0

View File

@@ -538,7 +538,7 @@ def train(args):
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
) )
if accelerator.is_main_process: if accelerator.is_main_process:
accelerator.init_trackers("network_train") accelerator.init_trackers("network_train" if args.log_tracker_name is None else args.log_tracker_name)
loss_list = [] loss_list = []
loss_total = 0.0 loss_total = 0.0

View File

@@ -337,7 +337,7 @@ def train(args):
) )
if accelerator.is_main_process: if accelerator.is_main_process:
accelerator.init_trackers("textual_inversion") accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name)
for epoch in range(num_train_epochs): for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}") print(f"epoch {epoch+1}/{num_train_epochs}")

View File

@@ -371,7 +371,7 @@ def train(args):
) )
if accelerator.is_main_process: if accelerator.is_main_process:
accelerator.init_trackers("textual_inversion") accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name)
for epoch in range(num_train_epochs): for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}") print(f"epoch {epoch+1}/{num_train_epochs}")