some minor fixes

This commit is contained in:
Kohya S
2023-04-22 09:55:04 +09:00
parent c430cf481a
commit 220436244c

View File

@@ -2075,7 +2075,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
help="what logging tool(s) to use (if 'all', TensorBoard and WandB are both used) / ログ出力に使用するツール (allを指定するとTensorBoardとWandBの両方が使用される)", help="what logging tool(s) to use (if 'all', TensorBoard and WandB are both used) / ログ出力に使用するツール (allを指定するとTensorBoardとWandBの両方が使用される)",
) )
parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列") parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
parser.add_argument("--log_tracker_name", type=str, default=None, help="name of tracker to use for logging / ログ出力に使用するtrackerの名前") parser.add_argument(
"--log_tracker_name", type=str, default=None, help="name of tracker to use for logging, default is script-specific default name / ログ出力に使用するtrackerの名前、省略時はスクリプトごとのデフォルト名"
)
parser.add_argument( parser.add_argument(
"--noise_offset", "--noise_offset",
type=float, type=float,
@@ -2746,7 +2748,7 @@ def prepare_accelerator(args: argparse.Namespace):
logging_dir = args.logging_dir + "/" + log_prefix + time.strftime("%Y%m%d%H%M%S", time.localtime()) logging_dir = args.logging_dir + "/" + log_prefix + time.strftime("%Y%m%d%H%M%S", time.localtime())
if args.log_with is not None: if args.log_with is not None:
log_with = "tensorboard" if args.log_with is None else args.log_with log_with = args.log_with
if log_with in ["tensorboard", "all"]: if log_with in ["tensorboard", "all"]:
if logging_dir is None: if logging_dir is None:
raise ValueError("logging_dir is required when log_with is tensorboard / Tensorboardを使う場合、logging_dirを指定してください") raise ValueError("logging_dir is required when log_with is tensorboard / Tensorboardを使う場合、logging_dirを指定してください")
@@ -2756,7 +2758,7 @@ def prepare_accelerator(args: argparse.Namespace):
except ImportError: except ImportError:
raise ImportError("No wandb / wandb がインストールされていないようです") raise ImportError("No wandb / wandb がインストールされていないようです")
if logging_dir is not None: if logging_dir is not None:
os.makedirs(logging_dir) os.makedirs(logging_dir, exist_ok=True)
os.environ["WANDB_DIR"] = logging_dir os.environ["WANDB_DIR"] = logging_dir
accelerator = Accelerator( accelerator = Accelerator(
@@ -3222,14 +3224,12 @@ def sample_images(
wandb_tracker = accelerator.get_tracker("wandb") wandb_tracker = accelerator.get_tracker("wandb")
try: try:
import wandb import wandb
except ImportError: # 事前に一度確認するのでここはエラー出ないはず except ImportError: # 事前に一度確認するのでここはエラー出ないはず
raise ImportError("No wandb / wandb がインストールされていないようです") raise ImportError("No wandb / wandb がインストールされていないようです")
wandb_tracker.log({f"sample_{i}": wandb.Image(image)}) wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
except: # wandb 無効時 except: # wandb 無効時
pass pass
# clear pipeline and cache to reduce vram usage # clear pipeline and cache to reduce vram usage
del pipeline del pipeline
@@ -3239,6 +3239,7 @@ def sample_images(
torch.cuda.set_rng_state(cuda_rng_state) torch.cuda.set_rng_state(cuda_rng_state)
vae.to(org_vae_device) vae.to(org_vae_device)
# endregion # endregion
# region 前処理用 # region 前処理用