diff --git a/library/config_util.py b/library/config_util.py index c6667690..8f01e1f6 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -510,8 +510,10 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu subset_blueprint.params.caption_dropout_every_n_epochs = 0 subset_blueprint.params.caption_tag_dropout_rate = 0.0 subset_blueprint.params.token_warmup_step = 0 - if subset_klass != DreamBoothSubset or not subset_blueprint.params.is_reg: - subsets.append(subset_klass(**asdict(subset_blueprint.params))) + + if subset_klass != DreamBoothSubset or (subset_klass == DreamBoothSubset and not subset_blueprint.params.is_reg): + subset = subset_klass(**asdict(subset_blueprint.params)) + subsets.append(subset) dataset = dataset_klass(subsets=subsets, is_train=False, **asdict(dataset_blueprint.params)) val_datasets.append(dataset)