Merge pull request #1903 from kohya-ss/val-loss-improvement

Val loss improvement
This commit is contained in:
Kohya S.
2025-02-26 21:15:14 +09:00
committed by GitHub
4 changed files with 302 additions and 238 deletions

View File

@@ -381,8 +381,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
t5_attn_mask = None t5_attn_mask = None
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask): def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
# if not args.split_mode: # grad is enabled even if unet is not in train mode, because Text Encoder is in train mode
# normal forward
with torch.set_grad_enabled(is_train), accelerator.autocast(): with torch.set_grad_enabled(is_train), accelerator.autocast():
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
model_pred = unet( model_pred = unet(
@@ -395,44 +394,6 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
guidance=guidance_vec, guidance=guidance_vec,
txt_attention_mask=t5_attn_mask, txt_attention_mask=t5_attn_mask,
) )
"""
else:
# split forward to reduce memory usage
assert network.train_blocks == "single", "train_blocks must be single for split mode"
with accelerator.autocast():
# move flux lower to cpu, and then move flux upper to gpu
unet.to("cpu")
clean_memory_on_device(accelerator.device)
self.flux_upper.to(accelerator.device)
# upper model does not require grad
with torch.no_grad():
intermediate_img, intermediate_txt, vec, pe = self.flux_upper(
img=packed_noisy_model_input,
img_ids=img_ids,
txt=t5_out,
txt_ids=txt_ids,
y=l_pooled,
timesteps=timesteps / 1000,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
)
# move flux upper back to cpu, and then move flux lower to gpu
self.flux_upper.to("cpu")
clean_memory_on_device(accelerator.device)
unet.to(accelerator.device)
# lower model requires grad
intermediate_img.requires_grad_(True)
intermediate_txt.requires_grad_(True)
vec.requires_grad_(True)
pe.requires_grad_(True)
with torch.set_grad_enabled(is_train and train_unet):
model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)
"""
return model_pred return model_pred
model_pred = call_dit( model_pred = call_dit(
@@ -551,6 +512,11 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
text_encoder.to(te_weight_dtype) # fp8 text_encoder.to(te_weight_dtype) # fp8
prepare_fp8(text_encoder, weight_dtype) prepare_fp8(text_encoder, weight_dtype)
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
if self.is_swapping_blocks:
# prepare for next forward: because backward pass is not called, we need to prepare it here
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
def prepare_unet_with_accelerator( def prepare_unet_with_accelerator(
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
) -> torch.nn.Module: ) -> torch.nn.Module:

View File

@@ -13,17 +13,7 @@ import re
import shutil import shutil
import time import time
import typing import typing
from typing import ( from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
Any,
Callable,
Dict,
List,
NamedTuple,
Optional,
Sequence,
Tuple,
Union
)
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState
import glob import glob
import math import math
@@ -146,12 +136,13 @@ IMAGE_TRANSFORMS = transforms.Compose(
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz" TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz" TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz"
def split_train_val( def split_train_val(
paths: List[str], paths: List[str],
sizes: List[Optional[Tuple[int, int]]], sizes: List[Optional[Tuple[int, int]]],
is_training_dataset: bool, is_training_dataset: bool,
validation_split: float, validation_split: float,
validation_seed: int | None validation_seed: int | None,
) -> Tuple[List[str], List[Optional[Tuple[int, int]]]]: ) -> Tuple[List[str], List[Optional[Tuple[int, int]]]]:
""" """
Split the dataset into train and validation Split the dataset into train and validation
@@ -1842,7 +1833,7 @@ class BaseDataset(torch.utils.data.Dataset):
class DreamBoothDataset(BaseDataset): class DreamBoothDataset(BaseDataset):
IMAGE_INFO_CACHE_FILE = "metadata_cache.json" IMAGE_INFO_CACHE_FILE = "metadata_cache.json"
# The is_training_dataset defines the type of dataset, training or validation # The is_training_dataset defines the type of dataset, training or validation
# if is_training_dataset is True -> training dataset # if is_training_dataset is True -> training dataset
# if is_training_dataset is False -> validation dataset # if is_training_dataset is False -> validation dataset
def __init__( def __init__(
@@ -1981,29 +1972,25 @@ class DreamBoothDataset(BaseDataset):
logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}") logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}")
# We want to create a training and validation split. This should be improved in the future # We want to create a training and validation split. This should be improved in the future
# to allow a clearer distinction between training and validation. This can be seen as a # to allow a clearer distinction between training and validation. This can be seen as a
# short-term solution to limit what is necessary to implement validation datasets # short-term solution to limit what is necessary to implement validation datasets
# #
# We split the dataset for the subset based on if we are doing a validation split # We split the dataset for the subset based on if we are doing a validation split
# The self.is_training_dataset defines the type of dataset, training or validation # The self.is_training_dataset defines the type of dataset, training or validation
# if self.is_training_dataset is True -> training dataset # if self.is_training_dataset is True -> training dataset
# if self.is_training_dataset is False -> validation dataset # if self.is_training_dataset is False -> validation dataset
if self.validation_split > 0.0: if self.validation_split > 0.0:
# For regularization images we do not want to split this dataset. # For regularization images we do not want to split this dataset.
if subset.is_reg is True: if subset.is_reg is True:
# Skip any validation dataset for regularization images # Skip any validation dataset for regularization images
if self.is_training_dataset is False: if self.is_training_dataset is False:
img_paths = [] img_paths = []
sizes = [] sizes = []
# Otherwise the img_paths remain as original img_paths and no split # Otherwise the img_paths remain as original img_paths and no split
# required for training images dataset of regularization images # required for training images dataset of regularization images
else: else:
img_paths, sizes = split_train_val( img_paths, sizes = split_train_val(
img_paths, img_paths, sizes, self.is_training_dataset, self.validation_split, self.validation_seed
sizes,
self.is_training_dataset,
self.validation_split,
self.validation_seed
) )
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
@@ -2373,7 +2360,7 @@ class ControlNetDataset(BaseDataset):
bucket_no_upscale: bool, bucket_no_upscale: bool,
debug_dataset: bool, debug_dataset: bool,
validation_split: float, validation_split: float,
validation_seed: Optional[int], validation_seed: Optional[int],
) -> None: ) -> None:
super().__init__(resolution, network_multiplier, debug_dataset) super().__init__(resolution, network_multiplier, debug_dataset)
@@ -2431,9 +2418,9 @@ class ControlNetDataset(BaseDataset):
self.image_data = self.dreambooth_dataset_delegate.image_data self.image_data = self.dreambooth_dataset_delegate.image_data
self.batch_size = batch_size self.batch_size = batch_size
self.num_train_images = self.dreambooth_dataset_delegate.num_train_images self.num_train_images = self.dreambooth_dataset_delegate.num_train_images
self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images
self.validation_split = validation_split self.validation_split = validation_split
self.validation_seed = validation_seed self.validation_seed = validation_seed
# assert all conditioning data exists # assert all conditioning data exists
missing_imgs = [] missing_imgs = []
@@ -5944,12 +5931,17 @@ def save_sd_model_on_train_end_common(
def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: torch.device) -> torch.Tensor: def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: torch.device) -> torch.Tensor:
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu") if min_timestep < max_timestep:
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu")
else:
timesteps = torch.full((b_size,), max_timestep, device="cpu")
timesteps = timesteps.long().to(device) timesteps = timesteps.long().to(device)
return timesteps return timesteps
def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor]: def get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents: torch.FloatTensor
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor]:
# Sample noise that we'll add to the latents # Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device) noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset: if args.noise_offset:
@@ -6441,7 +6433,7 @@ def sample_image_inference(
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
def init_trackers(accelerator: Accelerator, args: argparse.Namespace, default_tracker_name: str): def init_trackers(accelerator: Accelerator, args: argparse.Namespace, default_tracker_name: str):
""" """
Initialize experiment trackers with tracker specific behaviors Initialize experiment trackers with tracker specific behaviors
""" """
@@ -6458,13 +6450,17 @@ def init_trackers(accelerator: Accelerator, args: argparse.Namespace, default_tr
) )
if "wandb" in [tracker.name for tracker in accelerator.trackers]: if "wandb" in [tracker.name for tracker in accelerator.trackers]:
import wandb import wandb
wandb_tracker = accelerator.get_tracker("wandb", unwrap=True) wandb_tracker = accelerator.get_tracker("wandb", unwrap=True)
# Define specific metrics to handle validation and epochs "steps" # Define specific metrics to handle validation and epochs "steps"
wandb_tracker.define_metric("epoch", hidden=True) wandb_tracker.define_metric("epoch", hidden=True)
wandb_tracker.define_metric("val_step", hidden=True) wandb_tracker.define_metric("val_step", hidden=True)
wandb_tracker.define_metric("global_step", hidden=True)
# endregion # endregion

View File

@@ -450,14 +450,19 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
text_encoder.to(te_weight_dtype) # fp8 text_encoder.to(te_weight_dtype) # fp8
prepare_fp8(text_encoder, weight_dtype) prepare_fp8(text_encoder, weight_dtype)
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True):
# drop cached text encoder outputs # drop cached text encoder outputs: in validation, we drop cached outputs deterministically by fixed seed
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None: if text_encoder_outputs_list is not None:
text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list) text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list)
batch["text_encoder_outputs_list"] = text_encoder_outputs_list batch["text_encoder_outputs_list"] = text_encoder_outputs_list
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
if self.is_swapping_blocks:
# prepare for next forward: because backward pass is not called, we need to prepare it here
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
def prepare_unet_with_accelerator( def prepare_unet_with_accelerator(
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
) -> torch.nn.Module: ) -> torch.nn.Module:

View File

@@ -9,6 +9,7 @@ import random
import time import time
import json import json
from multiprocessing import Value from multiprocessing import Value
import numpy as np
import toml import toml
from tqdm import tqdm from tqdm import tqdm
@@ -100,9 +101,7 @@ class NetworkTrainer:
if ( if (
args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None
): # tracking d*lr value of unet. ): # tracking d*lr value of unet.
logs["lr/d*lr"] = ( logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"]
optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"]
)
else: else:
idx = 0 idx = 0
if not args.network_train_unet_only: if not args.network_train_unet_only:
@@ -115,16 +114,56 @@ class NetworkTrainer:
logs[f"lr/d*lr/group{i}"] = ( logs[f"lr/d*lr/group{i}"] = (
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
) )
if ( if args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None:
args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None logs[f"lr/d*lr/group{i}"] = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"]
):
logs[f"lr/d*lr/group{i}"] = (
optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"]
)
return logs return logs
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): def step_logging(self, accelerator: Accelerator, logs: dict, global_step: int, epoch: int):
self.accelerator_logging(accelerator, logs, global_step, global_step, epoch)
def epoch_logging(self, accelerator: Accelerator, logs: dict, global_step: int, epoch: int):
self.accelerator_logging(accelerator, logs, epoch, global_step, epoch)
def val_logging(self, accelerator: Accelerator, logs: dict, global_step: int, epoch: int, val_step: int):
self.accelerator_logging(accelerator, logs, global_step + val_step, global_step, epoch, val_step)
def accelerator_logging(
self, accelerator: Accelerator, logs: dict, step_value: int, global_step: int, epoch: int, val_step: Optional[int] = None
):
"""
step_value is for tensorboard, other values are for wandb
"""
tensorboard_tracker = None
wandb_tracker = None
other_trackers = []
for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
tensorboard_tracker = accelerator.get_tracker("tensorboard")
elif tracker.name == "wandb":
wandb_tracker = accelerator.get_tracker("wandb")
else:
other_trackers.append(accelerator.get_tracker(tracker.name))
if tensorboard_tracker is not None:
tensorboard_tracker.log(logs, step=step_value)
if wandb_tracker is not None:
logs["global_step"] = global_step
logs["epoch"] = epoch
if val_step is not None:
logs["val_step"] = val_step
wandb_tracker.log(logs)
for tracker in other_trackers:
tracker.log(logs, step=step_value)
def assert_extra_args(
self,
args,
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
val_dataset_group: Optional[train_util.DatasetGroup],
):
train_dataset_group.verify_bucket_reso_steps(64) train_dataset_group.verify_bucket_reso_steps(64)
if val_dataset_group is not None: if val_dataset_group is not None:
val_dataset_group.verify_bucket_reso_steps(64) val_dataset_group.verify_bucket_reso_steps(64)
@@ -219,7 +258,7 @@ class NetworkTrainer:
network, network,
weight_dtype, weight_dtype,
train_unet, train_unet,
is_train=True is_train=True,
): ):
# Sample noise, sample a random timestep for each image, and add noise to the latents, # Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified # with noise offset and/or multires noise if specified
@@ -309,28 +348,31 @@ class NetworkTrainer:
) -> torch.nn.Module: ) -> torch.nn.Module:
return accelerator.prepare(unet) return accelerator.prepare(unet)
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train: bool = True):
pass
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
pass pass
# endregion # endregion
def process_batch( def process_batch(
self, self,
batch, batch,
text_encoders, text_encoders,
unet, unet,
network, network,
vae, vae,
noise_scheduler, noise_scheduler,
vae_dtype, vae_dtype,
weight_dtype, weight_dtype,
accelerator, accelerator,
args, args,
text_encoding_strategy: strategy_base.TextEncodingStrategy, text_encoding_strategy: strategy_base.TextEncodingStrategy,
tokenize_strategy: strategy_base.TokenizeStrategy, tokenize_strategy: strategy_base.TokenizeStrategy,
is_train=True, is_train=True,
train_text_encoder=True, train_text_encoder=True,
train_unet=True train_unet=True,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Process a batch for the network Process a batch for the network
@@ -397,7 +439,7 @@ class NetworkTrainer:
network, network,
weight_dtype, weight_dtype,
train_unet, train_unet,
is_train=is_train is_train=is_train,
) )
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
@@ -484,7 +526,7 @@ class NetworkTrainer:
else: else:
# use arbitrary dataset class # use arbitrary dataset class
train_dataset_group = train_util.load_arbitrary_dataset(args) train_dataset_group = train_util.load_arbitrary_dataset(args)
val_dataset_group = None # placeholder until validation dataset supported for arbitrary val_dataset_group = None # placeholder until validation dataset supported for arbitrary
current_epoch = Value("i", 0) current_epoch = Value("i", 0)
current_step = Value("i", 0) current_step = Value("i", 0)
@@ -701,7 +743,7 @@ class NetworkTrainer:
num_workers=n_workers, num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers, persistent_workers=args.persistent_data_loader_workers,
) )
val_dataloader = torch.utils.data.DataLoader( val_dataloader = torch.utils.data.DataLoader(
val_dataset_group if val_dataset_group is not None else [], val_dataset_group if val_dataset_group is not None else [],
shuffle=False, shuffle=False,
@@ -900,7 +942,9 @@ class NetworkTrainer:
accelerator.print("running training / 学習開始") accelerator.print("running training / 学習開始")
accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
accelerator.print(f" num validation images * repeats / 学習画像の数×繰り返し回数: {val_dataset_group.num_train_images if val_dataset_group is not None else 0}") accelerator.print(
f" num validation images * repeats / 学習画像の数×繰り返し回数: {val_dataset_group.num_train_images if val_dataset_group is not None else 0}"
)
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
accelerator.print(f" num epochs / epoch数: {num_train_epochs}") accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
@@ -968,11 +1012,11 @@ class NetworkTrainer:
"ss_huber_c": args.huber_c, "ss_huber_c": args.huber_c,
"ss_fp8_base": bool(args.fp8_base), "ss_fp8_base": bool(args.fp8_base),
"ss_fp8_base_unet": bool(args.fp8_base_unet), "ss_fp8_base_unet": bool(args.fp8_base_unet),
"ss_validation_seed": args.validation_seed, "ss_validation_seed": args.validation_seed,
"ss_validation_split": args.validation_split, "ss_validation_split": args.validation_split,
"ss_max_validation_steps": args.max_validation_steps, "ss_max_validation_steps": args.max_validation_steps,
"ss_validate_every_n_epochs": args.validate_every_n_epochs, "ss_validate_every_n_epochs": args.validate_every_n_epochs,
"ss_validate_every_n_steps": args.validate_every_n_steps, "ss_validate_every_n_steps": args.validate_every_n_steps,
} }
self.update_metadata(metadata, args) # architecture specific metadata self.update_metadata(metadata, args) # architecture specific metadata
@@ -1243,12 +1287,6 @@ class NetworkTrainer:
# log empty object to commit the sample images to wandb # log empty object to commit the sample images to wandb
accelerator.log({}, step=0) accelerator.log({}, step=0)
validation_steps = (
min(args.max_validation_steps, len(val_dataloader))
if args.max_validation_steps is not None
else len(val_dataloader)
)
# training loop # training loop
if initial_step > 0: # only if skip_until_initial_step is specified if initial_step > 0: # only if skip_until_initial_step is specified
for skip_epoch in range(epoch_to_start): # skip epochs for skip_epoch in range(epoch_to_start): # skip epochs
@@ -1271,13 +1309,53 @@ class NetworkTrainer:
range(args.max_train_steps - initial_step), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps" range(args.max_train_steps - initial_step), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps"
) )
validation_steps = (
min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader)
)
NUM_VALIDATION_TIMESTEPS = 4 # 200, 400, 600, 800 TODO make this configurable
min_timestep = 0 if args.min_timestep is None else args.min_timestep
max_timestep = noise_scheduler.num_train_timesteps if args.max_timestep is None else args.max_timestep
validation_timesteps = np.linspace(min_timestep, max_timestep, (NUM_VALIDATION_TIMESTEPS + 2), dtype=int)[1:-1]
validation_total_steps = validation_steps * len(validation_timesteps)
original_args_min_timestep = args.min_timestep
original_args_max_timestep = args.max_timestep
def switch_rng_state(seed: int) -> tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]:
cpu_rng_state = torch.get_rng_state()
if accelerator.device.type == "cuda":
gpu_rng_state = torch.cuda.get_rng_state()
elif accelerator.device.type == "xpu":
gpu_rng_state = torch.xpu.get_rng_state()
elif accelerator.device.type == "mps":
gpu_rng_state = torch.cuda.get_rng_state()
else:
gpu_rng_state = None
python_rng_state = random.getstate()
torch.manual_seed(seed)
random.seed(seed)
return (cpu_rng_state, gpu_rng_state, python_rng_state)
def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]):
cpu_rng_state, gpu_rng_state, python_rng_state = rng_states
torch.set_rng_state(cpu_rng_state)
if gpu_rng_state is not None:
if accelerator.device.type == "cuda":
torch.cuda.set_rng_state(gpu_rng_state)
elif accelerator.device.type == "xpu":
torch.xpu.set_rng_state(gpu_rng_state)
elif accelerator.device.type == "mps":
torch.cuda.set_rng_state(gpu_rng_state)
random.setstate(python_rng_state)
for epoch in range(epoch_to_start, num_train_epochs): for epoch in range(epoch_to_start, num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}\n") accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}\n")
current_epoch.value = epoch + 1 current_epoch.value = epoch + 1
metadata["ss_epoch"] = str(epoch + 1) metadata["ss_epoch"] = str(epoch + 1)
accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet) accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet) # network.train() is called here
# TRAINING # TRAINING
skipped_dataloader = None skipped_dataloader = None
@@ -1294,25 +1372,25 @@ class NetworkTrainer:
with accelerator.accumulate(training_model): with accelerator.accumulate(training_model):
on_step_start_for_network(text_encoder, unet) on_step_start_for_network(text_encoder, unet)
# temporary, for batch processing # preprocess batch for each model
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True)
loss = self.process_batch( loss = self.process_batch(
batch, batch,
text_encoders, text_encoders,
unet, unet,
network, network,
vae, vae,
noise_scheduler, noise_scheduler,
vae_dtype, vae_dtype,
weight_dtype, weight_dtype,
accelerator, accelerator,
args, args,
text_encoding_strategy, text_encoding_strategy,
tokenize_strategy, tokenize_strategy,
is_train=True, is_train=True,
train_text_encoder=train_text_encoder, train_text_encoder=train_text_encoder,
train_unet=train_unet train_unet=train_unet,
) )
accelerator.backward(loss) accelerator.backward(loss)
@@ -1369,148 +1447,167 @@ class NetworkTrainer:
if args.scale_weight_norms: if args.scale_weight_norms:
progress_bar.set_postfix(**{**max_mean_logs, **logs}) progress_bar.set_postfix(**{**max_mean_logs, **logs})
if is_tracking: if is_tracking:
logs = self.generate_step_logs( logs = self.generate_step_logs(
args, args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm
current_loss,
avr_loss,
lr_scheduler,
lr_descriptions,
optimizer,
keys_scaled,
mean_norm,
maximum_norm
) )
accelerator.log(logs, step=global_step) self.step_logging(accelerator, logs, global_step, epoch + 1)
# VALIDATION PER STEP # VALIDATION PER STEP: global_step is already incremented
should_validate_step = ( # for example, if validate_every_n_steps=100, validate at step 100, 200, 300, ...
args.validate_every_n_steps is not None should_validate_step = args.validate_every_n_steps is not None and global_step % args.validate_every_n_steps == 0
and global_step != 0 # Skip first step
and global_step % args.validate_every_n_steps == 0
)
if accelerator.sync_gradients and validation_steps > 0 and should_validate_step: if accelerator.sync_gradients and validation_steps > 0 and should_validate_step:
optimizer_eval_fn()
accelerator.unwrap_model(network).eval()
rng_states = switch_rng_state(args.validation_seed if args.validation_seed is not None else args.seed)
val_progress_bar = tqdm( val_progress_bar = tqdm(
range(validation_steps), smoothing=0, range(validation_total_steps),
disable=not accelerator.is_local_main_process, smoothing=0,
desc="validation steps" disable=not accelerator.is_local_main_process,
desc="validation steps",
) )
val_timesteps_step = 0
for val_step, batch in enumerate(val_dataloader): for val_step, batch in enumerate(val_dataloader):
if val_step >= validation_steps: if val_step >= validation_steps:
break break
# temporary, for batch processing for timestep in validation_timesteps:
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False)
loss = self.process_batch( args.min_timestep = args.max_timestep = timestep # dirty hack to change timestep
batch,
text_encoders,
unet,
network,
vae,
noise_scheduler,
vae_dtype,
weight_dtype,
accelerator,
args,
text_encoding_strategy,
tokenize_strategy,
is_train=False,
train_text_encoder=False,
train_unet=False
)
current_loss = loss.detach().item() loss = self.process_batch(
val_step_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) batch,
val_progress_bar.update(1) text_encoders,
val_progress_bar.set_postfix({ "val_avg_loss": val_step_loss_recorder.moving_average }) unet,
network,
vae,
noise_scheduler,
vae_dtype,
weight_dtype,
accelerator,
args,
text_encoding_strategy,
tokenize_strategy,
is_train=False,
train_text_encoder=train_text_encoder, # this is needed for validation because Text Encoders must be called if train_text_encoder is True
train_unet=train_unet,
)
if is_tracking: current_loss = loss.detach().item()
logs = { val_step_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=current_loss)
"loss/validation/step_current": current_loss, val_progress_bar.update(1)
"val_step": (epoch * validation_steps) + val_step, val_progress_bar.set_postfix(
} {"val_avg_loss": val_step_loss_recorder.moving_average, "timestep": timestep}
accelerator.log(logs, step=global_step) )
# if is_tracking:
# logs = {f"loss/validation/step_current_{timestep}": current_loss}
# self.val_logging(accelerator, logs, global_step, epoch + 1, val_step)
self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
val_timesteps_step += 1
if is_tracking: if is_tracking:
loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average
logs = { logs = {
"loss/validation/step_average": val_step_loss_recorder.moving_average, "loss/validation/step_average": val_step_loss_recorder.moving_average,
"loss/validation/step_divergence": loss_validation_divergence, "loss/validation/step_divergence": loss_validation_divergence,
} }
accelerator.log(logs, step=global_step) self.step_logging(accelerator, logs, global_step, epoch=epoch + 1)
restore_rng_state(rng_states)
args.min_timestep = original_args_min_timestep
args.max_timestep = original_args_max_timestep
optimizer_train_fn()
accelerator.unwrap_model(network).train()
progress_bar.unpause()
if global_step >= args.max_train_steps: if global_step >= args.max_train_steps:
break break
# EPOCH VALIDATION # EPOCH VALIDATION
should_validate_epoch = ( should_validate_epoch = (
(epoch + 1) % args.validate_every_n_epochs == 0 (epoch + 1) % args.validate_every_n_epochs == 0 if args.validate_every_n_epochs is not None else True
if args.validate_every_n_epochs is not None
else True
) )
if should_validate_epoch and len(val_dataloader) > 0: if should_validate_epoch and len(val_dataloader) > 0:
optimizer_eval_fn()
accelerator.unwrap_model(network).eval()
rng_states = switch_rng_state(args.validation_seed if args.validation_seed is not None else args.seed)
val_progress_bar = tqdm( val_progress_bar = tqdm(
range(validation_steps), smoothing=0, range(validation_total_steps),
disable=not accelerator.is_local_main_process, smoothing=0,
desc="epoch validation steps" disable=not accelerator.is_local_main_process,
desc="epoch validation steps",
) )
val_timesteps_step = 0
for val_step, batch in enumerate(val_dataloader): for val_step, batch in enumerate(val_dataloader):
if val_step >= validation_steps: if val_step >= validation_steps:
break break
# temporary, for batch processing for timestep in validation_timesteps:
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) args.min_timestep = args.max_timestep = timestep
loss = self.process_batch( # temporary, for batch processing
batch, self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False)
text_encoders,
unet,
network,
vae,
noise_scheduler,
vae_dtype,
weight_dtype,
accelerator,
args,
text_encoding_strategy,
tokenize_strategy,
is_train=False,
train_text_encoder=False,
train_unet=False
)
current_loss = loss.detach().item() loss = self.process_batch(
val_epoch_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) batch,
val_progress_bar.update(1) text_encoders,
val_progress_bar.set_postfix({ "val_epoch_avg_loss": val_epoch_loss_recorder.moving_average }) unet,
network,
vae,
noise_scheduler,
vae_dtype,
weight_dtype,
accelerator,
args,
text_encoding_strategy,
tokenize_strategy,
is_train=False,
train_text_encoder=train_text_encoder,
train_unet=train_unet,
)
if is_tracking: current_loss = loss.detach().item()
logs = { val_epoch_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=current_loss)
"loss/validation/epoch_current": current_loss, val_progress_bar.update(1)
"epoch": epoch + 1, val_progress_bar.set_postfix(
"val_step": (epoch * validation_steps) + val_step {"val_epoch_avg_loss": val_epoch_loss_recorder.moving_average, "timestep": timestep}
} )
accelerator.log(logs, step=global_step)
# if is_tracking:
# logs = {f"loss/validation/epoch_current_{timestep}": current_loss}
# self.val_logging(accelerator, logs, global_step, epoch + 1, val_step)
self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
val_timesteps_step += 1
if is_tracking: if is_tracking:
avr_loss: float = val_epoch_loss_recorder.moving_average avr_loss: float = val_epoch_loss_recorder.moving_average
loss_validation_divergence = val_epoch_loss_recorder.moving_average - loss_recorder.moving_average loss_validation_divergence = val_epoch_loss_recorder.moving_average - loss_recorder.moving_average
logs = { logs = {
"loss/validation/epoch_average": avr_loss, "loss/validation/epoch_average": avr_loss,
"loss/validation/epoch_divergence": loss_validation_divergence, "loss/validation/epoch_divergence": loss_validation_divergence,
"epoch": epoch + 1
} }
accelerator.log(logs, step=global_step) self.epoch_logging(accelerator, logs, global_step, epoch + 1)
restore_rng_state(rng_states)
args.min_timestep = original_args_min_timestep
args.max_timestep = original_args_max_timestep
optimizer_train_fn()
accelerator.unwrap_model(network).train()
progress_bar.unpause()
# END OF EPOCH # END OF EPOCH
if is_tracking: if is_tracking:
logs = {"loss/epoch_average": loss_recorder.moving_average, "epoch": epoch + 1} logs = {"loss/epoch_average": loss_recorder.moving_average}
accelerator.log(logs, step=global_step) self.epoch_logging(accelerator, logs, global_step, epoch + 1)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
# 指定エポックごとにモデルを保存 # 指定エポックごとにモデルを保存
@@ -1696,31 +1793,31 @@ def setup_parser() -> argparse.ArgumentParser:
"--validation_seed", "--validation_seed",
type=int, type=int,
default=None, default=None,
help="Validation seed for shuffling validation dataset, training `--seed` used otherwise / 検証データセットをシャッフルするための検証シード、それ以外の場合はトレーニング `--seed` を使用する" help="Validation seed for shuffling validation dataset, training `--seed` used otherwise / 検証データセットをシャッフルするための検証シード、それ以外の場合はトレーニング `--seed` を使用する",
) )
parser.add_argument( parser.add_argument(
"--validation_split", "--validation_split",
type=float, type=float,
default=0.0, default=0.0,
help="Split for validation images out of the training dataset / 学習画像から検証画像に分割する割合" help="Split for validation images out of the training dataset / 学習画像から検証画像に分割する割合",
) )
parser.add_argument( parser.add_argument(
"--validate_every_n_steps", "--validate_every_n_steps",
type=int, type=int,
default=None, default=None,
help="Run validation on validation dataset every N steps. By default, validation will only occur every epoch if a validation dataset is available / 検証データセットの検証をNステップごとに実行します。デフォルトでは、検証データセットが利用可能な場合にのみ、検証はエポックごとに実行されます" help="Run validation on validation dataset every N steps. By default, validation will only occur every epoch if a validation dataset is available / 検証データセットの検証をNステップごとに実行します。デフォルトでは、検証データセットが利用可能な場合にのみ、検証はエポックごとに実行されます",
) )
parser.add_argument( parser.add_argument(
"--validate_every_n_epochs", "--validate_every_n_epochs",
type=int, type=int,
default=None, default=None,
help="Run validation dataset every N epochs. By default, validation will run every epoch if a validation dataset is available / 検証データセットをNエポックごとに実行します。デフォルトでは、検証データセットが利用可能な場合、検証はエポックごとに実行されます" help="Run validation dataset every N epochs. By default, validation will run every epoch if a validation dataset is available / 検証データセットをNエポックごとに実行します。デフォルトでは、検証データセットが利用可能な場合、検証はエポックごとに実行されます",
) )
parser.add_argument( parser.add_argument(
"--max_validation_steps", "--max_validation_steps",
type=int, type=int,
default=None, default=None,
help="Max number of validation dataset items processed. By default, validation will run the entire validation dataset / 処理される検証データセット項目の最大数。デフォルトでは、検証は検証データセット全体を実行します" help="Max number of validation dataset items processed. By default, validation will run the entire validation dataset / 処理される検証データセット項目の最大数。デフォルトでは、検証は検証データセット全体を実行します",
) )
return parser return parser