Update PO cached latents, move out functions, update calls

This commit is contained in:
rockerBOO
2025-04-27 17:38:50 -04:00
parent 74529743d4
commit d22c827544
11 changed files with 480 additions and 129 deletions

View File

@@ -336,24 +336,24 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
def get_noise_pred_and_target(
self,
args,
accelerator,
args: argparse.Namespace,
accelerator: Accelerator,
noise_scheduler,
latents,
batch,
latents: torch.FloatTensor,
batch: dict[str, torch.Tensor],
text_encoder_conds,
unet: flux_models.Flux,
unet,
network,
weight_dtype,
train_unet,
weight_dtype: torch.dtype,
train_unet: bool,
is_train=True,
):
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.IntTensor, torch.Tensor | None]:
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# get noisy model input and timesteps
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timestep(
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
)
@@ -448,7 +448,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
)
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
return model_pred, target, timesteps, weighting
return model_pred, noisy_model_input, target, timesteps, weighting
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
return loss