mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
some minor fixes
This commit is contained in:
@@ -2075,7 +2075,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
help="what logging tool(s) to use (if 'all', TensorBoard and WandB are both used) / ログ出力に使用するツール (allを指定するとTensorBoardとWandBの両方が使用される)",
|
||||
)
|
||||
parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
|
||||
parser.add_argument("--log_tracker_name", type=str, default=None, help="name of tracker to use for logging / ログ出力に使用するtrackerの名前")
|
||||
parser.add_argument(
|
||||
"--log_tracker_name", type=str, default=None, help="name of tracker to use for logging, default is script-specific default name / ログ出力に使用するtrackerの名前、省略時はスクリプトごとのデフォルト名"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--noise_offset",
|
||||
type=float,
|
||||
@@ -2746,7 +2748,7 @@ def prepare_accelerator(args: argparse.Namespace):
|
||||
logging_dir = args.logging_dir + "/" + log_prefix + time.strftime("%Y%m%d%H%M%S", time.localtime())
|
||||
|
||||
if args.log_with is not None:
|
||||
log_with = "tensorboard" if args.log_with is None else args.log_with
|
||||
log_with = args.log_with
|
||||
if log_with in ["tensorboard", "all"]:
|
||||
if logging_dir is None:
|
||||
raise ValueError("logging_dir is required when log_with is tensorboard / Tensorboardを使う場合、logging_dirを指定してください")
|
||||
@@ -2756,7 +2758,7 @@ def prepare_accelerator(args: argparse.Namespace):
|
||||
except ImportError:
|
||||
raise ImportError("No wandb / wandb がインストールされていないようです")
|
||||
if logging_dir is not None:
|
||||
os.makedirs(logging_dir)
|
||||
os.makedirs(logging_dir, exist_ok=True)
|
||||
os.environ["WANDB_DIR"] = logging_dir
|
||||
|
||||
accelerator = Accelerator(
|
||||
@@ -3222,14 +3224,12 @@ def sample_images(
|
||||
wandb_tracker = accelerator.get_tracker("wandb")
|
||||
try:
|
||||
import wandb
|
||||
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
|
||||
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
|
||||
raise ImportError("No wandb / wandb がインストールされていないようです")
|
||||
|
||||
|
||||
wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
|
||||
except: # wandb 無効時
|
||||
except: # wandb 無効時
|
||||
pass
|
||||
|
||||
|
||||
|
||||
# clear pipeline and cache to reduce vram usage
|
||||
del pipeline
|
||||
@@ -3239,6 +3239,7 @@ def sample_images(
|
||||
torch.cuda.set_rng_state(cuda_rng_state)
|
||||
vae.to(org_vae_device)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region 前処理用
|
||||
|
||||
Reference in New Issue
Block a user