support sample generation in TI training

This commit is contained in:
Kohya S
2023-02-28 22:05:10 +09:00
parent dd523c94ff
commit 57c565c402

View File

@@ -164,9 +164,13 @@ def train(args):
for tmpl in templates: for tmpl in templates:
captions.append(tmpl.format(replace_to)) captions.append(tmpl.format(replace_to))
train_dataset.add_replacement("", captions) train_dataset.add_replacement("", captions)
elif args.num_vectors_per_token > 1: else:
replace_to = " ".join(token_strings) if args.num_vectors_per_token > 1:
train_dataset.add_replacement(args.token_string, replace_to) replace_to = " ".join(token_strings)
train_dataset.add_replacement(args.token_string, replace_to)
prompt_replacement = (args.token_string, replace_to)
else:
prompt_replacement = None
train_dataset.make_buckets() train_dataset.make_buckets()
@@ -288,7 +292,6 @@ def train(args):
text_encoder.train() text_encoder.train()
loss_total = 0 loss_total = 0
bef_epo_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(text_encoder): with accelerator.accumulate(text_encoder):
with torch.no_grad(): with torch.no_grad():
@@ -354,7 +357,8 @@ def train(args):
progress_bar.update(1) progress_bar.update(1)
global_step += 1 global_step += 1
train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) train_util.sample_images(accelerator, args, None, global_step, accelerator.device,
vae, tokenizer, text_encoder, unet, prompt_replacement)
current_loss = loss.detach().item() current_loss = loss.detach().item()
if args.logging_dir is not None: if args.logging_dir is not None:
@@ -399,7 +403,8 @@ def train(args):
if saving and args.save_state: if saving and args.save_state:
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1) train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device,
vae, tokenizer, text_encoder, unet, prompt_replacement)
# end of epoch # end of epoch