diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index cb3f1bdd..c5e7ab39 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -283,8 +283,6 @@ def get_weighted_text_embeddings( prompt = [prompt] prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, prompt, max_length - 2) - # prompt_tokens = [token[1:-1] for token in tokenizer(prompt, max_length=max_length, truncation=True).input_ids] - # prompt_weights = [[1.0] * len(token) for token in prompt_tokens] # round up the longest length of tokens to a multiple of (model_max_length - 2) max_length = max([len(token) for token in prompt_tokens]) @@ -330,4 +328,4 @@ def get_weighted_text_embeddings( current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) text_embeddings = text_embeddings * (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) - return text_embeddings \ No newline at end of file + return text_embeddings