mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
support sample generation in TI training
This commit is contained in:
@@ -164,9 +164,13 @@ def train(args):
|
|||||||
for tmpl in templates:
|
for tmpl in templates:
|
||||||
captions.append(tmpl.format(replace_to))
|
captions.append(tmpl.format(replace_to))
|
||||||
train_dataset.add_replacement("", captions)
|
train_dataset.add_replacement("", captions)
|
||||||
elif args.num_vectors_per_token > 1:
|
else:
|
||||||
replace_to = " ".join(token_strings)
|
if args.num_vectors_per_token > 1:
|
||||||
train_dataset.add_replacement(args.token_string, replace_to)
|
replace_to = " ".join(token_strings)
|
||||||
|
train_dataset.add_replacement(args.token_string, replace_to)
|
||||||
|
prompt_replacement = (args.token_string, replace_to)
|
||||||
|
else:
|
||||||
|
prompt_replacement = None
|
||||||
|
|
||||||
train_dataset.make_buckets()
|
train_dataset.make_buckets()
|
||||||
|
|
||||||
@@ -288,7 +292,6 @@ def train(args):
|
|||||||
text_encoder.train()
|
text_encoder.train()
|
||||||
|
|
||||||
loss_total = 0
|
loss_total = 0
|
||||||
bef_epo_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
|
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
with accelerator.accumulate(text_encoder):
|
with accelerator.accumulate(text_encoder):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -354,7 +357,8 @@ def train(args):
|
|||||||
progress_bar.update(1)
|
progress_bar.update(1)
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
train_util.sample_images(accelerator, args, None, global_step, accelerator.device,
|
||||||
|
vae, tokenizer, text_encoder, unet, prompt_replacement)
|
||||||
|
|
||||||
current_loss = loss.detach().item()
|
current_loss = loss.detach().item()
|
||||||
if args.logging_dir is not None:
|
if args.logging_dir is not None:
|
||||||
@@ -399,7 +403,8 @@ def train(args):
|
|||||||
if saving and args.save_state:
|
if saving and args.save_state:
|
||||||
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
|
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
|
||||||
|
|
||||||
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device,
|
||||||
|
vae, tokenizer, text_encoder, unet, prompt_replacement)
|
||||||
|
|
||||||
# end of epoch
|
# end of epoch
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user