Merge pull request #1903 from kohya-ss/val-loss-improvement

Val loss improvement
This commit is contained in:
Kohya S.
2025-02-26 21:15:14 +09:00
committed by GitHub
4 changed files with 302 additions and 238 deletions

View File

@@ -381,8 +381,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
t5_attn_mask = None
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
# if not args.split_mode:
# normal forward
# grad is enabled even if unet is not in train mode, because Text Encoder is in train mode
with torch.set_grad_enabled(is_train), accelerator.autocast():
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
model_pred = unet(
@@ -395,44 +394,6 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
)
"""
else:
# split forward to reduce memory usage
assert network.train_blocks == "single", "train_blocks must be single for split mode"
with accelerator.autocast():
# move flux lower to cpu, and then move flux upper to gpu
unet.to("cpu")
clean_memory_on_device(accelerator.device)
self.flux_upper.to(accelerator.device)
# upper model does not require grad
with torch.no_grad():
intermediate_img, intermediate_txt, vec, pe = self.flux_upper(
img=packed_noisy_model_input,
img_ids=img_ids,
txt=t5_out,
txt_ids=txt_ids,
y=l_pooled,
timesteps=timesteps / 1000,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
)
# move flux upper back to cpu, and then move flux lower to gpu
self.flux_upper.to("cpu")
clean_memory_on_device(accelerator.device)
unet.to(accelerator.device)
# lower model requires grad
intermediate_img.requires_grad_(True)
intermediate_txt.requires_grad_(True)
vec.requires_grad_(True)
pe.requires_grad_(True)
with torch.set_grad_enabled(is_train and train_unet):
model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)
"""
return model_pred
model_pred = call_dit(
@@ -551,6 +512,11 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
text_encoder.to(te_weight_dtype) # fp8
prepare_fp8(text_encoder, weight_dtype)
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
if self.is_swapping_blocks:
# prepare for next forward: because backward pass is not called, we need to prepare it here
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
def prepare_unet_with_accelerator(
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
) -> torch.nn.Module:

View File

@@ -13,17 +13,7 @@ import re
import shutil
import time
import typing
from typing import (
Any,
Callable,
Dict,
List,
NamedTuple,
Optional,
Sequence,
Tuple,
Union
)
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState
import glob
import math
@@ -146,12 +136,13 @@ IMAGE_TRANSFORMS = transforms.Compose(
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz"
def split_train_val(
paths: List[str],
paths: List[str],
sizes: List[Optional[Tuple[int, int]]],
is_training_dataset: bool,
validation_split: float,
validation_seed: int | None
is_training_dataset: bool,
validation_split: float,
validation_seed: int | None,
) -> Tuple[List[str], List[Optional[Tuple[int, int]]]]:
"""
Split the dataset into train and validation
@@ -1842,7 +1833,7 @@ class BaseDataset(torch.utils.data.Dataset):
class DreamBoothDataset(BaseDataset):
IMAGE_INFO_CACHE_FILE = "metadata_cache.json"
# The is_training_dataset defines the type of dataset, training or validation
# The is_training_dataset defines the type of dataset, training or validation
# if is_training_dataset is True -> training dataset
# if is_training_dataset is False -> validation dataset
def __init__(
@@ -1981,29 +1972,25 @@ class DreamBoothDataset(BaseDataset):
logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}")
# We want to create a training and validation split. This should be improved in the future
# to allow a clearer distinction between training and validation. This can be seen as a
# to allow a clearer distinction between training and validation. This can be seen as a
# short-term solution to limit what is necessary to implement validation datasets
#
#
# We split the dataset for the subset based on if we are doing a validation split
# The self.is_training_dataset defines the type of dataset, training or validation
# The self.is_training_dataset defines the type of dataset, training or validation
# if self.is_training_dataset is True -> training dataset
# if self.is_training_dataset is False -> validation dataset
if self.validation_split > 0.0:
# For regularization images we do not want to split this dataset.
# For regularization images we do not want to split this dataset.
if subset.is_reg is True:
# Skip any validation dataset for regularization images
if self.is_training_dataset is False:
img_paths = []
sizes = []
# Otherwise the img_paths remain as original img_paths and no split
# Otherwise the img_paths remain as original img_paths and no split
# required for training images dataset of regularization images
else:
img_paths, sizes = split_train_val(
img_paths,
sizes,
self.is_training_dataset,
self.validation_split,
self.validation_seed
img_paths, sizes, self.is_training_dataset, self.validation_split, self.validation_seed
)
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
@@ -2373,7 +2360,7 @@ class ControlNetDataset(BaseDataset):
bucket_no_upscale: bool,
debug_dataset: bool,
validation_split: float,
validation_seed: Optional[int],
validation_seed: Optional[int],
) -> None:
super().__init__(resolution, network_multiplier, debug_dataset)
@@ -2431,9 +2418,9 @@ class ControlNetDataset(BaseDataset):
self.image_data = self.dreambooth_dataset_delegate.image_data
self.batch_size = batch_size
self.num_train_images = self.dreambooth_dataset_delegate.num_train_images
self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images
self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images
self.validation_split = validation_split
self.validation_seed = validation_seed
self.validation_seed = validation_seed
# assert all conditioning data exists
missing_imgs = []
@@ -5944,12 +5931,17 @@ def save_sd_model_on_train_end_common(
def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: torch.device) -> torch.Tensor:
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu")
if min_timestep < max_timestep:
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu")
else:
timesteps = torch.full((b_size,), max_timestep, device="cpu")
timesteps = timesteps.long().to(device)
return timesteps
def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor]:
def get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents: torch.FloatTensor
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor]:
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:
@@ -6441,7 +6433,7 @@ def sample_image_inference(
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
def init_trackers(accelerator: Accelerator, args: argparse.Namespace, default_tracker_name: str):
def init_trackers(accelerator: Accelerator, args: argparse.Namespace, default_tracker_name: str):
"""
Initialize experiment trackers with tracker specific behaviors
"""
@@ -6458,13 +6450,17 @@ def init_trackers(accelerator: Accelerator, args: argparse.Namespace, default_tr
)
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
import wandb
import wandb
wandb_tracker = accelerator.get_tracker("wandb", unwrap=True)
# Define specific metrics to handle validation and epochs "steps"
wandb_tracker.define_metric("epoch", hidden=True)
wandb_tracker.define_metric("val_step", hidden=True)
wandb_tracker.define_metric("global_step", hidden=True)
# endregion

View File

@@ -450,14 +450,19 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
text_encoder.to(te_weight_dtype) # fp8
prepare_fp8(text_encoder, weight_dtype)
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
# drop cached text encoder outputs
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True):
# drop cached text encoder outputs: in validation, we drop cached outputs deterministically by fixed seed
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list)
batch["text_encoder_outputs_list"] = text_encoder_outputs_list
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
if self.is_swapping_blocks:
# prepare for next forward: because backward pass is not called, we need to prepare it here
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
def prepare_unet_with_accelerator(
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
) -> torch.nn.Module:

View File

@@ -9,6 +9,7 @@ import random
import time
import json
from multiprocessing import Value
import numpy as np
import toml
from tqdm import tqdm
@@ -100,9 +101,7 @@ class NetworkTrainer:
if (
args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None
): # tracking d*lr value of unet.
logs["lr/d*lr"] = (
optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"]
)
logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"]
else:
idx = 0
if not args.network_train_unet_only:
@@ -115,16 +114,56 @@ class NetworkTrainer:
logs[f"lr/d*lr/group{i}"] = (
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
)
if (
args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None
):
logs[f"lr/d*lr/group{i}"] = (
optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"]
)
if args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None:
logs[f"lr/d*lr/group{i}"] = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"]
return logs
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
def step_logging(self, accelerator: Accelerator, logs: dict, global_step: int, epoch: int):
self.accelerator_logging(accelerator, logs, global_step, global_step, epoch)
def epoch_logging(self, accelerator: Accelerator, logs: dict, global_step: int, epoch: int):
self.accelerator_logging(accelerator, logs, epoch, global_step, epoch)
def val_logging(self, accelerator: Accelerator, logs: dict, global_step: int, epoch: int, val_step: int):
self.accelerator_logging(accelerator, logs, global_step + val_step, global_step, epoch, val_step)
def accelerator_logging(
self, accelerator: Accelerator, logs: dict, step_value: int, global_step: int, epoch: int, val_step: Optional[int] = None
):
"""
step_value is for tensorboard, other values are for wandb
"""
tensorboard_tracker = None
wandb_tracker = None
other_trackers = []
for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
tensorboard_tracker = accelerator.get_tracker("tensorboard")
elif tracker.name == "wandb":
wandb_tracker = accelerator.get_tracker("wandb")
else:
other_trackers.append(accelerator.get_tracker(tracker.name))
if tensorboard_tracker is not None:
tensorboard_tracker.log(logs, step=step_value)
if wandb_tracker is not None:
logs["global_step"] = global_step
logs["epoch"] = epoch
if val_step is not None:
logs["val_step"] = val_step
wandb_tracker.log(logs)
for tracker in other_trackers:
tracker.log(logs, step=step_value)
def assert_extra_args(
self,
args,
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
val_dataset_group: Optional[train_util.DatasetGroup],
):
train_dataset_group.verify_bucket_reso_steps(64)
if val_dataset_group is not None:
val_dataset_group.verify_bucket_reso_steps(64)
@@ -219,7 +258,7 @@ class NetworkTrainer:
network,
weight_dtype,
train_unet,
is_train=True
is_train=True,
):
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
@@ -309,28 +348,31 @@ class NetworkTrainer:
) -> torch.nn.Module:
return accelerator.prepare(unet)
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train: bool = True):
pass
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
pass
# endregion
def process_batch(
self,
batch,
text_encoders,
unet,
network,
vae,
noise_scheduler,
vae_dtype,
weight_dtype,
accelerator,
args,
text_encoding_strategy: strategy_base.TextEncodingStrategy,
tokenize_strategy: strategy_base.TokenizeStrategy,
is_train=True,
train_text_encoder=True,
train_unet=True
self,
batch,
text_encoders,
unet,
network,
vae,
noise_scheduler,
vae_dtype,
weight_dtype,
accelerator,
args,
text_encoding_strategy: strategy_base.TextEncodingStrategy,
tokenize_strategy: strategy_base.TokenizeStrategy,
is_train=True,
train_text_encoder=True,
train_unet=True,
) -> torch.Tensor:
"""
Process a batch for the network
@@ -397,7 +439,7 @@ class NetworkTrainer:
network,
weight_dtype,
train_unet,
is_train=is_train
is_train=is_train,
)
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
@@ -484,7 +526,7 @@ class NetworkTrainer:
else:
# use arbitrary dataset class
train_dataset_group = train_util.load_arbitrary_dataset(args)
val_dataset_group = None # placeholder until validation dataset supported for arbitrary
val_dataset_group = None # placeholder until validation dataset supported for arbitrary
current_epoch = Value("i", 0)
current_step = Value("i", 0)
@@ -701,7 +743,7 @@ class NetworkTrainer:
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
val_dataloader = torch.utils.data.DataLoader(
val_dataset_group if val_dataset_group is not None else [],
shuffle=False,
@@ -900,7 +942,9 @@ class NetworkTrainer:
accelerator.print("running training / 学習開始")
accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
accelerator.print(f" num validation images * repeats / 学習画像の数×繰り返し回数: {val_dataset_group.num_train_images if val_dataset_group is not None else 0}")
accelerator.print(
f" num validation images * repeats / 学習画像の数×繰り返し回数: {val_dataset_group.num_train_images if val_dataset_group is not None else 0}"
)
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
@@ -968,11 +1012,11 @@ class NetworkTrainer:
"ss_huber_c": args.huber_c,
"ss_fp8_base": bool(args.fp8_base),
"ss_fp8_base_unet": bool(args.fp8_base_unet),
"ss_validation_seed": args.validation_seed,
"ss_validation_split": args.validation_split,
"ss_max_validation_steps": args.max_validation_steps,
"ss_validate_every_n_epochs": args.validate_every_n_epochs,
"ss_validate_every_n_steps": args.validate_every_n_steps,
"ss_validation_seed": args.validation_seed,
"ss_validation_split": args.validation_split,
"ss_max_validation_steps": args.max_validation_steps,
"ss_validate_every_n_epochs": args.validate_every_n_epochs,
"ss_validate_every_n_steps": args.validate_every_n_steps,
}
self.update_metadata(metadata, args) # architecture specific metadata
@@ -1243,12 +1287,6 @@ class NetworkTrainer:
# log empty object to commit the sample images to wandb
accelerator.log({}, step=0)
validation_steps = (
min(args.max_validation_steps, len(val_dataloader))
if args.max_validation_steps is not None
else len(val_dataloader)
)
# training loop
if initial_step > 0: # only if skip_until_initial_step is specified
for skip_epoch in range(epoch_to_start): # skip epochs
@@ -1271,13 +1309,53 @@ class NetworkTrainer:
range(args.max_train_steps - initial_step), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps"
)
validation_steps = (
min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader)
)
NUM_VALIDATION_TIMESTEPS = 4 # 200, 400, 600, 800 TODO make this configurable
min_timestep = 0 if args.min_timestep is None else args.min_timestep
max_timestep = noise_scheduler.num_train_timesteps if args.max_timestep is None else args.max_timestep
validation_timesteps = np.linspace(min_timestep, max_timestep, (NUM_VALIDATION_TIMESTEPS + 2), dtype=int)[1:-1]
validation_total_steps = validation_steps * len(validation_timesteps)
original_args_min_timestep = args.min_timestep
original_args_max_timestep = args.max_timestep
def switch_rng_state(seed: int) -> tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]:
cpu_rng_state = torch.get_rng_state()
if accelerator.device.type == "cuda":
gpu_rng_state = torch.cuda.get_rng_state()
elif accelerator.device.type == "xpu":
gpu_rng_state = torch.xpu.get_rng_state()
elif accelerator.device.type == "mps":
gpu_rng_state = torch.cuda.get_rng_state()
else:
gpu_rng_state = None
python_rng_state = random.getstate()
torch.manual_seed(seed)
random.seed(seed)
return (cpu_rng_state, gpu_rng_state, python_rng_state)
def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]):
cpu_rng_state, gpu_rng_state, python_rng_state = rng_states
torch.set_rng_state(cpu_rng_state)
if gpu_rng_state is not None:
if accelerator.device.type == "cuda":
torch.cuda.set_rng_state(gpu_rng_state)
elif accelerator.device.type == "xpu":
torch.xpu.set_rng_state(gpu_rng_state)
elif accelerator.device.type == "mps":
torch.cuda.set_rng_state(gpu_rng_state)
random.setstate(python_rng_state)
for epoch in range(epoch_to_start, num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}\n")
current_epoch.value = epoch + 1
metadata["ss_epoch"] = str(epoch + 1)
accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet)
accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet) # network.train() is called here
# TRAINING
skipped_dataloader = None
@@ -1294,25 +1372,25 @@ class NetworkTrainer:
with accelerator.accumulate(training_model):
on_step_start_for_network(text_encoder, unet)
# temporary, for batch processing
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
# preprocess batch for each model
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True)
loss = self.process_batch(
batch,
text_encoders,
unet,
network,
vae,
noise_scheduler,
vae_dtype,
weight_dtype,
accelerator,
args,
text_encoding_strategy,
tokenize_strategy,
is_train=True,
train_text_encoder=train_text_encoder,
train_unet=train_unet
batch,
text_encoders,
unet,
network,
vae,
noise_scheduler,
vae_dtype,
weight_dtype,
accelerator,
args,
text_encoding_strategy,
tokenize_strategy,
is_train=True,
train_text_encoder=train_text_encoder,
train_unet=train_unet,
)
accelerator.backward(loss)
@@ -1369,148 +1447,167 @@ class NetworkTrainer:
if args.scale_weight_norms:
progress_bar.set_postfix(**{**max_mean_logs, **logs})
if is_tracking:
logs = self.generate_step_logs(
args,
current_loss,
avr_loss,
lr_scheduler,
lr_descriptions,
optimizer,
keys_scaled,
mean_norm,
maximum_norm
args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm
)
accelerator.log(logs, step=global_step)
self.step_logging(accelerator, logs, global_step, epoch + 1)
# VALIDATION PER STEP
should_validate_step = (
args.validate_every_n_steps is not None
and global_step != 0 # Skip first step
and global_step % args.validate_every_n_steps == 0
)
# VALIDATION PER STEP: global_step is already incremented
# for example, if validate_every_n_steps=100, validate at step 100, 200, 300, ...
should_validate_step = args.validate_every_n_steps is not None and global_step % args.validate_every_n_steps == 0
if accelerator.sync_gradients and validation_steps > 0 and should_validate_step:
optimizer_eval_fn()
accelerator.unwrap_model(network).eval()
rng_states = switch_rng_state(args.validation_seed if args.validation_seed is not None else args.seed)
val_progress_bar = tqdm(
range(validation_steps), smoothing=0,
disable=not accelerator.is_local_main_process,
desc="validation steps"
range(validation_total_steps),
smoothing=0,
disable=not accelerator.is_local_main_process,
desc="validation steps",
)
val_timesteps_step = 0
for val_step, batch in enumerate(val_dataloader):
if val_step >= validation_steps:
break
# temporary, for batch processing
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
for timestep in validation_timesteps:
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False)
loss = self.process_batch(
batch,
text_encoders,
unet,
network,
vae,
noise_scheduler,
vae_dtype,
weight_dtype,
accelerator,
args,
text_encoding_strategy,
tokenize_strategy,
is_train=False,
train_text_encoder=False,
train_unet=False
)
args.min_timestep = args.max_timestep = timestep # dirty hack to change timestep
current_loss = loss.detach().item()
val_step_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
val_progress_bar.update(1)
val_progress_bar.set_postfix({ "val_avg_loss": val_step_loss_recorder.moving_average })
loss = self.process_batch(
batch,
text_encoders,
unet,
network,
vae,
noise_scheduler,
vae_dtype,
weight_dtype,
accelerator,
args,
text_encoding_strategy,
tokenize_strategy,
is_train=False,
train_text_encoder=train_text_encoder, # this is needed for validation because Text Encoders must be called if train_text_encoder is True
train_unet=train_unet,
)
if is_tracking:
logs = {
"loss/validation/step_current": current_loss,
"val_step": (epoch * validation_steps) + val_step,
}
accelerator.log(logs, step=global_step)
current_loss = loss.detach().item()
val_step_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=current_loss)
val_progress_bar.update(1)
val_progress_bar.set_postfix(
{"val_avg_loss": val_step_loss_recorder.moving_average, "timestep": timestep}
)
# if is_tracking:
# logs = {f"loss/validation/step_current_{timestep}": current_loss}
# self.val_logging(accelerator, logs, global_step, epoch + 1, val_step)
self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
val_timesteps_step += 1
if is_tracking:
loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average
logs = {
"loss/validation/step_average": val_step_loss_recorder.moving_average,
"loss/validation/step_divergence": loss_validation_divergence,
"loss/validation/step_average": val_step_loss_recorder.moving_average,
"loss/validation/step_divergence": loss_validation_divergence,
}
accelerator.log(logs, step=global_step)
self.step_logging(accelerator, logs, global_step, epoch=epoch + 1)
restore_rng_state(rng_states)
args.min_timestep = original_args_min_timestep
args.max_timestep = original_args_max_timestep
optimizer_train_fn()
accelerator.unwrap_model(network).train()
progress_bar.unpause()
if global_step >= args.max_train_steps:
break
# EPOCH VALIDATION
should_validate_epoch = (
(epoch + 1) % args.validate_every_n_epochs == 0
if args.validate_every_n_epochs is not None
else True
(epoch + 1) % args.validate_every_n_epochs == 0 if args.validate_every_n_epochs is not None else True
)
if should_validate_epoch and len(val_dataloader) > 0:
optimizer_eval_fn()
accelerator.unwrap_model(network).eval()
rng_states = switch_rng_state(args.validation_seed if args.validation_seed is not None else args.seed)
val_progress_bar = tqdm(
range(validation_steps), smoothing=0,
disable=not accelerator.is_local_main_process,
desc="epoch validation steps"
range(validation_total_steps),
smoothing=0,
disable=not accelerator.is_local_main_process,
desc="epoch validation steps",
)
val_timesteps_step = 0
for val_step, batch in enumerate(val_dataloader):
if val_step >= validation_steps:
break
# temporary, for batch processing
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
for timestep in validation_timesteps:
args.min_timestep = args.max_timestep = timestep
loss = self.process_batch(
batch,
text_encoders,
unet,
network,
vae,
noise_scheduler,
vae_dtype,
weight_dtype,
accelerator,
args,
text_encoding_strategy,
tokenize_strategy,
is_train=False,
train_text_encoder=False,
train_unet=False
)
# temporary, for batch processing
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False)
current_loss = loss.detach().item()
val_epoch_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
val_progress_bar.update(1)
val_progress_bar.set_postfix({ "val_epoch_avg_loss": val_epoch_loss_recorder.moving_average })
loss = self.process_batch(
batch,
text_encoders,
unet,
network,
vae,
noise_scheduler,
vae_dtype,
weight_dtype,
accelerator,
args,
text_encoding_strategy,
tokenize_strategy,
is_train=False,
train_text_encoder=train_text_encoder,
train_unet=train_unet,
)
if is_tracking:
logs = {
"loss/validation/epoch_current": current_loss,
"epoch": epoch + 1,
"val_step": (epoch * validation_steps) + val_step
}
accelerator.log(logs, step=global_step)
current_loss = loss.detach().item()
val_epoch_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=current_loss)
val_progress_bar.update(1)
val_progress_bar.set_postfix(
{"val_epoch_avg_loss": val_epoch_loss_recorder.moving_average, "timestep": timestep}
)
# if is_tracking:
# logs = {f"loss/validation/epoch_current_{timestep}": current_loss}
# self.val_logging(accelerator, logs, global_step, epoch + 1, val_step)
self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
val_timesteps_step += 1
if is_tracking:
avr_loss: float = val_epoch_loss_recorder.moving_average
loss_validation_divergence = val_epoch_loss_recorder.moving_average - loss_recorder.moving_average
loss_validation_divergence = val_epoch_loss_recorder.moving_average - loss_recorder.moving_average
logs = {
"loss/validation/epoch_average": avr_loss,
"loss/validation/epoch_divergence": loss_validation_divergence,
"epoch": epoch + 1
"loss/validation/epoch_average": avr_loss,
"loss/validation/epoch_divergence": loss_validation_divergence,
}
accelerator.log(logs, step=global_step)
self.epoch_logging(accelerator, logs, global_step, epoch + 1)
restore_rng_state(rng_states)
args.min_timestep = original_args_min_timestep
args.max_timestep = original_args_max_timestep
optimizer_train_fn()
accelerator.unwrap_model(network).train()
progress_bar.unpause()
# END OF EPOCH
if is_tracking:
logs = {"loss/epoch_average": loss_recorder.moving_average, "epoch": epoch + 1}
accelerator.log(logs, step=global_step)
logs = {"loss/epoch_average": loss_recorder.moving_average}
self.epoch_logging(accelerator, logs, global_step, epoch + 1)
accelerator.wait_for_everyone()
# 指定エポックごとにモデルを保存
@@ -1696,31 +1793,31 @@ def setup_parser() -> argparse.ArgumentParser:
"--validation_seed",
type=int,
default=None,
help="Validation seed for shuffling validation dataset, training `--seed` used otherwise / 検証データセットをシャッフルするための検証シード、それ以外の場合はトレーニング `--seed` を使用する"
help="Validation seed for shuffling validation dataset, training `--seed` used otherwise / 検証データセットをシャッフルするための検証シード、それ以外の場合はトレーニング `--seed` を使用する",
)
parser.add_argument(
"--validation_split",
type=float,
default=0.0,
help="Split for validation images out of the training dataset / 学習画像から検証画像に分割する割合"
help="Split for validation images out of the training dataset / 学習画像から検証画像に分割する割合",
)
parser.add_argument(
"--validate_every_n_steps",
type=int,
default=None,
help="Run validation on validation dataset every N steps. By default, validation will only occur every epoch if a validation dataset is available / 検証データセットの検証をNステップごとに実行します。デフォルトでは、検証データセットが利用可能な場合にのみ、検証はエポックごとに実行されます"
help="Run validation on validation dataset every N steps. By default, validation will only occur every epoch if a validation dataset is available / 検証データセットの検証をNステップごとに実行します。デフォルトでは、検証データセットが利用可能な場合にのみ、検証はエポックごとに実行されます",
)
parser.add_argument(
"--validate_every_n_epochs",
type=int,
default=None,
help="Run validation dataset every N epochs. By default, validation will run every epoch if a validation dataset is available / 検証データセットをNエポックごとに実行します。デフォルトでは、検証データセットが利用可能な場合、検証はエポックごとに実行されます"
help="Run validation dataset every N epochs. By default, validation will run every epoch if a validation dataset is available / 検証データセットをNエポックごとに実行します。デフォルトでは、検証データセットが利用可能な場合、検証はエポックごとに実行されます",
)
parser.add_argument(
"--max_validation_steps",
type=int,
default=None,
help="Max number of validation dataset items processed. By default, validation will run the entire validation dataset / 処理される検証データセット項目の最大数。デフォルトでは、検証は検証データセット全体を実行します"
help="Max number of validation dataset items processed. By default, validation will run the entire validation dataset / 処理される検証データセット項目の最大数。デフォルトでは、検証は検証データセット全体を実行します",
)
return parser