mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Add Validation loss for LoRA training
This commit is contained in:
@@ -73,6 +73,8 @@ class BaseSubsetParams:
|
||||
token_warmup_min: int = 1
|
||||
token_warmup_step: float = 0
|
||||
custom_attributes: Optional[Dict[str, Any]] = None
|
||||
validation_seed: int = 0
|
||||
validation_split: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -102,6 +104,8 @@ class BaseDatasetParams:
|
||||
resolution: Optional[Tuple[int, int]] = None
|
||||
network_multiplier: float = 1.0
|
||||
debug_dataset: bool = False
|
||||
validation_seed: Optional[int] = None
|
||||
validation_split: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -478,9 +482,27 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
dataset_klass = FineTuningDataset
|
||||
|
||||
subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
|
||||
dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params))
|
||||
dataset = dataset_klass(subsets=subsets, is_train=True, **asdict(dataset_blueprint.params))
|
||||
datasets.append(dataset)
|
||||
|
||||
val_datasets:List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
|
||||
for dataset_blueprint in dataset_group_blueprint.datasets:
|
||||
if dataset_blueprint.params.validation_split <= 0.0:
|
||||
continue
|
||||
if dataset_blueprint.is_controlnet:
|
||||
subset_klass = ControlNetSubset
|
||||
dataset_klass = ControlNetDataset
|
||||
elif dataset_blueprint.is_dreambooth:
|
||||
subset_klass = DreamBoothSubset
|
||||
dataset_klass = DreamBoothDataset
|
||||
else:
|
||||
subset_klass = FineTuningSubset
|
||||
dataset_klass = FineTuningDataset
|
||||
|
||||
subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
|
||||
dataset = dataset_klass(subsets=subsets, is_train=False, **asdict(dataset_blueprint.params))
|
||||
val_datasets.append(dataset)
|
||||
|
||||
# print info
|
||||
info = ""
|
||||
for i, dataset in enumerate(datasets):
|
||||
@@ -566,6 +588,50 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
|
||||
logger.info(f"{info}")
|
||||
|
||||
if len(val_datasets) > 0:
|
||||
info = ""
|
||||
|
||||
for i, dataset in enumerate(val_datasets):
|
||||
info += dedent(
|
||||
f"""\
|
||||
[Validation Dataset {i}]
|
||||
batch_size: {dataset.batch_size}
|
||||
resolution: {(dataset.width, dataset.height)}
|
||||
enable_bucket: {dataset.enable_bucket}
|
||||
network_multiplier: {dataset.network_multiplier}
|
||||
"""
|
||||
)
|
||||
|
||||
if dataset.enable_bucket:
|
||||
info += indent(
|
||||
dedent(
|
||||
f"""\
|
||||
min_bucket_reso: {dataset.min_bucket_reso}
|
||||
max_bucket_reso: {dataset.max_bucket_reso}
|
||||
bucket_reso_steps: {dataset.bucket_reso_steps}
|
||||
bucket_no_upscale: {dataset.bucket_no_upscale}
|
||||
\n"""
|
||||
),
|
||||
" ",
|
||||
)
|
||||
else:
|
||||
info += "\n"
|
||||
|
||||
for j, subset in enumerate(dataset.subsets):
|
||||
info += indent(
|
||||
dedent(
|
||||
f"""\
|
||||
[Subset {j} of Validation Dataset {i}]
|
||||
image_dir: "{subset.image_dir}"
|
||||
image_count: {subset.img_count}
|
||||
num_repeats: {subset.num_repeats}
|
||||
"""
|
||||
),
|
||||
" ",
|
||||
)
|
||||
|
||||
logger.info(f"{info}")
|
||||
|
||||
# make buckets first because it determines the length of dataset
|
||||
# and set the same seed for all datasets
|
||||
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
|
||||
@@ -574,7 +640,15 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
dataset.make_buckets()
|
||||
dataset.set_seed(seed)
|
||||
|
||||
return DatasetGroup(datasets)
|
||||
for i, dataset in enumerate(val_datasets):
|
||||
logger.info(f"[Validation Dataset {i}]")
|
||||
dataset.make_buckets()
|
||||
dataset.set_seed(seed)
|
||||
|
||||
return (
|
||||
DatasetGroup(datasets),
|
||||
DatasetGroup(val_datasets) if val_datasets else None
|
||||
)
|
||||
|
||||
|
||||
def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None):
|
||||
|
||||
@@ -145,6 +145,17 @@ IMAGE_TRANSFORMS = transforms.Compose(
|
||||
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"
|
||||
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz"
|
||||
|
||||
def split_train_val(paths: List[str], validation_split: float, validation_seed: int) -> List[str]:
|
||||
if validation_seed is not None:
|
||||
print(f"Using validation seed: {validation_seed}")
|
||||
prevstate = random.getstate()
|
||||
random.seed(validation_seed)
|
||||
random.shuffle(paths)
|
||||
random.setstate(prevstate)
|
||||
else:
|
||||
random.shuffle(paths)
|
||||
|
||||
return paths[len(paths) - round(len(paths) * validation_split):]
|
||||
|
||||
class ImageInfo:
|
||||
def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None:
|
||||
@@ -397,6 +408,8 @@ class BaseSubset:
|
||||
token_warmup_min: int,
|
||||
token_warmup_step: Union[float, int],
|
||||
custom_attributes: Optional[Dict[str, Any]] = None,
|
||||
validation_seed: Optional[int] = None,
|
||||
validation_split: Optional[float] = 0.0,
|
||||
) -> None:
|
||||
self.image_dir = image_dir
|
||||
self.alpha_mask = alpha_mask if alpha_mask is not None else False
|
||||
@@ -424,6 +437,9 @@ class BaseSubset:
|
||||
|
||||
self.img_count = 0
|
||||
|
||||
self.validation_seed = validation_seed
|
||||
self.validation_split = validation_split
|
||||
|
||||
|
||||
class DreamBoothSubset(BaseSubset):
|
||||
def __init__(
|
||||
@@ -453,6 +469,8 @@ class DreamBoothSubset(BaseSubset):
|
||||
token_warmup_min,
|
||||
token_warmup_step,
|
||||
custom_attributes: Optional[Dict[str, Any]] = None,
|
||||
validation_seed: Optional[int] = None,
|
||||
validation_split: Optional[float] = 0.0,
|
||||
) -> None:
|
||||
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
|
||||
|
||||
@@ -478,6 +496,8 @@ class DreamBoothSubset(BaseSubset):
|
||||
token_warmup_min,
|
||||
token_warmup_step,
|
||||
custom_attributes=custom_attributes,
|
||||
validation_seed=validation_seed,
|
||||
validation_split=validation_split,
|
||||
)
|
||||
|
||||
self.is_reg = is_reg
|
||||
@@ -518,6 +538,8 @@ class FineTuningSubset(BaseSubset):
|
||||
token_warmup_min,
|
||||
token_warmup_step,
|
||||
custom_attributes: Optional[Dict[str, Any]] = None,
|
||||
validation_seed: Optional[int] = None,
|
||||
validation_split: Optional[float] = 0.0,
|
||||
) -> None:
|
||||
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
|
||||
|
||||
@@ -543,6 +565,8 @@ class FineTuningSubset(BaseSubset):
|
||||
token_warmup_min,
|
||||
token_warmup_step,
|
||||
custom_attributes=custom_attributes,
|
||||
validation_seed=validation_seed,
|
||||
validation_split=validation_split,
|
||||
)
|
||||
|
||||
self.metadata_file = metadata_file
|
||||
@@ -579,6 +603,8 @@ class ControlNetSubset(BaseSubset):
|
||||
token_warmup_min,
|
||||
token_warmup_step,
|
||||
custom_attributes: Optional[Dict[str, Any]] = None,
|
||||
validation_seed: Optional[int] = None,
|
||||
validation_split: Optional[float] = 0.0,
|
||||
) -> None:
|
||||
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
|
||||
|
||||
@@ -604,6 +630,8 @@ class ControlNetSubset(BaseSubset):
|
||||
token_warmup_min,
|
||||
token_warmup_step,
|
||||
custom_attributes=custom_attributes,
|
||||
validation_seed=validation_seed,
|
||||
validation_split=validation_split,
|
||||
)
|
||||
|
||||
self.conditioning_data_dir = conditioning_data_dir
|
||||
@@ -1799,6 +1827,9 @@ class DreamBoothDataset(BaseDataset):
|
||||
bucket_no_upscale: bool,
|
||||
prior_loss_weight: float,
|
||||
debug_dataset: bool,
|
||||
is_train: bool,
|
||||
validation_seed: int,
|
||||
validation_split: float,
|
||||
) -> None:
|
||||
super().__init__(resolution, network_multiplier, debug_dataset)
|
||||
|
||||
@@ -1808,6 +1839,9 @@ class DreamBoothDataset(BaseDataset):
|
||||
self.size = min(self.width, self.height) # 短いほう
|
||||
self.prior_loss_weight = prior_loss_weight
|
||||
self.latents_cache = None
|
||||
self.is_train = is_train
|
||||
self.validation_seed = validation_seed
|
||||
self.validation_split = validation_split
|
||||
|
||||
self.enable_bucket = enable_bucket
|
||||
if self.enable_bucket:
|
||||
@@ -1992,6 +2026,9 @@ class DreamBoothDataset(BaseDataset):
|
||||
)
|
||||
continue
|
||||
|
||||
if self.is_train == False:
|
||||
img_paths = split_train_val(img_paths, self.validation_split, self.validation_seed)
|
||||
|
||||
if subset.is_reg:
|
||||
num_reg_images += subset.num_repeats * len(img_paths)
|
||||
else:
|
||||
@@ -2009,7 +2046,11 @@ class DreamBoothDataset(BaseDataset):
|
||||
subset.img_count = len(img_paths)
|
||||
self.subsets.append(subset)
|
||||
|
||||
logger.info(f"{num_train_images} train images with repeating.")
|
||||
if self.is_train:
|
||||
logger.info(f"{num_train_images} train images with repeating.")
|
||||
else:
|
||||
logger.info(f"{num_train_images} validation images with repeating.")
|
||||
|
||||
self.num_train_images = num_train_images
|
||||
|
||||
logger.info(f"{num_reg_images} reg images.")
|
||||
@@ -2050,6 +2091,9 @@ class FineTuningDataset(BaseDataset):
|
||||
bucket_reso_steps: int,
|
||||
bucket_no_upscale: bool,
|
||||
debug_dataset: bool,
|
||||
is_train: bool,
|
||||
validation_seed: int,
|
||||
validation_split: float,
|
||||
) -> None:
|
||||
super().__init__(resolution, network_multiplier, debug_dataset)
|
||||
|
||||
@@ -2276,6 +2320,9 @@ class ControlNetDataset(BaseDataset):
|
||||
bucket_reso_steps: int,
|
||||
bucket_no_upscale: bool,
|
||||
debug_dataset: float,
|
||||
is_train: bool,
|
||||
validation_seed: int,
|
||||
validation_split: float,
|
||||
) -> None:
|
||||
super().__init__(resolution, network_multiplier, debug_dataset)
|
||||
|
||||
@@ -2324,6 +2371,9 @@ class ControlNetDataset(BaseDataset):
|
||||
bucket_no_upscale,
|
||||
1.0,
|
||||
debug_dataset,
|
||||
is_train,
|
||||
validation_seed,
|
||||
validation_split,
|
||||
)
|
||||
|
||||
# config_util等から参照される値をいれておく(若干微妙なのでなんとかしたい)
|
||||
@@ -4887,7 +4937,7 @@ def get_optimizer(args, trainable_params) -> tuple[str, str, object]:
|
||||
import schedulefree as sf
|
||||
except ImportError:
|
||||
raise ImportError("No schedulefree / schedulefreeがインストールされていないようです")
|
||||
|
||||
|
||||
if optimizer_type == "RAdamScheduleFree".lower():
|
||||
optimizer_class = sf.RAdamScheduleFree
|
||||
logger.info(f"use RAdamScheduleFree optimizer | {optimizer_kwargs}")
|
||||
|
||||
Reference in New Issue
Block a user