Add Validation loss for LoRA training

This commit is contained in:
Hina Chen
2024-12-27 16:47:59 +08:00
parent e89653975d
commit 05bb9183fa
3 changed files with 257 additions and 6 deletions

View File

@@ -73,6 +73,8 @@ class BaseSubsetParams:
token_warmup_min: int = 1
token_warmup_step: float = 0
custom_attributes: Optional[Dict[str, Any]] = None
validation_seed: int = 0
validation_split: float = 0.0
@dataclass
@@ -102,6 +104,8 @@ class BaseDatasetParams:
resolution: Optional[Tuple[int, int]] = None
network_multiplier: float = 1.0
debug_dataset: bool = False
validation_seed: Optional[int] = None
validation_split: float = 0.0
@dataclass
@@ -478,9 +482,27 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
dataset_klass = FineTuningDataset
subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params))
dataset = dataset_klass(subsets=subsets, is_train=True, **asdict(dataset_blueprint.params))
datasets.append(dataset)
val_datasets:List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
for dataset_blueprint in dataset_group_blueprint.datasets:
if dataset_blueprint.params.validation_split <= 0.0:
continue
if dataset_blueprint.is_controlnet:
subset_klass = ControlNetSubset
dataset_klass = ControlNetDataset
elif dataset_blueprint.is_dreambooth:
subset_klass = DreamBoothSubset
dataset_klass = DreamBoothDataset
else:
subset_klass = FineTuningSubset
dataset_klass = FineTuningDataset
subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
dataset = dataset_klass(subsets=subsets, is_train=False, **asdict(dataset_blueprint.params))
val_datasets.append(dataset)
# print info
info = ""
for i, dataset in enumerate(datasets):
@@ -566,6 +588,50 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
logger.info(f"{info}")
if len(val_datasets) > 0:
info = ""
for i, dataset in enumerate(val_datasets):
info += dedent(
f"""\
[Validation Dataset {i}]
batch_size: {dataset.batch_size}
resolution: {(dataset.width, dataset.height)}
enable_bucket: {dataset.enable_bucket}
network_multiplier: {dataset.network_multiplier}
"""
)
if dataset.enable_bucket:
info += indent(
dedent(
f"""\
min_bucket_reso: {dataset.min_bucket_reso}
max_bucket_reso: {dataset.max_bucket_reso}
bucket_reso_steps: {dataset.bucket_reso_steps}
bucket_no_upscale: {dataset.bucket_no_upscale}
\n"""
),
" ",
)
else:
info += "\n"
for j, subset in enumerate(dataset.subsets):
info += indent(
dedent(
f"""\
[Subset {j} of Validation Dataset {i}]
image_dir: "{subset.image_dir}"
image_count: {subset.img_count}
num_repeats: {subset.num_repeats}
"""
),
" ",
)
logger.info(f"{info}")
# make buckets first because it determines the length of dataset
# and set the same seed for all datasets
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
@@ -574,7 +640,15 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
dataset.make_buckets()
dataset.set_seed(seed)
return DatasetGroup(datasets)
for i, dataset in enumerate(val_datasets):
logger.info(f"[Validation Dataset {i}]")
dataset.make_buckets()
dataset.set_seed(seed)
return (
DatasetGroup(datasets),
DatasetGroup(val_datasets) if val_datasets else None
)
def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None):

View File

@@ -145,6 +145,17 @@ IMAGE_TRANSFORMS = transforms.Compose(
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz"
def split_train_val(paths: List[str], validation_split: float, validation_seed: int) -> List[str]:
if validation_seed is not None:
print(f"Using validation seed: {validation_seed}")
prevstate = random.getstate()
random.seed(validation_seed)
random.shuffle(paths)
random.setstate(prevstate)
else:
random.shuffle(paths)
return paths[len(paths) - round(len(paths) * validation_split):]
class ImageInfo:
def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None:
@@ -397,6 +408,8 @@ class BaseSubset:
token_warmup_min: int,
token_warmup_step: Union[float, int],
custom_attributes: Optional[Dict[str, Any]] = None,
validation_seed: Optional[int] = None,
validation_split: Optional[float] = 0.0,
) -> None:
self.image_dir = image_dir
self.alpha_mask = alpha_mask if alpha_mask is not None else False
@@ -424,6 +437,9 @@ class BaseSubset:
self.img_count = 0
self.validation_seed = validation_seed
self.validation_split = validation_split
class DreamBoothSubset(BaseSubset):
def __init__(
@@ -453,6 +469,8 @@ class DreamBoothSubset(BaseSubset):
token_warmup_min,
token_warmup_step,
custom_attributes: Optional[Dict[str, Any]] = None,
validation_seed: Optional[int] = None,
validation_split: Optional[float] = 0.0,
) -> None:
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
@@ -478,6 +496,8 @@ class DreamBoothSubset(BaseSubset):
token_warmup_min,
token_warmup_step,
custom_attributes=custom_attributes,
validation_seed=validation_seed,
validation_split=validation_split,
)
self.is_reg = is_reg
@@ -518,6 +538,8 @@ class FineTuningSubset(BaseSubset):
token_warmup_min,
token_warmup_step,
custom_attributes: Optional[Dict[str, Any]] = None,
validation_seed: Optional[int] = None,
validation_split: Optional[float] = 0.0,
) -> None:
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
@@ -543,6 +565,8 @@ class FineTuningSubset(BaseSubset):
token_warmup_min,
token_warmup_step,
custom_attributes=custom_attributes,
validation_seed=validation_seed,
validation_split=validation_split,
)
self.metadata_file = metadata_file
@@ -579,6 +603,8 @@ class ControlNetSubset(BaseSubset):
token_warmup_min,
token_warmup_step,
custom_attributes: Optional[Dict[str, Any]] = None,
validation_seed: Optional[int] = None,
validation_split: Optional[float] = 0.0,
) -> None:
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
@@ -604,6 +630,8 @@ class ControlNetSubset(BaseSubset):
token_warmup_min,
token_warmup_step,
custom_attributes=custom_attributes,
validation_seed=validation_seed,
validation_split=validation_split,
)
self.conditioning_data_dir = conditioning_data_dir
@@ -1799,6 +1827,9 @@ class DreamBoothDataset(BaseDataset):
bucket_no_upscale: bool,
prior_loss_weight: float,
debug_dataset: bool,
is_train: bool,
validation_seed: int,
validation_split: float,
) -> None:
super().__init__(resolution, network_multiplier, debug_dataset)
@@ -1808,6 +1839,9 @@ class DreamBoothDataset(BaseDataset):
self.size = min(self.width, self.height) # 短いほう
self.prior_loss_weight = prior_loss_weight
self.latents_cache = None
self.is_train = is_train
self.validation_seed = validation_seed
self.validation_split = validation_split
self.enable_bucket = enable_bucket
if self.enable_bucket:
@@ -1992,6 +2026,9 @@ class DreamBoothDataset(BaseDataset):
)
continue
if self.is_train == False:
img_paths = split_train_val(img_paths, self.validation_split, self.validation_seed)
if subset.is_reg:
num_reg_images += subset.num_repeats * len(img_paths)
else:
@@ -2009,7 +2046,11 @@ class DreamBoothDataset(BaseDataset):
subset.img_count = len(img_paths)
self.subsets.append(subset)
logger.info(f"{num_train_images} train images with repeating.")
if self.is_train:
logger.info(f"{num_train_images} train images with repeating.")
else:
logger.info(f"{num_train_images} validation images with repeating.")
self.num_train_images = num_train_images
logger.info(f"{num_reg_images} reg images.")
@@ -2050,6 +2091,9 @@ class FineTuningDataset(BaseDataset):
bucket_reso_steps: int,
bucket_no_upscale: bool,
debug_dataset: bool,
is_train: bool,
validation_seed: int,
validation_split: float,
) -> None:
super().__init__(resolution, network_multiplier, debug_dataset)
@@ -2276,6 +2320,9 @@ class ControlNetDataset(BaseDataset):
bucket_reso_steps: int,
bucket_no_upscale: bool,
debug_dataset: float,
is_train: bool,
validation_seed: int,
validation_split: float,
) -> None:
super().__init__(resolution, network_multiplier, debug_dataset)
@@ -2324,6 +2371,9 @@ class ControlNetDataset(BaseDataset):
bucket_no_upscale,
1.0,
debug_dataset,
is_train,
validation_seed,
validation_split,
)
# config_util等から参照される値をいれておく若干微妙なのでなんとかしたい
@@ -4887,7 +4937,7 @@ def get_optimizer(args, trainable_params) -> tuple[str, str, object]:
import schedulefree as sf
except ImportError:
raise ImportError("No schedulefree / schedulefreeがインストールされていないようです")
if optimizer_type == "RAdamScheduleFree".lower():
optimizer_class = sf.RAdamScheduleFree
logger.info(f"use RAdamScheduleFree optimizer | {optimizer_kwargs}")

View File

@@ -9,6 +9,7 @@ import json
from multiprocessing import Value
from typing import Any, List
import toml
import itertools
from tqdm import tqdm
@@ -114,7 +115,7 @@ class NetworkTrainer:
)
if (
args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None
):
):
logs[f"lr/d*lr/group{i}"] = (
optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"]
)
@@ -373,10 +374,11 @@ class NetworkTrainer:
}
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
# use arbitrary dataset class
train_dataset_group = train_util.load_arbitrary_dataset(args)
val_dataset_group = None
current_epoch = Value("i", 0)
current_step = Value("i", 0)
@@ -398,6 +400,11 @@ class NetworkTrainer:
train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
if val_dataset_group is not None:
assert (
val_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
self.assert_extra_args(args, train_dataset_group) # may change some args
# acceleratorを準備する
@@ -444,6 +451,8 @@ class NetworkTrainer:
vae.eval()
train_dataset_group.new_cache_latents(vae, accelerator)
if val_dataset_group is not None:
val_dataset_group.new_cache_latents(vae, accelerator)
vae.to("cpu")
clean_memory_on_device(accelerator.device)
@@ -459,6 +468,8 @@ class NetworkTrainer:
if text_encoder_outputs_caching_strategy is not None:
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_outputs_caching_strategy)
self.cache_text_encoder_outputs_if_needed(args, accelerator, unet, vae, text_encoders, train_dataset_group, weight_dtype)
if val_dataset_group is not None:
self.cache_text_encoder_outputs_if_needed(args, accelerator, unet, vae, text_encoders, val_dataset_group, weight_dtype)
# prepare network
net_kwargs = {}
@@ -567,6 +578,8 @@ class NetworkTrainer:
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
# some strategies can be None
train_dataset_group.set_current_strategies()
if val_dataset_group is not None:
val_dataset_group.set_current_strategies()
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
@@ -580,6 +593,17 @@ class NetworkTrainer:
persistent_workers=args.persistent_data_loader_workers,
)
val_dataloader = torch.utils.data.DataLoader(
val_dataset_group if val_dataset_group is not None else [],
batch_size=1,
shuffle=False,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
cyclic_val_dataloader = itertools.cycle(val_dataloader)
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil(
@@ -592,6 +616,10 @@ class NetworkTrainer:
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
# Not for sure here.
# if val_dataset_group is not None:
# val_dataset_group.set_max_train_steps(args.max_train_steps)
# lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
@@ -1064,7 +1092,11 @@ class NetworkTrainer:
)
loss_recorder = train_util.LossRecorder()
# val_loss_recorder = train_util.LossRecorder()
del train_dataset_group
if val_dataset_group is not None:
del val_dataset_group
# callback for step start
if hasattr(accelerator.unwrap_model(network), "on_step_start"):
@@ -1308,6 +1340,77 @@ class NetworkTrainer:
)
accelerator.log(logs, step=global_step)
if len(val_dataloader) > 0:
if ((args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps):
accelerator.print("\nValidating バリデーション処理...")
total_loss = 0.0
with torch.no_grad():
validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader)
for val_step in tqdm(range(validation_steps), desc="Validation Steps バリデーションテップ"):
batch = next(cyclic_val_dataloader)
timesteps_list = [10, 350, 500, 650, 990]
val_loss = 0.0
for fixed_timesteps in timesteps_list:
with torch.set_grad_enabled(False), accelerator.autocast():
noise = torch.randn_like(latents, device=latents.device)
b_size = latents.shape[0]
timesteps = torch.full((b_size,), fixed_timesteps, dtype=torch.long, device="cpu")
timesteps = timesteps.long().to(latents.device)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
with accelerator.autocast():
noise_pred = self.call_unet(
args,
accelerator,
unet,
noisy_latents.requires_grad_(False),
timesteps,
text_encoder_conds,
batch,
weight_dtype,
)
if args.v_parameterization:
# v-parameterization training
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
target = noise
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
if weighting is not None:
loss = loss * weighting
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
# min snr gamma, scale v pred loss like noise pred, v pred like loss, debiased estimation etc.
loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
val_loss += loss / len(timesteps_list)
total_loss += val_loss.detach().item()
current_val_loss = total_loss / validation_steps
# val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_val_loss)
if len(accelerator.trackers) > 0:
logs = {"loss/current_val_loss": current_val_loss}
accelerator.log(logs, step=global_step)
# avr_loss: float = val_loss_recorder.moving_average
# logs = {"loss/average_val_loss": avr_loss}
# accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
@@ -1496,6 +1599,30 @@ def setup_parser() -> argparse.ArgumentParser:
help="initial step number including all epochs, 0 means first step (same as not specifying). overwrites initial_epoch."
+ " / 初期ステップ数、全エポックを含むステップ数、0で最初のステップ未指定時と同じ。initial_epochを上書きする",
)
parser.add_argument(
"--validation_seed",
type=int,
default=None,
help="Validation seed / 検証シード"
)
parser.add_argument(
"--validation_split",
type=float,
default=0.0,
help="Split for validation images out of the training dataset / 学習画像から検証画像に分割する割合"
)
parser.add_argument(
"--validation_every_n_step",
type=int,
default=None,
help="Number of train steps for counting validation loss. By default, validation per train epoch is performed / 学習エポックごとに検証を行う場合はNoneを指定する"
)
parser.add_argument(
"--max_validation_steps",
type=int,
default=None,
help="Number of max validation steps for counting validation loss. By default, validation will run entire validation dataset / 検証データセット全体を検証する場合はNoneを指定する"
)
# parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio")
# parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio")
# parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio")