mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Fix regularization images with validation
Adding metadata recording for validation arguments Add comments about the validation split for clarity of intention
This commit is contained in:
@@ -146,7 +146,12 @@ IMAGE_TRANSFORMS = transforms.Compose(
|
|||||||
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"
|
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"
|
||||||
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz"
|
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz"
|
||||||
|
|
||||||
def split_train_val(paths: List[str], is_training_dataset: bool, validation_split: float, validation_seed: int) -> List[str]:
|
def split_train_val(
|
||||||
|
paths: List[str],
|
||||||
|
is_training_dataset: bool,
|
||||||
|
validation_split: float,
|
||||||
|
validation_seed: int | None
|
||||||
|
) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Split the dataset into train and validation
|
Split the dataset into train and validation
|
||||||
|
|
||||||
@@ -1830,6 +1835,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
class DreamBoothDataset(BaseDataset):
|
class DreamBoothDataset(BaseDataset):
|
||||||
IMAGE_INFO_CACHE_FILE = "metadata_cache.json"
|
IMAGE_INFO_CACHE_FILE = "metadata_cache.json"
|
||||||
|
|
||||||
|
# The is_training_dataset defines the type of dataset, training or validation
|
||||||
|
# if is_training_dataset is True -> training dataset
|
||||||
|
# if is_training_dataset is False -> validation dataset
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
subsets: Sequence[DreamBoothSubset],
|
subsets: Sequence[DreamBoothSubset],
|
||||||
@@ -1965,8 +1973,29 @@ class DreamBoothDataset(BaseDataset):
|
|||||||
size_set_count += 1
|
size_set_count += 1
|
||||||
logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}")
|
logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}")
|
||||||
|
|
||||||
|
# We want to create a training and validation split. This should be improved in the future
|
||||||
|
# to allow a clearer distinction between training and validation. This can be seen as a
|
||||||
|
# short-term solution to limit what is necessary to implement validation datasets
|
||||||
|
#
|
||||||
|
# We split the dataset for the subset based on if we are doing a validation split
|
||||||
|
# The self.is_training_dataset defines the type of dataset, training or validation
|
||||||
|
# if self.is_training_dataset is True -> training dataset
|
||||||
|
# if self.is_training_dataset is False -> validation dataset
|
||||||
if self.validation_split > 0.0:
|
if self.validation_split > 0.0:
|
||||||
img_paths = split_train_val(img_paths, self.is_training_dataset, self.validation_split, self.validation_seed)
|
# For regularization images we do not want to split this dataset.
|
||||||
|
if subset.is_reg is True:
|
||||||
|
# Skip any validation dataset for regularization images
|
||||||
|
if self.is_training_dataset is False:
|
||||||
|
img_paths = []
|
||||||
|
# Otherwise the img_paths remain as original img_paths and no split
|
||||||
|
# required for training images dataset of regularization images
|
||||||
|
else:
|
||||||
|
img_paths = split_train_val(
|
||||||
|
img_paths,
|
||||||
|
self.is_training_dataset,
|
||||||
|
self.validation_split,
|
||||||
|
self.validation_seed
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
|
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
|
||||||
|
|
||||||
|
|||||||
@@ -898,6 +898,7 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
accelerator.print("running training / 学習開始")
|
accelerator.print("running training / 学習開始")
|
||||||
accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
||||||
|
accelerator.print(f" num validation images * repeats / 学習画像の数×繰り返し回数: {val_dataset_group.num_train_images if val_dataset_group is not None else 0}")
|
||||||
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
||||||
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||||
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
||||||
@@ -917,6 +918,7 @@ class NetworkTrainer:
|
|||||||
"ss_text_encoder_lr": text_encoder_lr,
|
"ss_text_encoder_lr": text_encoder_lr,
|
||||||
"ss_unet_lr": args.unet_lr,
|
"ss_unet_lr": args.unet_lr,
|
||||||
"ss_num_train_images": train_dataset_group.num_train_images,
|
"ss_num_train_images": train_dataset_group.num_train_images,
|
||||||
|
"ss_num_validation_images": val_dataset_group.num_train_images if val_dataset_group is not None else 0,
|
||||||
"ss_num_reg_images": train_dataset_group.num_reg_images,
|
"ss_num_reg_images": train_dataset_group.num_reg_images,
|
||||||
"ss_num_batches_per_epoch": len(train_dataloader),
|
"ss_num_batches_per_epoch": len(train_dataloader),
|
||||||
"ss_num_epochs": num_train_epochs,
|
"ss_num_epochs": num_train_epochs,
|
||||||
@@ -964,6 +966,11 @@ class NetworkTrainer:
|
|||||||
"ss_huber_c": args.huber_c,
|
"ss_huber_c": args.huber_c,
|
||||||
"ss_fp8_base": bool(args.fp8_base),
|
"ss_fp8_base": bool(args.fp8_base),
|
||||||
"ss_fp8_base_unet": bool(args.fp8_base_unet),
|
"ss_fp8_base_unet": bool(args.fp8_base_unet),
|
||||||
|
"ss_validation_seed": args.validation_seed,
|
||||||
|
"ss_validation_split": args.validation_split,
|
||||||
|
"ss_max_validation_steps": args.max_validation_steps,
|
||||||
|
"ss_validate_every_n_epochs": args.validate_every_n_epochs,
|
||||||
|
"ss_validate_every_n_steps": args.validate_every_n_steps,
|
||||||
}
|
}
|
||||||
|
|
||||||
self.update_metadata(metadata, args) # architecture specific metadata
|
self.update_metadata(metadata, args) # architecture specific metadata
|
||||||
|
|||||||
Reference in New Issue
Block a user