fix to work cache latents/text encoder outputs

This commit is contained in:
kohya-ss
2024-10-13 19:08:16 +09:00
parent 2244cf5b83
commit bfc3a65acd
3 changed files with 26 additions and 14 deletions

View File

@@ -27,7 +27,7 @@ from library.config_util import (
BlueprintGenerator,
)
from library.utils import setup_logging, add_logging_arguments
from tools import cache_latents
from cache_latents import set_tokenize_strategy
setup_logging()
import logging
@@ -38,6 +38,7 @@ logger = logging.getLogger(__name__)
def cache_to_disk(args: argparse.Namespace) -> None:
setup_logging(args, reset=True)
train_util.prepare_dataset_args(args, True)
train_util.enable_high_vram(args)
args.cache_text_encoder_outputs = True
args.cache_text_encoder_outputs_to_disk = True
@@ -57,8 +58,8 @@ def cache_to_disk(args: argparse.Namespace) -> None:
assert (
is_sdxl or args.weighted_captions is None
), "Weighted captions are only supported for SDXL models / 重み付きキャプションはSDXLモデルでのみ有効です"
cache_latents.set_tokenize_strategy(is_sd, is_sdxl, is_flux, args)
set_tokenize_strategy(is_sd, is_sdxl, is_flux, args)
# データセットを準備する
use_user_config = args.dataset_config is not None
@@ -178,7 +179,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
train_dataset_group.new_cache_text_encoder_outputs(text_encoders, accelerator)
accelerator.wait_for_everyone()
accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.")
accelerator.print(f"Finished caching text encoder outputs to disk.")
def setup_parser() -> argparse.ArgumentParser:
@@ -188,9 +189,10 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_sd_models_arguments(parser)
train_util.add_training_arguments(parser, True)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_masked_loss_arguments(parser)
config_util.add_config_arguments(parser)
sdxl_train_util.add_sdxl_training_arguments(parser)
flux_train_utils.add_flux_train_arguments(parser)
parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する")
parser.add_argument("--flux", action="store_true", help="Use FLUX model / FLUXモデルを使用する")
parser.add_argument(
@@ -205,6 +207,12 @@ def setup_parser() -> argparse.ArgumentParser:
help="[Deprecated] This option does not work. Existing .npz files are always checked. Use `--skip_cache_check` to skip the check."
" / [非推奨] このオプションは機能しません。既存の .npz は常に検証されます。`--skip_cache_check` で検証をスキップできます。",
)
parser.add_argument(
"--weighted_captions",
action="store_true",
default=False,
help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意",
)
return parser