validation: Implement timestep-based validation processing

This commit is contained in:
Kohya S
2025-01-27 21:56:59 +09:00
parent 29f31d005f
commit 0750859133
2 changed files with 109 additions and 77 deletions

View File

@@ -446,6 +446,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
prepare_fp8(text_encoder, weight_dtype)
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
# TODO consider validation
# drop cached text encoder outputs
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None: