Update PO cached latents, move out functions, update calls

This commit is contained in:
rockerBOO
2025-04-27 17:38:50 -04:00
parent 74529743d4
commit d22c827544
11 changed files with 480 additions and 129 deletions

View File

@@ -895,7 +895,7 @@ def compute_density_for_timestep_sampling(
return u
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas: torch.Tensor) -> torch.Tensor:
"""Computes loss weighting scheme for SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.