mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix to work cond mask and alpha mask
This commit is contained in:
@@ -561,6 +561,7 @@ class ControlNetSubset(BaseSubset):
|
||||
|
||||
super().__init__(
|
||||
image_dir,
|
||||
False, # alpha_mask
|
||||
num_repeats,
|
||||
shuffle_caption,
|
||||
caption_separator,
|
||||
@@ -1947,6 +1948,7 @@ class ControlNetDataset(BaseDataset):
|
||||
None,
|
||||
subset.caption_extension,
|
||||
subset.cache_info,
|
||||
False,
|
||||
subset.num_repeats,
|
||||
subset.shuffle_caption,
|
||||
subset.caption_separator,
|
||||
@@ -2196,6 +2198,9 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alph
|
||||
return False
|
||||
if npz["alpha_mask"].shape[0:2] != reso: # HxW
|
||||
return False
|
||||
else:
|
||||
if "alpha_mask" in npz:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {npz_path}")
|
||||
raise e
|
||||
@@ -2296,6 +2301,13 @@ def debug_dataset(train_dataset, show_input_ids=False):
|
||||
if os.name == "nt":
|
||||
cv2.imshow("cond_img", cond_img)
|
||||
|
||||
if "alpha_masks" in example and example["alpha_masks"] is not None:
|
||||
alpha_mask = example["alpha_masks"][j]
|
||||
logger.info(f"alpha mask size: {alpha_mask.size()}")
|
||||
alpha_mask = (alpha_mask[0].numpy() * 255.0).astype(np.uint8)
|
||||
if os.name == "nt":
|
||||
cv2.imshow("alpha_mask", alpha_mask)
|
||||
|
||||
if os.name == "nt": # only windows
|
||||
cv2.imshow("img", im)
|
||||
k = cv2.waitKey()
|
||||
|
||||
Reference in New Issue
Block a user