Fix sample norms in batches

This commit is contained in:
rockerBOO
2025-02-27 00:00:20 -05:00
parent 70403f6977
commit 542f980443

View File

@@ -680,12 +680,14 @@ def denoise(
dim=tuple(range(1, len(noise_pred_cond.shape))),
keepdim=True,
)
max_new_norm = cond_norm * float(renorm_cfg)
noise_norm = torch.linalg.vector_norm(
max_new_norms = cond_norm * float(renorm_cfg)
noise_norms = torch.linalg.vector_norm(
noise_pred, dim=tuple(range(1, len(noise_pred.shape))), keepdim=True
)
if noise_norm >= max_new_norm:
noise_pred = noise_pred * (max_new_norm / noise_norm)
# Iterate through batch
for noise_norm, max_new_norm, noise in zip(noise_norms, max_new_norms, noise_pred):
if noise_norm >= max_new_norm:
noise = noise * (max_new_norm / noise_norm)
else:
noise_pred = noise_pred_cond