Clean up custom_train_functions.py

Removed commented out lines from earlier bugfix.
This commit is contained in:
AI-Casanova
2023-04-08 00:44:56 -05:00
committed by GitHub
parent 0d54609435
commit dbab72153f

View File

@@ -283,8 +283,6 @@ def get_weighted_text_embeddings(
prompt = [prompt]
prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, prompt, max_length - 2)
# prompt_tokens = [token[1:-1] for token in tokenizer(prompt, max_length=max_length, truncation=True).input_ids]
# prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
# round up the longest length of tokens to a multiple of (model_max_length - 2)
max_length = max([len(token) for token in prompt_tokens])
@@ -330,4 +328,4 @@ def get_weighted_text_embeddings(
current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
text_embeddings = text_embeddings * (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
return text_embeddings
return text_embeddings