Update PO cached latents, move out functions, update calls

This commit is contained in:
rockerBOO
2025-04-27 17:38:50 -04:00
parent 74529743d4
commit d22c827544
11 changed files with 480 additions and 129 deletions

View File

@@ -323,7 +323,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
weight_dtype,
train_unet,
is_train=True,
):
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.IntTensor, torch.Tensor | None]:
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
@@ -389,7 +389,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
return model_pred, target, timesteps, weighting
return model_pred, noisy_model_input, target, timesteps, weighting
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
return loss