format by black

This commit is contained in:
Kohya S
2024-02-18 09:13:24 +09:00
parent 75e4a951d0
commit d1fb480887

View File

@@ -70,8 +70,10 @@ import library.model_util as model_util
import library.huggingface_util as huggingface_util
import library.sai_model_spec as sai_model_spec
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
# from library.attention_processors import FlashAttnProcessor
# from library.hypernetwork import replace_attentions_for_hypernetwork
@@ -1483,7 +1485,9 @@ class DreamBoothDataset(BaseDataset):
img_paths, captions = load_dreambooth_dir(subset)
if len(img_paths) < 1:
logger.warning(f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します")
logger.warning(
f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します"
)
continue
if subset.is_reg:
@@ -1574,7 +1578,9 @@ class FineTuningDataset(BaseDataset):
raise ValueError(f"no metadata / メタデータファイルがありません: {subset.metadata_file}")
if len(metadata) < 1:
logger.warning(f"ignore subset with '{subset.metadata_file}': no image entries found / 画像に関するデータが見つからないためサブセットを無視します")
logger.warning(
f"ignore subset with '{subset.metadata_file}': no image entries found / 画像に関するデータが見つからないためサブセットを無視します"
)
continue
tags_list = []
@@ -1655,7 +1661,9 @@ class FineTuningDataset(BaseDataset):
logger.warning(f"npz file does not exist. ignore npz files / npzファイルが見つからないためnpzファイルを無視します")
elif not npz_all:
use_npz_latents = False
logger.warning(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します")
logger.warning(
f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します"
)
if flip_aug_in_subset:
logger.warning("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
# else:
@@ -1675,7 +1683,9 @@ class FineTuningDataset(BaseDataset):
if sizes is None:
if use_npz_latents:
use_npz_latents = False
logger.warning(f"npz files exist, but no bucket info in metadata. ignore npz files / メタデータにbucket情報がないためnpzファイルを無視します")
logger.warning(
f"npz files exist, but no bucket info in metadata. ignore npz files / メタデータにbucket情報がないためnpzファイルを無視します"
)
assert (
resolution is not None
@@ -1871,7 +1881,9 @@ class ControlNetDataset(BaseDataset):
assert (
cond_img.shape[0] == original_size_hw[0] and cond_img.shape[1] == original_size_hw[1]
), f"size of conditioning image is not match / 画像サイズが合いません: {image_info.absolute_path}"
cond_img = cv2.resize(cond_img, image_info.resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
cond_img = cv2.resize(
cond_img, image_info.resized_size, interpolation=cv2.INTER_AREA
) # INTER_AREAでやりたいのでcv2でリサイズ
# TODO support random crop
# 現在サポートしているcropはrandomではなく中央のみ
@@ -2025,7 +2037,9 @@ def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, fli
def debug_dataset(train_dataset, show_input_ids=False):
logger.info(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
logger.info("`S` for next step, `E` for next epoch no. , Escape for exit. / Sキーで次のステップ、Eキーで次のエポック、Escキーで中断、終了します")
logger.info(
"`S` for next step, `E` for next epoch no. , Escape for exit. / Sキーで次のステップ、Eキーで次のエポック、Escキーで中断、終了します"
)
epoch = 1
while True:
@@ -2686,7 +2700,9 @@ def get_sai_model_spec(
def add_sd_models_arguments(parser: argparse.ArgumentParser):
# for pretrained models
parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む")
parser.add_argument(
"--v2", action="store_true", help="load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む"
)
parser.add_argument(
"--v_parameterization", action="store_true", help="enable v-parameterization training / v-parameterization学習を有効にする"
)
@@ -2726,7 +2742,10 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
parser.add_argument(
"--max_grad_norm", default=1.0, type=float, help="Max gradient norm, 0 for no clipping / 勾配正規化の最大norm、0でclippingを行わない"
"--max_grad_norm",
default=1.0,
type=float,
help="Max gradient norm, 0 for no clipping / 勾配正規化の最大norm、0でclippingを行わない",
)
parser.add_argument(
@@ -2773,13 +2792,23 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
parser.add_argument("--output_dir", type=str, default=None, help="directory to output trained model / 学習後のモデル出力先ディレクトリ")
parser.add_argument("--output_name", type=str, default=None, help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名")
parser.add_argument(
"--huggingface_repo_id", type=str, default=None, help="huggingface repo name to upload / huggingfaceにアップロードするリポジトリ"
"--output_dir", type=str, default=None, help="directory to output trained model / 学習後のモデル出力先ディレクトリ"
)
parser.add_argument(
"--huggingface_repo_type", type=str, default=None, help="huggingface repo type to upload / huggingfaceにアップロードするリポジトリの種類"
"--output_name", type=str, default=None, help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名"
)
parser.add_argument(
"--huggingface_repo_id",
type=str,
default=None,
help="huggingface repo name to upload / huggingfaceにアップロードするリポジトリ名",
)
parser.add_argument(
"--huggingface_repo_type",
type=str,
default=None,
help="huggingface repo type to upload / huggingfaceにアップロードするリポジトリの種類",
)
parser.add_argument(
"--huggingface_path_in_repo",
@@ -2815,10 +2844,16 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
help="precision in saving / 保存時に精度を変更して保存する",
)
parser.add_argument(
"--save_every_n_epochs", type=int, default=None, help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する"
"--save_every_n_epochs",
type=int,
default=None,
help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する",
)
parser.add_argument(
"--save_every_n_steps", type=int, default=None, help="save checkpoint every N steps / 学習中のモデルを指定ステップごとに保存する"
"--save_every_n_steps",
type=int,
default=None,
help="save checkpoint every N steps / 学習中のモデルを指定ステップごとに保存する",
)
parser.add_argument(
"--save_n_epoch_ratio",
@@ -2870,7 +2905,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
action="store_true",
help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う",
)
parser.add_argument("--torch_compile", action="store_true", help="use torch.compile (requires PyTorch 2.0) / torch.compile を使う")
parser.add_argument(
"--torch_compile", action="store_true", help="use torch.compile (requires PyTorch 2.0) / torch.compile を使う"
)
parser.add_argument(
"--dynamo_backend",
type=str,
@@ -2888,7 +2925,10 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
help="use sdpa for CrossAttention (requires PyTorch 2.0) / CrossAttentionにsdpaを使うPyTorch 2.0が必要)",
)
parser.add_argument(
"--vae", type=str, default=None, help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ"
"--vae",
type=str,
default=None,
help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ",
)
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
@@ -2920,7 +2960,11 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
help="Number of updates steps to accumulate before performing a backward/update pass / 学習時に逆伝播をする前に勾配を合計するステップ数",
)
parser.add_argument(
"--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度"
"--mixed_precision",
type=str,
default="no",
choices=["no", "fp16", "bf16"],
help="use mixed precision / 混合精度を使う場合、その精度",
)
parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する")
parser.add_argument(
@@ -2962,7 +3006,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
choices=["tensorboard", "wandb", "all"],
help="what logging tool(s) to use (if 'all', TensorBoard and WandB are both used) / ログ出力に使用するツール (allを指定するとTensorBoardとWandBの両方が使用される)",
)
parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
parser.add_argument(
"--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列"
)
parser.add_argument(
"--log_tracker_name",
type=str,
@@ -3050,14 +3096,19 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
parser.add_argument(
"--highvram",
action="store_true",
help="disable low VRAM optimization. e.g. do not clear CUDA cache after each latent caching (for machines which have bigger VRAM) " +
"/ VRAMが少ない環境向け最適化を無効にする。たとえば各latentのキャッシュ後のCUDAキャッシュクリアを行わない等VRAMが多い環境向け",
help="disable low VRAM optimization. e.g. do not clear CUDA cache after each latent caching (for machines which have bigger VRAM) "
+ "/ VRAMが少ない環境向け最適化を無効にする。たとえば各latentのキャッシュ後のCUDAキャッシュクリアを行わない等VRAMが多い環境向け",
)
parser.add_argument(
"--sample_every_n_steps", type=int, default=None, help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する"
"--sample_every_n_steps",
type=int,
default=None,
help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する",
)
parser.add_argument(
"--sample_at_first", action="store_true", help="generate sample images before training / 学習前にサンプル出力する"
)
parser.add_argument("--sample_at_first", action="store_true", help="generate sample images before training / 学習前にサンプル出力する")
parser.add_argument(
"--sample_every_n_epochs",
type=int,
@@ -3065,7 +3116,10 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
help="generate sample images every N epochs (overwrites n_steps) / 学習中のモデルで指定エポックごとにサンプル出力する(ステップ数指定を上書きします)",
)
parser.add_argument(
"--sample_prompts", type=str, default=None, help="file for prompts to generate sample images / 学習中モデルのサンプル出力用プロンプトのファイル"
"--sample_prompts",
type=str,
default=None,
help="file for prompts to generate sample images / 学習中モデルのサンプル出力用プロンプトのファイル",
)
parser.add_argument(
"--sample_sampler",
@@ -3152,7 +3206,9 @@ def verify_training_args(args: argparse.Namespace):
HIGH_VRAM = True
if args.v_parameterization and not args.v2:
logger.warning("v_parameterization should be with v2 not v1 or sdxl / v1やsdxlでv_parameterizationを使用することは想定されていません")
logger.warning(
"v_parameterization should be with v2 not v1 or sdxl / v1やsdxlでv_parameterizationを使用することは想定されていません"
)
if args.v2 and args.clip_skip is not None:
logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
@@ -3199,8 +3255,12 @@ def add_dataset_arguments(
parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool
):
# dataset common
parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("--shuffle_caption", action="store_true", help="shuffle separated caption / 区切られたcaptionの各要素をshuffleする")
parser.add_argument(
"--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ"
)
parser.add_argument(
"--shuffle_caption", action="store_true", help="shuffle separated caption / 区切られたcaptionの各要素をshuffleする"
)
parser.add_argument("--caption_separator", type=str, default=",", help="separator for caption / captionの区切り文字")
parser.add_argument(
"--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子"
@@ -3236,8 +3296,12 @@ def add_dataset_arguments(
default=None,
help="suffix for caption text / captionのテキストの末尾に付ける文字列",
)
parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする")
parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする")
parser.add_argument(
"--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする"
)
parser.add_argument(
"--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする"
)
parser.add_argument(
"--face_crop_aug_range",
type=str,
@@ -3250,7 +3314,9 @@ def add_dataset_arguments(
help="enable random crop (for style training in face-centered crop augmentation) / ランダムな切り出しを有効にする顔を中心としたaugmentationを行うときに画風の学習用に指定する",
)
parser.add_argument(
"--debug_dataset", action="store_true", help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)"
"--debug_dataset",
action="store_true",
help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)",
)
parser.add_argument(
"--resolution",
@@ -3263,14 +3329,18 @@ def add_dataset_arguments(
action="store_true",
help="cache latents to main memory to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをメインメモリにcacheするaugmentationは使用不可 ",
)
parser.add_argument("--vae_batch_size", type=int, default=1, help="batch size for caching latents / latentのcache時のバッチサイズ")
parser.add_argument(
"--vae_batch_size", type=int, default=1, help="batch size for caching latents / latentのcache時のバッチサイズ"
)
parser.add_argument(
"--cache_latents_to_disk",
action="store_true",
help="cache latents to disk to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをディスクにcacheするaugmentationは使用不可",
)
parser.add_argument(
"--enable_bucket", action="store_true", help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする"
"--enable_bucket",
action="store_true",
help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする",
)
parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最大解像度")
@@ -3281,7 +3351,9 @@ def add_dataset_arguments(
help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します",
)
parser.add_argument(
"--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します"
"--bucket_no_upscale",
action="store_true",
help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します",
)
parser.add_argument(
@@ -3325,13 +3397,20 @@ def add_dataset_arguments(
if support_dreambooth:
# DreamBooth dataset
parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ")
parser.add_argument(
"--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ"
)
if support_caption:
# caption dataset
parser.add_argument("--in_json", type=str, default=None, help="json metadata for dataset / データセットのmetadataのjsonファイル")
parser.add_argument(
"--dataset_repeats", type=int, default=1, help="repeat dataset when training with captions / キャプションでの学習時にデータセットを繰り返す回数"
"--in_json", type=str, default=None, help="json metadata for dataset / データセットのmetadataのjsonファイル"
)
parser.add_argument(
"--dataset_repeats",
type=int,
default=1,
help="repeat dataset when training with captions / キャプションでの学習時にデータセットを繰り返す回数",
)
@@ -3469,7 +3548,9 @@ def resume_from_local_or_hf_if_specified(accelerator, args):
loop = asyncio.get_event_loop()
results = loop.run_until_complete(asyncio.gather(*[download(filename=filename.rfilename) for filename in list_files]))
if len(results) == 0:
raise ValueError("No files found in the specified repo id/path/revision / 指定されたリポジトリID/パス/リビジョンにファイルが見つかりませんでした")
raise ValueError(
"No files found in the specified repo id/path/revision / 指定されたリポジトリID/パス/リビジョンにファイルが見つかりませんでした"
)
dirname = os.path.dirname(results[0])
accelerator.load_state(dirname)
@@ -3610,7 +3691,9 @@ def get_optimizer(args, trainable_params):
elif optimizer_type == "SGDNesterov".lower():
logger.info(f"use SGD with Nesterov optimizer | {optimizer_kwargs}")
if "momentum" not in optimizer_kwargs:
logger.info(f"SGD with Nesterov must be with momentum, set momentum to 0.9 / SGD with Nesterovはmomentum指定が必須のため0.9に設定します")
logger.info(
f"SGD with Nesterov must be with momentum, set momentum to 0.9 / SGD with Nesterovはmomentum指定が必須のため0.9に設定します"
)
optimizer_kwargs["momentum"] = 0.9
optimizer_class = torch.optim.SGD
@@ -3689,7 +3772,9 @@ def get_optimizer(args, trainable_params):
if "relative_step" not in optimizer_kwargs:
optimizer_kwargs["relative_step"] = True # default
if not optimizer_kwargs["relative_step"] and optimizer_kwargs.get("warmup_init", False):
logger.info(f"set relative_step to True because warmup_init is True / warmup_initがTrueのためrelative_stepをTrueにします")
logger.info(
f"set relative_step to True because warmup_init is True / warmup_initがTrueのためrelative_stepをTrueにします"
)
optimizer_kwargs["relative_step"] = True
logger.info(f"use Adafactor optimizer | {optimizer_kwargs}")
@@ -3913,7 +3998,9 @@ def prepare_accelerator(args: argparse.Namespace):
log_with = args.log_with
if log_with in ["tensorboard", "all"]:
if logging_dir is None:
raise ValueError("logging_dir is required when log_with is tensorboard / Tensorboardを使う場合、logging_dirを指定してください")
raise ValueError(
"logging_dir is required when log_with is tensorboard / Tensorboardを使う場合、logging_dirを指定してください"
)
if log_with in ["wandb", "all"]:
try:
import wandb
@@ -3932,9 +4019,13 @@ def prepare_accelerator(args: argparse.Namespace):
kwargs_handlers = (
InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout)) if args.ddp_timeout else None,
DistributedDataParallelKwargs(gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph)
if args.ddp_gradient_as_bucket_view or args.ddp_static_graph
else None,
(
DistributedDataParallelKwargs(
gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph
)
if args.ddp_gradient_as_bucket_view or args.ddp_static_graph
else None
),
)
kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers))
accelerator = Accelerator(
@@ -4083,7 +4174,9 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod
# v1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
for i in range(1, args.max_token_length, tokenizer.model_max_length):
states_list.append(encoder_hidden_states[:, i : i + tokenizer.model_max_length - 2]) # <BOS> の後から <EOS> の前まで
states_list.append(
encoder_hidden_states[:, i : i + tokenizer.model_max_length - 2]
) # <BOS> の後から <EOS> の前まで
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS>
encoder_hidden_states = torch.cat(states_list, dim=1)
@@ -4666,6 +4759,7 @@ def line_to_prompt_dict(line: str) -> dict:
return prompt_dict
def sample_images_common(
pipe_class,
accelerator: Accelerator,
@@ -4704,10 +4798,10 @@ def sample_images_common(
logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
return
distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here
distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here
org_vae_device = vae.device # CPUにいるはず
vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device
vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device
# unwrap unet and text_encoder(s)
unet = accelerator.unwrap_model(unet)
@@ -4774,18 +4868,22 @@ def sample_images_common(
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
with torch.no_grad():
for prompt_dict in prompts:
sample_image_inference(accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet)
sample_image_inference(
accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet
)
else:
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
# prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
per_process_prompts = [] # list of lists
per_process_prompts = [] # list of lists
for i in range(distributed_state.num_processes):
per_process_prompts.append(prompts[i::distributed_state.num_processes])
per_process_prompts.append(prompts[i :: distributed_state.num_processes])
with torch.no_grad():
with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists:
for prompt_dict in prompt_dict_lists[0]:
sample_image_inference(accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet)
sample_image_inference(
accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet
)
# clear pipeline and cache to reduce vram usage
del pipeline
@@ -4800,7 +4898,18 @@ def sample_images_common(
torch.cuda.set_rng_state(cuda_rng_state)
vae.to(org_vae_device)
def sample_image_inference(accelerator: Accelerator, args: argparse.Namespace, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=None):
def sample_image_inference(
accelerator: Accelerator,
args: argparse.Namespace,
pipeline,
save_dir,
prompt_dict,
epoch,
steps,
prompt_replacement,
controlnet=None,
):
assert isinstance(prompt_dict, dict)
negative_prompt = prompt_dict.get("negative_prompt")
sample_steps = prompt_dict.get("sample_steps", 30)
@@ -4870,9 +4979,7 @@ def sample_image_inference(accelerator: Accelerator, args: argparse.Namespace, p
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
seed_suffix = "" if seed is None else f"_{seed}"
i: int = prompt_dict["enum"]
img_filename = (
f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
)
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
image.save(os.path.join(save_dir, img_filename))
# wandb有効時のみログを送信
@@ -4886,9 +4993,10 @@ def sample_image_inference(accelerator: Accelerator, args: argparse.Namespace, p
wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
except: # wandb 無効時
pass
# endregion
# endregion
# region 前処理用