diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index c5e7ab39..9c0c4028 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -4,22 +4,34 @@ import re from typing import List, Optional, Union -def apply_snr_weight(loss, timesteps, noise_scheduler, gamma): - alphas_cumprod = noise_scheduler.alphas_cumprod - sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) - sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) - alpha = sqrt_alphas_cumprod - sigma = sqrt_one_minus_alphas_cumprod - all_snr = (alpha / sigma) ** 2 - snr = torch.stack([all_snr[t] for t in timesteps]) - gamma_over_snr = torch.div(torch.ones_like(snr)*gamma,snr) - snr_weight = torch.minimum(gamma_over_snr,torch.ones_like(gamma_over_snr)).float() #from paper - loss = loss * snr_weight - return loss +def apply_snr_weight(loss, timesteps, noise_scheduler, gamma): + alphas_cumprod = noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) + sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) + alpha = sqrt_alphas_cumprod + sigma = sqrt_one_minus_alphas_cumprod + all_snr = (alpha / sigma) ** 2 + snr = torch.stack([all_snr[t] for t in timesteps]) + gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr) + snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float() # from paper + loss = loss * snr_weight + return loss + def add_custom_train_arguments(parser: argparse.ArgumentParser): - parser.add_argument("--min_snr_gamma", type=float, default=None, help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨") - parser.add_argument("--weighted_captions", action="store_true", default=False, help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder.") + parser.add_argument( + "--min_snr_gamma", + type=float, + default=None, + help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨", + ) + parser.add_argument( + "--weighted_captions", + action="store_true", + default=False, + help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder.", + ) + re_attention = re.compile( r""" @@ -283,10 +295,10 @@ def get_weighted_text_embeddings( prompt = [prompt] prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, prompt, max_length - 2) - + # round up the longest length of tokens to a multiple of (model_max_length - 2) max_length = max([len(token) for token in prompt_tokens]) - + max_embeddings_multiples = min( max_embeddings_multiples, (max_length - 1) // (tokenizer.model_max_length - 2) + 1, @@ -308,7 +320,7 @@ def get_weighted_text_embeddings( chunk_length=tokenizer.model_max_length, ) prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device) - + # get the embeddings text_embeddings = get_unweighted_text_embeddings( tokenizer, @@ -321,11 +333,11 @@ def get_weighted_text_embeddings( no_boseos_middle=no_boseos_middle, ) prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=device) - + # assign weights to the prompts and normalize in the sense of mean previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) text_embeddings = text_embeddings * prompt_weights.unsqueeze(-1) current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) text_embeddings = text_embeddings * (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) - + return text_embeddings