mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
Fix sample norms in batches
This commit is contained in:
@@ -680,12 +680,14 @@ def denoise(
|
|||||||
dim=tuple(range(1, len(noise_pred_cond.shape))),
|
dim=tuple(range(1, len(noise_pred_cond.shape))),
|
||||||
keepdim=True,
|
keepdim=True,
|
||||||
)
|
)
|
||||||
max_new_norm = cond_norm * float(renorm_cfg)
|
max_new_norms = cond_norm * float(renorm_cfg)
|
||||||
noise_norm = torch.linalg.vector_norm(
|
noise_norms = torch.linalg.vector_norm(
|
||||||
noise_pred, dim=tuple(range(1, len(noise_pred.shape))), keepdim=True
|
noise_pred, dim=tuple(range(1, len(noise_pred.shape))), keepdim=True
|
||||||
)
|
)
|
||||||
if noise_norm >= max_new_norm:
|
# Iterate through batch
|
||||||
noise_pred = noise_pred * (max_new_norm / noise_norm)
|
for noise_norm, max_new_norm, noise in zip(noise_norms, max_new_norms, noise_pred):
|
||||||
|
if noise_norm >= max_new_norm:
|
||||||
|
noise = noise * (max_new_norm / noise_norm)
|
||||||
else:
|
else:
|
||||||
noise_pred = noise_pred_cond
|
noise_pred = noise_pred_cond
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user