validation: Implement timestep-based validation processing

This commit is contained in:
Kohya S
2025-01-27 21:56:59 +09:00
parent 29f31d005f
commit 0750859133
2 changed files with 109 additions and 77 deletions

View File

@@ -446,6 +446,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
prepare_fp8(text_encoder, weight_dtype) prepare_fp8(text_encoder, weight_dtype)
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
# TODO consider validation
# drop cached text encoder outputs # drop cached text encoder outputs
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None: if text_encoder_outputs_list is not None:

View File

@@ -9,6 +9,7 @@ import random
import time import time
import json import json
from multiprocessing import Value from multiprocessing import Value
import numpy as np
import toml import toml
from tqdm import tqdm from tqdm import tqdm
@@ -1248,10 +1249,6 @@ class NetworkTrainer:
# log empty object to commit the sample images to wandb # log empty object to commit the sample images to wandb
accelerator.log({}, step=0) accelerator.log({}, step=0)
validation_steps = (
min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader)
)
# training loop # training loop
if initial_step > 0: # only if skip_until_initial_step is specified if initial_step > 0: # only if skip_until_initial_step is specified
for skip_epoch in range(epoch_to_start): # skip epochs for skip_epoch in range(epoch_to_start): # skip epochs
@@ -1270,6 +1267,17 @@ class NetworkTrainer:
clean_memory_on_device(accelerator.device) clean_memory_on_device(accelerator.device)
validation_steps = (
min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader)
)
NUM_VALIDATION_TIMESTEPS = 4 # 200, 400, 600, 800 TODO make this configurable
min_timestep = 0 if args.min_timestep is None else args.min_timestep
max_timestep = noise_scheduler.num_train_timesteps if args.max_timestep is None else args.max_timestep
validation_timesteps = np.linspace(min_timestep, max_timestep, (NUM_VALIDATION_TIMESTEPS + 2), dtype=int)[1:-1]
validation_total_steps = validation_steps * len(validation_timesteps)
original_args_min_timestep = args.min_timestep
original_args_max_timestep = args.max_timestep
for epoch in range(epoch_to_start, num_train_epochs): for epoch in range(epoch_to_start, num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}\n") accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}\n")
current_epoch.value = epoch + 1 current_epoch.value = epoch + 1
@@ -1385,15 +1393,22 @@ class NetworkTrainer:
accelerator.unwrap_model(network).eval() accelerator.unwrap_model(network).eval()
val_progress_bar = tqdm( val_progress_bar = tqdm(
range(validation_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="validation steps" range(validation_total_steps),
smoothing=0,
disable=not accelerator.is_local_main_process,
desc="validation steps",
) )
val_ts_step = 0
for val_step, batch in enumerate(val_dataloader): for val_step, batch in enumerate(val_dataloader):
if val_step >= validation_steps: if val_step >= validation_steps:
break break
for timestep in validation_timesteps:
# temporary, for batch processing # temporary, for batch processing
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
args.min_timestep = args.max_timestep = timestep # dirty hack to change timestep
loss = self.process_batch( loss = self.process_batch(
batch, batch,
text_encoders, text_encoders,
@@ -1413,17 +1428,21 @@ class NetworkTrainer:
) )
current_loss = loss.detach().item() current_loss = loss.detach().item()
val_step_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) val_step_loss_recorder.add(epoch=epoch, step=val_ts_step, loss=current_loss)
val_progress_bar.update(1) val_progress_bar.update(1)
val_progress_bar.set_postfix({"val_avg_loss": val_step_loss_recorder.moving_average}) val_progress_bar.set_postfix(
{"val_avg_loss": val_step_loss_recorder.moving_average, "timestep": timestep}
)
if is_tracking: if is_tracking:
logs = { logs = {
"loss/validation/step_current": current_loss, "loss/validation/step_current": current_loss,
"val_step": (epoch * validation_steps) + val_step, "val_step": (epoch * validation_total_steps) + val_ts_step,
} }
accelerator.log(logs, step=global_step) accelerator.log(logs, step=global_step)
val_ts_step += 1
if is_tracking: if is_tracking:
loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average
logs = { logs = {
@@ -1432,6 +1451,8 @@ class NetworkTrainer:
} }
accelerator.log(logs, step=global_step) accelerator.log(logs, step=global_step)
args.min_timestep = original_args_min_timestep
args.max_timestep = original_args_max_timestep
optimizer_train_fn() optimizer_train_fn()
accelerator.unwrap_model(network).train() accelerator.unwrap_model(network).train()
@@ -1448,16 +1469,20 @@ class NetworkTrainer:
accelerator.unwrap_model(network).eval() accelerator.unwrap_model(network).eval()
val_progress_bar = tqdm( val_progress_bar = tqdm(
range(validation_steps), range(validation_total_steps),
smoothing=0, smoothing=0,
disable=not accelerator.is_local_main_process, disable=not accelerator.is_local_main_process,
desc="epoch validation steps", desc="epoch validation steps",
) )
val_ts_step = 0
for val_step, batch in enumerate(val_dataloader): for val_step, batch in enumerate(val_dataloader):
if val_step >= validation_steps: if val_step >= validation_steps:
break break
for timestep in validation_timesteps:
args.min_timestep = args.max_timestep = timestep
# temporary, for batch processing # temporary, for batch processing
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
@@ -1480,18 +1505,22 @@ class NetworkTrainer:
) )
current_loss = loss.detach().item() current_loss = loss.detach().item()
val_epoch_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) val_epoch_loss_recorder.add(epoch=epoch, step=val_ts_step, loss=current_loss)
val_progress_bar.update(1) val_progress_bar.update(1)
val_progress_bar.set_postfix({"val_epoch_avg_loss": val_epoch_loss_recorder.moving_average}) val_progress_bar.set_postfix(
{"val_epoch_avg_loss": val_epoch_loss_recorder.moving_average, "timestep": timestep}
)
if is_tracking: if is_tracking:
logs = { logs = {
"loss/validation/epoch_current": current_loss, "loss/validation/epoch_current": current_loss,
"epoch": epoch + 1, "epoch": epoch + 1,
"val_step": (epoch * validation_steps) + val_step, "val_step": (epoch * validation_total_steps) + val_ts_step,
} }
accelerator.log(logs, step=global_step) accelerator.log(logs, step=global_step)
val_ts_step += 1
if is_tracking: if is_tracking:
avr_loss: float = val_epoch_loss_recorder.moving_average avr_loss: float = val_epoch_loss_recorder.moving_average
loss_validation_divergence = val_step_loss_recorder.moving_average - avr_loss loss_validation_divergence = val_step_loss_recorder.moving_average - avr_loss
@@ -1502,6 +1531,8 @@ class NetworkTrainer:
} }
accelerator.log(logs, step=global_step) accelerator.log(logs, step=global_step)
args.min_timestep = original_args_min_timestep
args.max_timestep = original_args_max_timestep
optimizer_train_fn() optimizer_train_fn()
accelerator.unwrap_model(network).train() accelerator.unwrap_model(network).train()