diff --git a/library/config_util.py b/library/config_util.py index 834d6bfa..a2e07dc6 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -471,36 +471,49 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] for dataset_blueprint in dataset_group_blueprint.datasets: + extra_dataset_params = {} + if dataset_blueprint.is_controlnet: subset_klass = ControlNetSubset dataset_klass = ControlNetDataset elif dataset_blueprint.is_dreambooth: subset_klass = DreamBoothSubset dataset_klass = DreamBoothDataset + # DreamBooth datasets support splitting training and validation datasets + extra_dataset_params = {"is_training_dataset": True} else: subset_klass = FineTuningSubset dataset_klass = FineTuningDataset subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] - dataset = dataset_klass(subsets=subsets, is_training_dataset=True, **asdict(dataset_blueprint.params)) + dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params), **extra_dataset_params) datasets.append(dataset) val_datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] for dataset_blueprint in dataset_group_blueprint.datasets: - if dataset_blueprint.params.validation_split <= 0.0: + if dataset_blueprint.params.validation_split < 0.0 or dataset_blueprint.params.validation_split > 1.0: + logging.warning(f"Dataset param `validation_split` ({dataset_blueprint.params.validation_split}) is not a valid number between 0.0 and 1.0, skipping validation split...") continue + + # if the dataset isn't setting a validation split, there is no current validation dataset + if dataset_blueprint.params.validation_split == 0.0: + continue + + extra_dataset_params = {} if dataset_blueprint.is_controlnet: subset_klass = ControlNetSubset dataset_klass = ControlNetDataset elif dataset_blueprint.is_dreambooth: subset_klass = DreamBoothSubset dataset_klass = DreamBoothDataset + # DreamBooth datasets support splitting training and validation datasets + extra_dataset_params = {"is_training_dataset": False} else: subset_klass = FineTuningSubset dataset_klass = FineTuningDataset subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] - dataset = dataset_klass(subsets=subsets, is_training_dataset=False, **asdict(dataset_blueprint.params)) + dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params), **extra_dataset_params) val_datasets.append(dataset) def print_info(_datasets, dataset_type: str):