mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix to work cond mask and alpha mask
This commit is contained in:
@@ -78,7 +78,6 @@ class BaseSubsetParams:
|
||||
caption_tag_dropout_rate: float = 0.0
|
||||
token_warmup_min: int = 1
|
||||
token_warmup_step: float = 0
|
||||
alpha_mask: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -87,11 +86,13 @@ class DreamBoothSubsetParams(BaseSubsetParams):
|
||||
class_tokens: Optional[str] = None
|
||||
caption_extension: str = ".caption"
|
||||
cache_info: bool = False
|
||||
alpha_mask: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class FineTuningSubsetParams(BaseSubsetParams):
|
||||
metadata_file: Optional[str] = None
|
||||
alpha_mask: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -484,9 +484,11 @@ def apply_masked_loss(loss, batch):
|
||||
# conditioning image is -1 to 1. we need to convert it to 0 to 1
|
||||
mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
|
||||
mask_image = mask_image / 2 + 0.5
|
||||
# print(f"conditioning_image: {mask_image.shape}")
|
||||
elif "alpha_masks" in batch and batch["alpha_masks"] is not None:
|
||||
# alpha mask is 0 to 1
|
||||
mask_image = batch["alpha_masks"].to(dtype=loss.dtype)
|
||||
mask_image = batch["alpha_masks"].to(dtype=loss.dtype).unsqueeze(1) # add channel dimension
|
||||
# print(f"mask_image: {mask_image.shape}, {mask_image.mean()}")
|
||||
else:
|
||||
return loss
|
||||
|
||||
|
||||
@@ -561,6 +561,7 @@ class ControlNetSubset(BaseSubset):
|
||||
|
||||
super().__init__(
|
||||
image_dir,
|
||||
False, # alpha_mask
|
||||
num_repeats,
|
||||
shuffle_caption,
|
||||
caption_separator,
|
||||
@@ -1947,6 +1948,7 @@ class ControlNetDataset(BaseDataset):
|
||||
None,
|
||||
subset.caption_extension,
|
||||
subset.cache_info,
|
||||
False,
|
||||
subset.num_repeats,
|
||||
subset.shuffle_caption,
|
||||
subset.caption_separator,
|
||||
@@ -2196,6 +2198,9 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alph
|
||||
return False
|
||||
if npz["alpha_mask"].shape[0:2] != reso: # HxW
|
||||
return False
|
||||
else:
|
||||
if "alpha_mask" in npz:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {npz_path}")
|
||||
raise e
|
||||
@@ -2296,6 +2301,13 @@ def debug_dataset(train_dataset, show_input_ids=False):
|
||||
if os.name == "nt":
|
||||
cv2.imshow("cond_img", cond_img)
|
||||
|
||||
if "alpha_masks" in example and example["alpha_masks"] is not None:
|
||||
alpha_mask = example["alpha_masks"][j]
|
||||
logger.info(f"alpha mask size: {alpha_mask.size()}")
|
||||
alpha_mask = (alpha_mask[0].numpy() * 255.0).astype(np.uint8)
|
||||
if os.name == "nt":
|
||||
cv2.imshow("alpha_mask", alpha_mask)
|
||||
|
||||
if os.name == "nt": # only windows
|
||||
cv2.imshow("img", im)
|
||||
k = cv2.waitKey()
|
||||
|
||||
Reference in New Issue
Block a user