mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 16:39:42 +00:00
Clean up code
This commit is contained in:
@@ -70,19 +70,7 @@ def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False
|
||||
|
||||
def apply_soft_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False):
|
||||
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
|
||||
# min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
|
||||
soft_min_snr_gamma_weight = 1 / (torch.pow(snr, 2) + (1 / float(gamma)))
|
||||
with open("snr.txt", "a") as myfile:
|
||||
myfile.write(f"{snr.item()},{gamma}\n")
|
||||
|
||||
# with open("snrmin.txt", "a") as myfile:
|
||||
# myfile.write(f"{min_snr_gamma.item()},{soft_min_snr_gamma.item()}\n")
|
||||
# print("soft_min_snr_gamma", soft_min_snr_gamma, 1 / (snr + (1 / float(gamma))))
|
||||
# print("min_snr_gamma", min_snr_gamma)
|
||||
# if v_prediction:
|
||||
# snr_weight = torch.div(soft_min_snr_gamma, snr+1).float().to(loss.device)
|
||||
# else:
|
||||
# snr_weight = torch.div(soft_min_snr_gamma, snr).float().to(loss.device)
|
||||
soft_min_snr_gamma_weight = 1 / (torch.pow(snr if v_prediction is False else snr + 1, 2) + (1 / float(gamma)))
|
||||
loss = loss * soft_min_snr_gamma_weight
|
||||
return loss
|
||||
|
||||
|
||||
@@ -4809,7 +4809,7 @@ def sample_images_common(
|
||||
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
|
||||
raise ImportError("No wandb / wandb がインストールされていないようです")
|
||||
|
||||
wandb_tracker.log({f"sample_{i}": wandb.Image(image)}, step=steps)
|
||||
wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
|
||||
except: # wandb 無効時
|
||||
pass
|
||||
|
||||
|
||||
Reference in New Issue
Block a user