Clean up code

This commit is contained in:
rockerBOO
2024-01-24 00:23:41 -05:00
parent d8155bfbe8
commit 38ef8ea8d6
2 changed files with 2 additions and 14 deletions

View File

@@ -70,19 +70,7 @@ def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False
def apply_soft_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False):
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
# min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
soft_min_snr_gamma_weight = 1 / (torch.pow(snr, 2) + (1 / float(gamma)))
with open("snr.txt", "a") as myfile:
myfile.write(f"{snr.item()},{gamma}\n")
# with open("snrmin.txt", "a") as myfile:
# myfile.write(f"{min_snr_gamma.item()},{soft_min_snr_gamma.item()}\n")
# print("soft_min_snr_gamma", soft_min_snr_gamma, 1 / (snr + (1 / float(gamma))))
# print("min_snr_gamma", min_snr_gamma)
# if v_prediction:
# snr_weight = torch.div(soft_min_snr_gamma, snr+1).float().to(loss.device)
# else:
# snr_weight = torch.div(soft_min_snr_gamma, snr).float().to(loss.device)
soft_min_snr_gamma_weight = 1 / (torch.pow(snr if v_prediction is False else snr + 1, 2) + (1 / float(gamma)))
loss = loss * soft_min_snr_gamma_weight
return loss

View File

@@ -4809,7 +4809,7 @@ def sample_images_common(
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
raise ImportError("No wandb / wandb がインストールされていないようです")
wandb_tracker.log({f"sample_{i}": wandb.Image(image)}, step=steps)
wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
except: # wandb 無効時
pass