diff --git a/train_textual_inversion.py b/train_textual_inversion.py index d91a78ff..34b7f092 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -181,6 +181,11 @@ def train(args): for tmpl in templates: captions.append(tmpl.format(replace_to)) train_dataset_group.add_replacement("", captions) + + if args.num_vectors_per_token > 1: + prompt_replacement = (args.token_string, replace_to) + else: + prompt_replacement = None else: if args.num_vectors_per_token > 1: replace_to = " ".join(token_strings)