fix control net

This commit is contained in:
gesen2egee
2024-03-16 11:51:11 +08:00
parent d05965dbad
commit b5e8045df4
2 changed files with 16 additions and 5 deletions

View File

@@ -491,8 +491,10 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
else:
subset_klass = FineTuningSubset
dataset_klass = FineTuningDataset
subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets if subset_blueprint.params.is_reg is False]
if subset_klass == DreamBoothSubset:
subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets if subset_blueprint.params.is_reg is False]
else:
subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
dataset = dataset_klass(subsets=subsets, is_train=False, **asdict(dataset_blueprint.params))
val_datasets.append(dataset)

View File

@@ -1816,6 +1816,7 @@ class ControlNetDataset(BaseDataset):
def __init__(
self,
subsets: Sequence[ControlNetSubset],
is_train: bool,
batch_size: int,
tokenizer,
max_token_length,
@@ -1826,6 +1827,8 @@ class ControlNetDataset(BaseDataset):
max_bucket_reso: int,
bucket_reso_steps: int,
bucket_no_upscale: bool,
validation_split: float,
validation_seed: Optional[int],
debug_dataset: float,
) -> None:
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)
@@ -1860,6 +1863,7 @@ class ControlNetDataset(BaseDataset):
self.dreambooth_dataset_delegate = DreamBoothDataset(
db_subsets,
is_train,
batch_size,
tokenizer,
max_token_length,
@@ -1871,6 +1875,8 @@ class ControlNetDataset(BaseDataset):
bucket_reso_steps,
bucket_no_upscale,
1.0,
validation_split,
validation_seed,
debug_dataset,
)
@@ -1878,7 +1884,10 @@ class ControlNetDataset(BaseDataset):
self.image_data = self.dreambooth_dataset_delegate.image_data
self.batch_size = batch_size
self.num_train_images = self.dreambooth_dataset_delegate.num_train_images
self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images
self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images
self.is_train = is_train
self.validation_split = validation_split
self.validation_seed = validation_seed
# assert all conditioning data exists
missing_imgs = []
@@ -1911,8 +1920,8 @@ class ControlNetDataset(BaseDataset):
[cond_img_path for cond_img_path in conditioning_img_paths if cond_img_path not in cond_imgs_with_img]
)
assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}"
assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}"
#assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}"
#assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}"
self.conditioning_image_transforms = IMAGE_TRANSFORMS