From 6a5f87d874031b66152de2356a96076061e596de Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 8 Apr 2023 21:45:57 +0900 Subject: [PATCH] disable weighted captions in TI/XTI training --- library/custom_train_functions.py | 15 ++++++++------- train_textual_inversion.py | 2 +- train_textual_inversion_XTI.py | 2 +- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 9c0c4028..7eb829fa 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -18,19 +18,20 @@ def apply_snr_weight(loss, timesteps, noise_scheduler, gamma): return loss -def add_custom_train_arguments(parser: argparse.ArgumentParser): +def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted_captions: bool = True): parser.add_argument( "--min_snr_gamma", type=float, default=None, help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨", ) - parser.add_argument( - "--weighted_captions", - action="store_true", - default=False, - help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder.", - ) + if support_weighted_captions: + parser.add_argument( + "--weighted_captions", + action="store_true", + default=False, + help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意", + ) re_attention = re.compile( diff --git a/train_textual_inversion.py b/train_textual_inversion.py index d8d803a4..98639345 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -549,7 +549,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_training_arguments(parser, True) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) - custom_train_functions.add_custom_train_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser, False) parser.add_argument( "--save_model_as", diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 9bd775ef..db46ad1b 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -603,7 +603,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_training_arguments(parser, True) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) - custom_train_functions.add_custom_train_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser, False) parser.add_argument( "--save_model_as",