mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
Fix sample norms in batches
This commit is contained in:
@@ -680,12 +680,14 @@ def denoise(
|
||||
dim=tuple(range(1, len(noise_pred_cond.shape))),
|
||||
keepdim=True,
|
||||
)
|
||||
max_new_norm = cond_norm * float(renorm_cfg)
|
||||
noise_norm = torch.linalg.vector_norm(
|
||||
max_new_norms = cond_norm * float(renorm_cfg)
|
||||
noise_norms = torch.linalg.vector_norm(
|
||||
noise_pred, dim=tuple(range(1, len(noise_pred.shape))), keepdim=True
|
||||
)
|
||||
if noise_norm >= max_new_norm:
|
||||
noise_pred = noise_pred * (max_new_norm / noise_norm)
|
||||
# Iterate through batch
|
||||
for noise_norm, max_new_norm, noise in zip(noise_norms, max_new_norms, noise_pred):
|
||||
if noise_norm >= max_new_norm:
|
||||
noise = noise * (max_new_norm / noise_norm)
|
||||
else:
|
||||
noise_pred = noise_pred_cond
|
||||
|
||||
|
||||
Reference in New Issue
Block a user