From 38ef8ea8d6af218ce92984ec1b42f8b0144b7daa Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 24 Jan 2024 00:23:41 -0500 Subject: [PATCH] Clean up code --- library/custom_train_functions.py | 14 +------------- library/train_util.py | 2 +- 2 files changed, 2 insertions(+), 14 deletions(-) diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index d9800345..6bd01a06 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -70,19 +70,7 @@ def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False def apply_soft_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False): snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) - # min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma)) - soft_min_snr_gamma_weight = 1 / (torch.pow(snr, 2) + (1 / float(gamma))) - with open("snr.txt", "a") as myfile: - myfile.write(f"{snr.item()},{gamma}\n") - - # with open("snrmin.txt", "a") as myfile: - # myfile.write(f"{min_snr_gamma.item()},{soft_min_snr_gamma.item()}\n") - # print("soft_min_snr_gamma", soft_min_snr_gamma, 1 / (snr + (1 / float(gamma)))) - # print("min_snr_gamma", min_snr_gamma) - # if v_prediction: - # snr_weight = torch.div(soft_min_snr_gamma, snr+1).float().to(loss.device) - # else: - # snr_weight = torch.div(soft_min_snr_gamma, snr).float().to(loss.device) + soft_min_snr_gamma_weight = 1 / (torch.pow(snr if v_prediction is False else snr + 1, 2) + (1 / float(gamma))) loss = loss * soft_min_snr_gamma_weight return loss diff --git a/library/train_util.py b/library/train_util.py index 6f62176d..4ac6728b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4809,7 +4809,7 @@ def sample_images_common( except ImportError: # 事前に一度確認するのでここはエラー出ないはず raise ImportError("No wandb / wandb がインストールされていないようです") - wandb_tracker.log({f"sample_{i}": wandb.Image(image)}, step=steps) + wandb_tracker.log({f"sample_{i}": wandb.Image(image)}) except: # wandb 無効時 pass