diff --git a/library/lumina_train_util.py b/library/lumina_train_util.py index 11dd3feb..a95da382 100644 --- a/library/lumina_train_util.py +++ b/library/lumina_train_util.py @@ -680,12 +680,14 @@ def denoise( dim=tuple(range(1, len(noise_pred_cond.shape))), keepdim=True, ) - max_new_norm = cond_norm * float(renorm_cfg) - noise_norm = torch.linalg.vector_norm( + max_new_norms = cond_norm * float(renorm_cfg) + noise_norms = torch.linalg.vector_norm( noise_pred, dim=tuple(range(1, len(noise_pred.shape))), keepdim=True ) - if noise_norm >= max_new_norm: - noise_pred = noise_pred * (max_new_norm / noise_norm) + # Iterate through batch + for noise_norm, max_new_norm, noise in zip(noise_norms, max_new_norms, noise_pred): + if noise_norm >= max_new_norm: + noise = noise * (max_new_norm / noise_norm) else: noise_pred = noise_pred_cond