diff --git a/library/config_util.py b/library/config_util.py index ec6ef4b2..0da0b143 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -491,8 +491,10 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu else: subset_klass = FineTuningSubset dataset_klass = FineTuningDataset - - subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets if subset_blueprint.params.is_reg is False] + if subset_klass == DreamBoothSubset: + subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets if subset_blueprint.params.is_reg is False] + else: + subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] dataset = dataset_klass(subsets=subsets, is_train=False, **asdict(dataset_blueprint.params)) val_datasets.append(dataset) diff --git a/library/train_util.py b/library/train_util.py index 89297962..ae7968d7 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1816,6 +1816,7 @@ class ControlNetDataset(BaseDataset): def __init__( self, subsets: Sequence[ControlNetSubset], + is_train: bool, batch_size: int, tokenizer, max_token_length, @@ -1826,6 +1827,8 @@ class ControlNetDataset(BaseDataset): max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, + validation_split: float, + validation_seed: Optional[int], debug_dataset: float, ) -> None: super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) @@ -1860,6 +1863,7 @@ class ControlNetDataset(BaseDataset): self.dreambooth_dataset_delegate = DreamBoothDataset( db_subsets, + is_train, batch_size, tokenizer, max_token_length, @@ -1871,6 +1875,8 @@ class ControlNetDataset(BaseDataset): bucket_reso_steps, bucket_no_upscale, 1.0, + validation_split, + validation_seed, debug_dataset, ) @@ -1878,7 +1884,10 @@ class ControlNetDataset(BaseDataset): self.image_data = self.dreambooth_dataset_delegate.image_data self.batch_size = batch_size self.num_train_images = self.dreambooth_dataset_delegate.num_train_images - self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images + self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images + self.is_train = is_train + self.validation_split = validation_split + self.validation_seed = validation_seed # assert all conditioning data exists missing_imgs = [] @@ -1911,8 +1920,8 @@ class ControlNetDataset(BaseDataset): [cond_img_path for cond_img_path in conditioning_img_paths if cond_img_path not in cond_imgs_with_img] ) - assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}" - assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}" + #assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}" + #assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}" self.conditioning_image_transforms = IMAGE_TRANSFORMS