fix to work cond mask and alpha mask

This commit is contained in:
Kohya S
2024-05-26 22:01:37 +09:00
parent da6fea3d97
commit e8cfd4ba1d
3 changed files with 17 additions and 2 deletions

View File

@@ -561,6 +561,7 @@ class ControlNetSubset(BaseSubset):
super().__init__(
image_dir,
False, # alpha_mask
num_repeats,
shuffle_caption,
caption_separator,
@@ -1947,6 +1948,7 @@ class ControlNetDataset(BaseDataset):
None,
subset.caption_extension,
subset.cache_info,
False,
subset.num_repeats,
subset.shuffle_caption,
subset.caption_separator,
@@ -2196,6 +2198,9 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alph
return False
if npz["alpha_mask"].shape[0:2] != reso: # HxW
return False
else:
if "alpha_mask" in npz:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
@@ -2296,6 +2301,13 @@ def debug_dataset(train_dataset, show_input_ids=False):
if os.name == "nt":
cv2.imshow("cond_img", cond_img)
if "alpha_masks" in example and example["alpha_masks"] is not None:
alpha_mask = example["alpha_masks"][j]
logger.info(f"alpha mask size: {alpha_mask.size()}")
alpha_mask = (alpha_mask[0].numpy() * 255.0).astype(np.uint8)
if os.name == "nt":
cv2.imshow("alpha_mask", alpha_mask)
if os.name == "nt": # only windows
cv2.imshow("img", im)
k = cv2.waitKey()