diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index 7c5e6860..e559e718 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -327,14 +327,17 @@ def save_sd_model_on_epoch_end_or_stepwise( def add_sdxl_training_arguments(parser: argparse.ArgumentParser, support_text_encoder_caching: bool = True): - parser.add_argument( - "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする" - ) - parser.add_argument( - "--cache_text_encoder_outputs_to_disk", - action="store_true", - help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする", - ) + if support_text_encoder_caching: + parser.add_argument( + "--cache_text_encoder_outputs", + action="store_true", + help="cache text encoder outputs / text encoderの出力をキャッシュする", + ) + parser.add_argument( + "--cache_text_encoder_outputs_to_disk", + action="store_true", + help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする", + ) parser.add_argument( "--disable_mmap_load_safetensors", action="store_true", @@ -342,7 +345,7 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser, support_text_en ) -def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True): +def verify_sdxl_training_args(args: argparse.Namespace, support_text_encoder_caching: bool = True): assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません" if args.clip_skip is not None: @@ -365,7 +368,7 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin # not hasattr(args, "weighted_captions") or not args.weighted_captions # ), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません" - if supportTextEncoderCaching: + if support_text_encoder_caching: if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: args.cache_text_encoder_outputs = True logger.warning( diff --git a/sdxl_train_textual_inversion.py b/sdxl_train_textual_inversion.py index be538cdd..6dec31de 100644 --- a/sdxl_train_textual_inversion.py +++ b/sdxl_train_textual_inversion.py @@ -20,7 +20,8 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine self.is_sdxl = True def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): - sdxl_train_util.verify_sdxl_training_args(args, supportTextEncoderCaching=False) + # super().assert_extra_args(args, train_dataset_group) # do not call parent because it checks reso steps with 64 + sdxl_train_util.verify_sdxl_training_args(args, support_text_encoder_caching=False) train_dataset_group.verify_bucket_reso_steps(32) if val_dataset_group is not None: