fix control net

This commit is contained in:
gesen2egee
2024-03-16 11:51:11 +08:00
parent d05965dbad
commit b5e8045df4
2 changed files with 16 additions and 5 deletions

View File

@@ -1816,6 +1816,7 @@ class ControlNetDataset(BaseDataset):
def __init__(
self,
subsets: Sequence[ControlNetSubset],
is_train: bool,
batch_size: int,
tokenizer,
max_token_length,
@@ -1826,6 +1827,8 @@ class ControlNetDataset(BaseDataset):
max_bucket_reso: int,
bucket_reso_steps: int,
bucket_no_upscale: bool,
validation_split: float,
validation_seed: Optional[int],
debug_dataset: float,
) -> None:
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)
@@ -1860,6 +1863,7 @@ class ControlNetDataset(BaseDataset):
self.dreambooth_dataset_delegate = DreamBoothDataset(
db_subsets,
is_train,
batch_size,
tokenizer,
max_token_length,
@@ -1871,6 +1875,8 @@ class ControlNetDataset(BaseDataset):
bucket_reso_steps,
bucket_no_upscale,
1.0,
validation_split,
validation_seed,
debug_dataset,
)
@@ -1878,7 +1884,10 @@ class ControlNetDataset(BaseDataset):
self.image_data = self.dreambooth_dataset_delegate.image_data
self.batch_size = batch_size
self.num_train_images = self.dreambooth_dataset_delegate.num_train_images
self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images
self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images
self.is_train = is_train
self.validation_split = validation_split
self.validation_seed = validation_seed
# assert all conditioning data exists
missing_imgs = []
@@ -1911,8 +1920,8 @@ class ControlNetDataset(BaseDataset):
[cond_img_path for cond_img_path in conditioning_img_paths if cond_img_path not in cond_imgs_with_img]
)
assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}"
assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}"
#assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}"
#assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}"
self.conditioning_image_transforms = IMAGE_TRANSFORMS