mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-17 17:24:21 +00:00
Update PO cached latents, move out functions, update calls
This commit is contained in:
@@ -336,24 +336,24 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
|
||||
def get_noise_pred_and_target(
|
||||
self,
|
||||
args,
|
||||
accelerator,
|
||||
args: argparse.Namespace,
|
||||
accelerator: Accelerator,
|
||||
noise_scheduler,
|
||||
latents,
|
||||
batch,
|
||||
latents: torch.FloatTensor,
|
||||
batch: dict[str, torch.Tensor],
|
||||
text_encoder_conds,
|
||||
unet: flux_models.Flux,
|
||||
unet,
|
||||
network,
|
||||
weight_dtype,
|
||||
train_unet,
|
||||
weight_dtype: torch.dtype,
|
||||
train_unet: bool,
|
||||
is_train=True,
|
||||
):
|
||||
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.IntTensor, torch.Tensor | None]:
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
|
||||
# get noisy model input and timesteps
|
||||
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
|
||||
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timestep(
|
||||
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
|
||||
)
|
||||
|
||||
@@ -448,7 +448,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
)
|
||||
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
|
||||
|
||||
return model_pred, target, timesteps, weighting
|
||||
return model_pred, noisy_model_input, target, timesteps, weighting
|
||||
|
||||
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
||||
return loss
|
||||
|
||||
Reference in New Issue
Block a user