add tag dropout

This commit is contained in:
Kohya S
2023-02-09 21:35:27 +09:00
parent f7b5abb595
commit 3a72e6f003
5 changed files with 60 additions and 46 deletions

View File

@@ -235,7 +235,7 @@ def train(args):
text_encoder, optimizer, train_dataloader, lr_scheduler)
index_no_updates = torch.arange(len(tokenizer)) < token_ids[0]
print(len(index_no_updates), torch.sum(index_no_updates))
# print(len(index_no_updates), torch.sum(index_no_updates))
orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
# Freeze all parameters except for the token embeddings in text encoder
@@ -296,6 +296,7 @@ def train(args):
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset.set_current_epoch(epoch + 1)
text_encoder.train()
@@ -383,8 +384,8 @@ def train(args):
accelerator.wait_for_everyone()
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
d = updated_embs - bef_epo_embs
print(bef_epo_embs.size(), updated_embs.size(), d.mean(), d.min())
# d = updated_embs - bef_epo_embs
# print(bef_epo_embs.size(), updated_embs.size(), d.mean(), d.min())
if args.save_every_n_epochs is not None:
model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name