diff --git a/train_textual_inversion.py b/train_textual_inversion.py index df718133..85515157 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -164,9 +164,13 @@ def train(args): for tmpl in templates: captions.append(tmpl.format(replace_to)) train_dataset.add_replacement("", captions) - elif args.num_vectors_per_token > 1: - replace_to = " ".join(token_strings) - train_dataset.add_replacement(args.token_string, replace_to) + else: + if args.num_vectors_per_token > 1: + replace_to = " ".join(token_strings) + train_dataset.add_replacement(args.token_string, replace_to) + prompt_replacement = (args.token_string, replace_to) + else: + prompt_replacement = None train_dataset.make_buckets() @@ -288,7 +292,6 @@ def train(args): text_encoder.train() loss_total = 0 - bef_epo_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone() for step, batch in enumerate(train_dataloader): with accelerator.accumulate(text_encoder): with torch.no_grad(): @@ -354,7 +357,8 @@ def train(args): progress_bar.update(1) global_step += 1 - train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + train_util.sample_images(accelerator, args, None, global_step, accelerator.device, + vae, tokenizer, text_encoder, unet, prompt_replacement) current_loss = loss.detach().item() if args.logging_dir is not None: @@ -399,7 +403,8 @@ def train(args): if saving and args.save_state: train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1) - train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, + vae, tokenizer, text_encoder, unet, prompt_replacement) # end of epoch