mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
validation: Implement timestep-based validation processing
This commit is contained in:
185
train_network.py
185
train_network.py
@@ -9,6 +9,7 @@ import random
|
||||
import time
|
||||
import json
|
||||
from multiprocessing import Value
|
||||
import numpy as np
|
||||
import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
@@ -1248,10 +1249,6 @@ class NetworkTrainer:
|
||||
# log empty object to commit the sample images to wandb
|
||||
accelerator.log({}, step=0)
|
||||
|
||||
validation_steps = (
|
||||
min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader)
|
||||
)
|
||||
|
||||
# training loop
|
||||
if initial_step > 0: # only if skip_until_initial_step is specified
|
||||
for skip_epoch in range(epoch_to_start): # skip epochs
|
||||
@@ -1270,6 +1267,17 @@ class NetworkTrainer:
|
||||
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
validation_steps = (
|
||||
min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader)
|
||||
)
|
||||
NUM_VALIDATION_TIMESTEPS = 4 # 200, 400, 600, 800 TODO make this configurable
|
||||
min_timestep = 0 if args.min_timestep is None else args.min_timestep
|
||||
max_timestep = noise_scheduler.num_train_timesteps if args.max_timestep is None else args.max_timestep
|
||||
validation_timesteps = np.linspace(min_timestep, max_timestep, (NUM_VALIDATION_TIMESTEPS + 2), dtype=int)[1:-1]
|
||||
validation_total_steps = validation_steps * len(validation_timesteps)
|
||||
original_args_min_timestep = args.min_timestep
|
||||
original_args_max_timestep = args.max_timestep
|
||||
|
||||
for epoch in range(epoch_to_start, num_train_epochs):
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}\n")
|
||||
current_epoch.value = epoch + 1
|
||||
@@ -1385,12 +1393,96 @@ class NetworkTrainer:
|
||||
accelerator.unwrap_model(network).eval()
|
||||
|
||||
val_progress_bar = tqdm(
|
||||
range(validation_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="validation steps"
|
||||
range(validation_total_steps),
|
||||
smoothing=0,
|
||||
disable=not accelerator.is_local_main_process,
|
||||
desc="validation steps",
|
||||
)
|
||||
val_ts_step = 0
|
||||
for val_step, batch in enumerate(val_dataloader):
|
||||
if val_step >= validation_steps:
|
||||
break
|
||||
|
||||
for timestep in validation_timesteps:
|
||||
# temporary, for batch processing
|
||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
||||
|
||||
args.min_timestep = args.max_timestep = timestep # dirty hack to change timestep
|
||||
|
||||
loss = self.process_batch(
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
network,
|
||||
vae,
|
||||
noise_scheduler,
|
||||
vae_dtype,
|
||||
weight_dtype,
|
||||
accelerator,
|
||||
args,
|
||||
text_encoding_strategy,
|
||||
tokenize_strategy,
|
||||
is_train=False,
|
||||
train_text_encoder=train_text_encoder, # this is needed for validation because Text Encoders must be called if train_text_encoder is True
|
||||
train_unet=train_unet,
|
||||
)
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
val_step_loss_recorder.add(epoch=epoch, step=val_ts_step, loss=current_loss)
|
||||
val_progress_bar.update(1)
|
||||
val_progress_bar.set_postfix(
|
||||
{"val_avg_loss": val_step_loss_recorder.moving_average, "timestep": timestep}
|
||||
)
|
||||
|
||||
if is_tracking:
|
||||
logs = {
|
||||
"loss/validation/step_current": current_loss,
|
||||
"val_step": (epoch * validation_total_steps) + val_ts_step,
|
||||
}
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
val_ts_step += 1
|
||||
|
||||
if is_tracking:
|
||||
loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average
|
||||
logs = {
|
||||
"loss/validation/step_average": val_step_loss_recorder.moving_average,
|
||||
"loss/validation/step_divergence": loss_validation_divergence,
|
||||
}
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
args.min_timestep = original_args_min_timestep
|
||||
args.max_timestep = original_args_max_timestep
|
||||
optimizer_train_fn()
|
||||
accelerator.unwrap_model(network).train()
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
# EPOCH VALIDATION
|
||||
should_validate_epoch = (
|
||||
(epoch + 1) % args.validate_every_n_epochs == 0 if args.validate_every_n_epochs is not None else True
|
||||
)
|
||||
|
||||
if should_validate_epoch and len(val_dataloader) > 0:
|
||||
optimizer_eval_fn()
|
||||
accelerator.unwrap_model(network).eval()
|
||||
|
||||
val_progress_bar = tqdm(
|
||||
range(validation_total_steps),
|
||||
smoothing=0,
|
||||
disable=not accelerator.is_local_main_process,
|
||||
desc="epoch validation steps",
|
||||
)
|
||||
|
||||
val_ts_step = 0
|
||||
for val_step, batch in enumerate(val_dataloader):
|
||||
if val_step >= validation_steps:
|
||||
break
|
||||
|
||||
for timestep in validation_timesteps:
|
||||
args.min_timestep = args.max_timestep = timestep
|
||||
|
||||
# temporary, for batch processing
|
||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
||||
|
||||
@@ -1408,89 +1500,26 @@ class NetworkTrainer:
|
||||
text_encoding_strategy,
|
||||
tokenize_strategy,
|
||||
is_train=False,
|
||||
train_text_encoder=train_text_encoder, # this is needed for validation because Text Encoders must be called if train_text_encoder is True
|
||||
train_text_encoder=train_text_encoder,
|
||||
train_unet=train_unet,
|
||||
)
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
val_step_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
|
||||
val_epoch_loss_recorder.add(epoch=epoch, step=val_ts_step, loss=current_loss)
|
||||
val_progress_bar.update(1)
|
||||
val_progress_bar.set_postfix({"val_avg_loss": val_step_loss_recorder.moving_average})
|
||||
val_progress_bar.set_postfix(
|
||||
{"val_epoch_avg_loss": val_epoch_loss_recorder.moving_average, "timestep": timestep}
|
||||
)
|
||||
|
||||
if is_tracking:
|
||||
logs = {
|
||||
"loss/validation/step_current": current_loss,
|
||||
"val_step": (epoch * validation_steps) + val_step,
|
||||
"loss/validation/epoch_current": current_loss,
|
||||
"epoch": epoch + 1,
|
||||
"val_step": (epoch * validation_total_steps) + val_ts_step,
|
||||
}
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if is_tracking:
|
||||
loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average
|
||||
logs = {
|
||||
"loss/validation/step_average": val_step_loss_recorder.moving_average,
|
||||
"loss/validation/step_divergence": loss_validation_divergence,
|
||||
}
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
optimizer_train_fn()
|
||||
accelerator.unwrap_model(network).train()
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
# EPOCH VALIDATION
|
||||
should_validate_epoch = (
|
||||
(epoch + 1) % args.validate_every_n_epochs == 0 if args.validate_every_n_epochs is not None else True
|
||||
)
|
||||
|
||||
if should_validate_epoch and len(val_dataloader) > 0:
|
||||
optimizer_eval_fn()
|
||||
accelerator.unwrap_model(network).eval()
|
||||
|
||||
val_progress_bar = tqdm(
|
||||
range(validation_steps),
|
||||
smoothing=0,
|
||||
disable=not accelerator.is_local_main_process,
|
||||
desc="epoch validation steps",
|
||||
)
|
||||
|
||||
for val_step, batch in enumerate(val_dataloader):
|
||||
if val_step >= validation_steps:
|
||||
break
|
||||
|
||||
# temporary, for batch processing
|
||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
||||
|
||||
loss = self.process_batch(
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
network,
|
||||
vae,
|
||||
noise_scheduler,
|
||||
vae_dtype,
|
||||
weight_dtype,
|
||||
accelerator,
|
||||
args,
|
||||
text_encoding_strategy,
|
||||
tokenize_strategy,
|
||||
is_train=False,
|
||||
train_text_encoder=train_text_encoder,
|
||||
train_unet=train_unet,
|
||||
)
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
val_epoch_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
|
||||
val_progress_bar.update(1)
|
||||
val_progress_bar.set_postfix({"val_epoch_avg_loss": val_epoch_loss_recorder.moving_average})
|
||||
|
||||
if is_tracking:
|
||||
logs = {
|
||||
"loss/validation/epoch_current": current_loss,
|
||||
"epoch": epoch + 1,
|
||||
"val_step": (epoch * validation_steps) + val_step,
|
||||
}
|
||||
accelerator.log(logs, step=global_step)
|
||||
val_ts_step += 1
|
||||
|
||||
if is_tracking:
|
||||
avr_loss: float = val_epoch_loss_recorder.moving_average
|
||||
@@ -1502,6 +1531,8 @@ class NetworkTrainer:
|
||||
}
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
args.min_timestep = original_args_min_timestep
|
||||
args.max_timestep = original_args_max_timestep
|
||||
optimizer_train_fn()
|
||||
accelerator.unwrap_model(network).train()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user