diff --git a/sd3_train_network.py b/sd3_train_network.py index 2f457949..d4f13125 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -446,6 +446,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): prepare_fp8(text_encoder, weight_dtype) def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): + # TODO consider validation # drop cached text encoder outputs text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: diff --git a/train_network.py b/train_network.py index 9b8036f8..a63e9d1e 100644 --- a/train_network.py +++ b/train_network.py @@ -9,6 +9,7 @@ import random import time import json from multiprocessing import Value +import numpy as np import toml from tqdm import tqdm @@ -1248,10 +1249,6 @@ class NetworkTrainer: # log empty object to commit the sample images to wandb accelerator.log({}, step=0) - validation_steps = ( - min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) - ) - # training loop if initial_step > 0: # only if skip_until_initial_step is specified for skip_epoch in range(epoch_to_start): # skip epochs @@ -1270,6 +1267,17 @@ class NetworkTrainer: clean_memory_on_device(accelerator.device) + validation_steps = ( + min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) + ) + NUM_VALIDATION_TIMESTEPS = 4 # 200, 400, 600, 800 TODO make this configurable + min_timestep = 0 if args.min_timestep is None else args.min_timestep + max_timestep = noise_scheduler.num_train_timesteps if args.max_timestep is None else args.max_timestep + validation_timesteps = np.linspace(min_timestep, max_timestep, (NUM_VALIDATION_TIMESTEPS + 2), dtype=int)[1:-1] + validation_total_steps = validation_steps * len(validation_timesteps) + original_args_min_timestep = args.min_timestep + original_args_max_timestep = args.max_timestep + for epoch in range(epoch_to_start, num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}\n") current_epoch.value = epoch + 1 @@ -1385,12 +1393,96 @@ class NetworkTrainer: accelerator.unwrap_model(network).eval() val_progress_bar = tqdm( - range(validation_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="validation steps" + range(validation_total_steps), + smoothing=0, + disable=not accelerator.is_local_main_process, + desc="validation steps", ) + val_ts_step = 0 for val_step, batch in enumerate(val_dataloader): if val_step >= validation_steps: break + for timestep in validation_timesteps: + # temporary, for batch processing + self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) + + args.min_timestep = args.max_timestep = timestep # dirty hack to change timestep + + loss = self.process_batch( + batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy, + tokenize_strategy, + is_train=False, + train_text_encoder=train_text_encoder, # this is needed for validation because Text Encoders must be called if train_text_encoder is True + train_unet=train_unet, + ) + + current_loss = loss.detach().item() + val_step_loss_recorder.add(epoch=epoch, step=val_ts_step, loss=current_loss) + val_progress_bar.update(1) + val_progress_bar.set_postfix( + {"val_avg_loss": val_step_loss_recorder.moving_average, "timestep": timestep} + ) + + if is_tracking: + logs = { + "loss/validation/step_current": current_loss, + "val_step": (epoch * validation_total_steps) + val_ts_step, + } + accelerator.log(logs, step=global_step) + + val_ts_step += 1 + + if is_tracking: + loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average + logs = { + "loss/validation/step_average": val_step_loss_recorder.moving_average, + "loss/validation/step_divergence": loss_validation_divergence, + } + accelerator.log(logs, step=global_step) + + args.min_timestep = original_args_min_timestep + args.max_timestep = original_args_max_timestep + optimizer_train_fn() + accelerator.unwrap_model(network).train() + + if global_step >= args.max_train_steps: + break + + # EPOCH VALIDATION + should_validate_epoch = ( + (epoch + 1) % args.validate_every_n_epochs == 0 if args.validate_every_n_epochs is not None else True + ) + + if should_validate_epoch and len(val_dataloader) > 0: + optimizer_eval_fn() + accelerator.unwrap_model(network).eval() + + val_progress_bar = tqdm( + range(validation_total_steps), + smoothing=0, + disable=not accelerator.is_local_main_process, + desc="epoch validation steps", + ) + + val_ts_step = 0 + for val_step, batch in enumerate(val_dataloader): + if val_step >= validation_steps: + break + + for timestep in validation_timesteps: + args.min_timestep = args.max_timestep = timestep + # temporary, for batch processing self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) @@ -1408,89 +1500,26 @@ class NetworkTrainer: text_encoding_strategy, tokenize_strategy, is_train=False, - train_text_encoder=train_text_encoder, # this is needed for validation because Text Encoders must be called if train_text_encoder is True + train_text_encoder=train_text_encoder, train_unet=train_unet, ) current_loss = loss.detach().item() - val_step_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) + val_epoch_loss_recorder.add(epoch=epoch, step=val_ts_step, loss=current_loss) val_progress_bar.update(1) - val_progress_bar.set_postfix({"val_avg_loss": val_step_loss_recorder.moving_average}) + val_progress_bar.set_postfix( + {"val_epoch_avg_loss": val_epoch_loss_recorder.moving_average, "timestep": timestep} + ) if is_tracking: logs = { - "loss/validation/step_current": current_loss, - "val_step": (epoch * validation_steps) + val_step, + "loss/validation/epoch_current": current_loss, + "epoch": epoch + 1, + "val_step": (epoch * validation_total_steps) + val_ts_step, } accelerator.log(logs, step=global_step) - if is_tracking: - loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average - logs = { - "loss/validation/step_average": val_step_loss_recorder.moving_average, - "loss/validation/step_divergence": loss_validation_divergence, - } - accelerator.log(logs, step=global_step) - - optimizer_train_fn() - accelerator.unwrap_model(network).train() - - if global_step >= args.max_train_steps: - break - - # EPOCH VALIDATION - should_validate_epoch = ( - (epoch + 1) % args.validate_every_n_epochs == 0 if args.validate_every_n_epochs is not None else True - ) - - if should_validate_epoch and len(val_dataloader) > 0: - optimizer_eval_fn() - accelerator.unwrap_model(network).eval() - - val_progress_bar = tqdm( - range(validation_steps), - smoothing=0, - disable=not accelerator.is_local_main_process, - desc="epoch validation steps", - ) - - for val_step, batch in enumerate(val_dataloader): - if val_step >= validation_steps: - break - - # temporary, for batch processing - self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) - - loss = self.process_batch( - batch, - text_encoders, - unet, - network, - vae, - noise_scheduler, - vae_dtype, - weight_dtype, - accelerator, - args, - text_encoding_strategy, - tokenize_strategy, - is_train=False, - train_text_encoder=train_text_encoder, - train_unet=train_unet, - ) - - current_loss = loss.detach().item() - val_epoch_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) - val_progress_bar.update(1) - val_progress_bar.set_postfix({"val_epoch_avg_loss": val_epoch_loss_recorder.moving_average}) - - if is_tracking: - logs = { - "loss/validation/epoch_current": current_loss, - "epoch": epoch + 1, - "val_step": (epoch * validation_steps) + val_step, - } - accelerator.log(logs, step=global_step) + val_ts_step += 1 if is_tracking: avr_loss: float = val_epoch_loss_recorder.moving_average @@ -1502,6 +1531,8 @@ class NetworkTrainer: } accelerator.log(logs, step=global_step) + args.min_timestep = original_args_min_timestep + args.max_timestep = original_args_max_timestep optimizer_train_fn() accelerator.unwrap_model(network).train()