mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
simplify and update alpha mask to work with various cases
This commit is contained in:
@@ -159,9 +159,7 @@ class ImageInfo:
|
||||
self.text_encoder_outputs1: Optional[torch.Tensor] = None
|
||||
self.text_encoder_outputs2: Optional[torch.Tensor] = None
|
||||
self.text_encoder_pool2: Optional[torch.Tensor] = None
|
||||
self.alpha_mask: Optional[torch.Tensor] = None
|
||||
self.alpha_mask_flipped: Optional[torch.Tensor] = None
|
||||
self.use_alpha_mask: bool = False
|
||||
self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
|
||||
|
||||
|
||||
class BucketManager:
|
||||
@@ -364,6 +362,7 @@ class BaseSubset:
|
||||
def __init__(
|
||||
self,
|
||||
image_dir: Optional[str],
|
||||
alpha_mask: Optional[bool],
|
||||
num_repeats: int,
|
||||
shuffle_caption: bool,
|
||||
caption_separator: str,
|
||||
@@ -382,9 +381,9 @@ class BaseSubset:
|
||||
caption_suffix: Optional[str],
|
||||
token_warmup_min: int,
|
||||
token_warmup_step: Union[float, int],
|
||||
alpha_mask: bool,
|
||||
) -> None:
|
||||
self.image_dir = image_dir
|
||||
self.alpha_mask = alpha_mask if alpha_mask is not None else False
|
||||
self.num_repeats = num_repeats
|
||||
self.shuffle_caption = shuffle_caption
|
||||
self.caption_separator = caption_separator
|
||||
@@ -407,8 +406,6 @@ class BaseSubset:
|
||||
|
||||
self.img_count = 0
|
||||
|
||||
self.alpha_mask = alpha_mask
|
||||
|
||||
|
||||
class DreamBoothSubset(BaseSubset):
|
||||
def __init__(
|
||||
@@ -418,6 +415,7 @@ class DreamBoothSubset(BaseSubset):
|
||||
class_tokens: Optional[str],
|
||||
caption_extension: str,
|
||||
cache_info: bool,
|
||||
alpha_mask: bool,
|
||||
num_repeats,
|
||||
shuffle_caption,
|
||||
caption_separator: str,
|
||||
@@ -441,6 +439,7 @@ class DreamBoothSubset(BaseSubset):
|
||||
|
||||
super().__init__(
|
||||
image_dir,
|
||||
alpha_mask,
|
||||
num_repeats,
|
||||
shuffle_caption,
|
||||
caption_separator,
|
||||
@@ -479,6 +478,7 @@ class FineTuningSubset(BaseSubset):
|
||||
self,
|
||||
image_dir,
|
||||
metadata_file: str,
|
||||
alpha_mask: bool,
|
||||
num_repeats,
|
||||
shuffle_caption,
|
||||
caption_separator,
|
||||
@@ -502,6 +502,7 @@ class FineTuningSubset(BaseSubset):
|
||||
|
||||
super().__init__(
|
||||
image_dir,
|
||||
alpha_mask,
|
||||
num_repeats,
|
||||
shuffle_caption,
|
||||
caption_separator,
|
||||
@@ -921,7 +922,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
logger.info(f"mean ar error (without repeats): {mean_img_ar_error}")
|
||||
|
||||
# データ参照用indexを作る。このindexはdatasetのshuffleに用いられる
|
||||
self.buckets_indices: List(BucketBatchIndex) = []
|
||||
self.buckets_indices: List[BucketBatchIndex] = []
|
||||
for bucket_index, bucket in enumerate(self.bucket_manager.buckets):
|
||||
batch_count = int(math.ceil(len(bucket) / self.batch_size))
|
||||
for batch_index in range(batch_count):
|
||||
@@ -991,8 +992,6 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
for info in tqdm(image_infos):
|
||||
subset = self.image_to_subset[info.image_key]
|
||||
|
||||
info.use_alpha_mask = subset.alpha_mask
|
||||
|
||||
if info.latents_npz is not None: # fine tuning dataset
|
||||
continue
|
||||
|
||||
@@ -1002,7 +1001,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
if not is_main_process: # store to info only
|
||||
continue
|
||||
|
||||
cache_available = is_disk_cached_latents_is_expected(info.bucket_reso, info.latents_npz, subset.flip_aug)
|
||||
cache_available = is_disk_cached_latents_is_expected(
|
||||
info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask
|
||||
)
|
||||
|
||||
if cache_available: # do not add to batch
|
||||
continue
|
||||
@@ -1028,7 +1029,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
# iterate batches: batch doesn't have image, image will be loaded in cache_batch_latents and discarded
|
||||
logger.info("caching latents...")
|
||||
for batch in tqdm(batches, smoothing=1, total=len(batches)):
|
||||
cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.random_crop)
|
||||
cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop)
|
||||
|
||||
# weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる
|
||||
# SDXLでのみ有効だが、datasetのメソッドとする必要があるので、sdxl_train_util.pyではなくこちらに実装する
|
||||
@@ -1202,18 +1203,15 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
alpha_mask = image_info.alpha_mask
|
||||
else:
|
||||
latents = image_info.latents_flipped
|
||||
alpha_mask = image_info.alpha_mask_flipped
|
||||
alpha_mask = None if image_info.alpha_mask is None else torch.flip(image_info.alpha_mask, [1])
|
||||
|
||||
image = None
|
||||
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
|
||||
latents, original_size, crop_ltrb, flipped_latents, alpha_mask, flipped_alpha_mask = load_latents_from_disk(
|
||||
image_info.latents_npz
|
||||
)
|
||||
latents, original_size, crop_ltrb, flipped_latents, alpha_mask = load_latents_from_disk(image_info.latents_npz)
|
||||
if flipped:
|
||||
latents = flipped_latents
|
||||
alpha_mask = flipped_alpha_mask
|
||||
alpha_mask = None if alpha_mask is None else alpha_mask[:, ::-1].copy() # copy to avoid negative stride problem
|
||||
del flipped_latents
|
||||
del flipped_alpha_mask
|
||||
latents = torch.FloatTensor(latents)
|
||||
if alpha_mask is not None:
|
||||
alpha_mask = torch.FloatTensor(alpha_mask)
|
||||
@@ -1255,23 +1253,28 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
# augmentation
|
||||
aug = self.aug_helper.get_augmentor(subset.color_aug)
|
||||
if aug is not None:
|
||||
img = aug(image=img)["image"]
|
||||
# augment RGB channels only
|
||||
img_rgb = img[:, :, :3]
|
||||
img_rgb = aug(image=img_rgb)["image"]
|
||||
img[:, :, :3] = img_rgb
|
||||
|
||||
if flipped:
|
||||
img = img[:, ::-1, :].copy() # copy to avoid negative stride problem
|
||||
|
||||
if subset.alpha_mask:
|
||||
if img.shape[2] == 4:
|
||||
alpha_mask = img[:, :, 3] # [W,H]
|
||||
alpha_mask = img[:, :, 3] # [H,W]
|
||||
alpha_mask = transforms.ToTensor()(alpha_mask) # 0-255 -> 0-1
|
||||
else:
|
||||
alpha_mask = np.full((im_w, im_h), 255, dtype=np.uint8) # [W,H]
|
||||
alpha_mask = transforms.ToTensor()(alpha_mask)
|
||||
alpha_mask = torch.ones((img.shape[0], img.shape[1]), dtype=torch.float32)
|
||||
else:
|
||||
alpha_mask = None
|
||||
|
||||
img = img[:, :, :3] # remove alpha channel
|
||||
|
||||
latents = None
|
||||
image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
|
||||
del img
|
||||
|
||||
images.append(image)
|
||||
latents_list.append(latents)
|
||||
@@ -1361,6 +1364,23 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
example["text_encoder_outputs2_list"] = torch.stack(text_encoder_outputs2_list)
|
||||
example["text_encoder_pool2_list"] = torch.stack(text_encoder_pool2_list)
|
||||
|
||||
# if one of alpha_masks is not None, we need to replace None with ones
|
||||
none_or_not = [x is None for x in alpha_mask_list]
|
||||
if all(none_or_not):
|
||||
example["alpha_masks"] = None
|
||||
elif any(none_or_not):
|
||||
for i in range(len(alpha_mask_list)):
|
||||
if alpha_mask_list[i] is None:
|
||||
if images[i] is not None:
|
||||
alpha_mask_list[i] = torch.ones((images[i].shape[1], images[i].shape[2]), dtype=torch.float32)
|
||||
else:
|
||||
alpha_mask_list[i] = torch.ones(
|
||||
(latents_list[i].shape[1] * 8, latents_list[i].shape[2] * 8), dtype=torch.float32
|
||||
)
|
||||
example["alpha_masks"] = torch.stack(alpha_mask_list)
|
||||
else:
|
||||
example["alpha_masks"] = torch.stack(alpha_mask_list)
|
||||
|
||||
if images[0] is not None:
|
||||
images = torch.stack(images)
|
||||
images = images.to(memory_format=torch.contiguous_format).float()
|
||||
@@ -1378,8 +1398,6 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
example["network_multipliers"] = torch.FloatTensor([self.network_multiplier] * len(captions))
|
||||
|
||||
example["alpha_mask"] = torch.stack(alpha_mask_list) if alpha_mask_list[0] is not None else None
|
||||
|
||||
if self.debug_dataset:
|
||||
example["image_keys"] = bucket[image_index : image_index + self.batch_size]
|
||||
return example
|
||||
@@ -1393,6 +1411,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
resized_sizes = []
|
||||
bucket_reso = None
|
||||
flip_aug = None
|
||||
alpha_mask = None
|
||||
random_crop = None
|
||||
|
||||
for image_key in bucket[image_index : image_index + bucket_batch_size]:
|
||||
@@ -1401,10 +1420,13 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
if flip_aug is None:
|
||||
flip_aug = subset.flip_aug
|
||||
alpha_mask = subset.alpha_mask
|
||||
random_crop = subset.random_crop
|
||||
bucket_reso = image_info.bucket_reso
|
||||
else:
|
||||
# TODO そもそも混在してても動くようにしたほうがいい
|
||||
assert flip_aug == subset.flip_aug, "flip_aug must be same in a batch"
|
||||
assert alpha_mask == subset.alpha_mask, "alpha_mask must be same in a batch"
|
||||
assert random_crop == subset.random_crop, "random_crop must be same in a batch"
|
||||
assert bucket_reso == image_info.bucket_reso, "bucket_reso must be same in a batch"
|
||||
|
||||
@@ -1441,6 +1463,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
example["absolute_paths"] = absolute_paths
|
||||
example["resized_sizes"] = resized_sizes
|
||||
example["flip_aug"] = flip_aug
|
||||
example["alpha_mask"] = alpha_mask
|
||||
example["random_crop"] = random_crop
|
||||
example["bucket_reso"] = bucket_reso
|
||||
return example
|
||||
@@ -2149,7 +2172,7 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
dataset.disable_token_padding()
|
||||
|
||||
|
||||
def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool):
|
||||
def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alpha_mask: bool):
|
||||
expected_latents_size = (reso[1] // 8, reso[0] // 8) # bucket_resoはWxHなので注意
|
||||
|
||||
if not os.path.exists(npz_path):
|
||||
@@ -2167,6 +2190,12 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool):
|
||||
return False
|
||||
if npz["latents_flipped"].shape[1:3] != expected_latents_size:
|
||||
return False
|
||||
|
||||
if alpha_mask:
|
||||
if "alpha_mask" not in npz:
|
||||
return False
|
||||
if npz["alpha_mask"].shape[0:2] != reso: # HxW
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {npz_path}")
|
||||
raise e
|
||||
@@ -2177,14 +2206,7 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool):
|
||||
# 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top)
|
||||
def load_latents_from_disk(
|
||||
npz_path,
|
||||
) -> Tuple[
|
||||
Optional[torch.Tensor],
|
||||
Optional[List[int]],
|
||||
Optional[List[int]],
|
||||
Optional[torch.Tensor],
|
||||
Optional[torch.Tensor],
|
||||
Optional[torch.Tensor],
|
||||
]:
|
||||
) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
||||
npz = np.load(npz_path)
|
||||
if "latents" not in npz:
|
||||
raise ValueError(f"error: npz is old format. please re-generate {npz_path}")
|
||||
@@ -2194,20 +2216,15 @@ def load_latents_from_disk(
|
||||
crop_ltrb = npz["crop_ltrb"].tolist()
|
||||
flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None
|
||||
alpha_mask = npz["alpha_mask"] if "alpha_mask" in npz else None
|
||||
flipped_alpha_mask = npz["flipped_alpha_mask"] if "flipped_alpha_mask" in npz else None
|
||||
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask, flipped_alpha_mask
|
||||
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
|
||||
|
||||
|
||||
def save_latents_to_disk(
|
||||
npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None, flipped_alpha_mask=None
|
||||
):
|
||||
def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None):
|
||||
kwargs = {}
|
||||
if flipped_latents_tensor is not None:
|
||||
kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy()
|
||||
if alpha_mask is not None:
|
||||
kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy()
|
||||
if flipped_alpha_mask is not None:
|
||||
kwargs["flipped_alpha_mask"] = flipped_alpha_mask.float().cpu().numpy()
|
||||
kwargs["alpha_mask"] = alpha_mask # ndarray
|
||||
np.savez(
|
||||
npz_path,
|
||||
latents=latents_tensor.float().cpu().numpy(),
|
||||
@@ -2398,10 +2415,11 @@ def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset:
|
||||
|
||||
def load_image(image_path, alpha=False):
|
||||
image = Image.open(image_path)
|
||||
if not image.mode == "RGB":
|
||||
if alpha:
|
||||
if alpha:
|
||||
if not image.mode == "RGBA":
|
||||
image = image.convert("RGBA")
|
||||
else:
|
||||
else:
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
img = np.array(image, np.uint8)
|
||||
return img
|
||||
@@ -2441,7 +2459,7 @@ def trim_and_resize_if_required(
|
||||
|
||||
|
||||
def cache_batch_latents(
|
||||
vae: AutoencoderKL, cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, random_crop: bool
|
||||
vae: AutoencoderKL, cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, use_alpha_mask: bool, random_crop: bool
|
||||
) -> None:
|
||||
r"""
|
||||
requires image_infos to have: absolute_path, bucket_reso, resized_size, latents_npz
|
||||
@@ -2453,49 +2471,43 @@ def cache_batch_latents(
|
||||
latents_original_size and latents_crop_ltrb are also set
|
||||
"""
|
||||
images = []
|
||||
alpha_masks = []
|
||||
alpha_masks: List[np.ndarray] = []
|
||||
for info in image_infos:
|
||||
image = load_image(info.absolute_path, info.use_alpha_mask) if info.image is None else np.array(info.image, np.uint8)
|
||||
image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8)
|
||||
# TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要
|
||||
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size)
|
||||
if info.use_alpha_mask:
|
||||
if image.shape[2] == 4:
|
||||
alpha_mask = image[:, :, 3] # [W,H]
|
||||
image = image[:, :, :3]
|
||||
else:
|
||||
alpha_mask = np.full_like(image[:, :, 0], 255, dtype=np.uint8) # [W,H]
|
||||
alpha_masks.append(transforms.ToTensor()(alpha_mask))
|
||||
image = IMAGE_TRANSFORMS(image)
|
||||
images.append(image)
|
||||
|
||||
info.latents_original_size = original_size
|
||||
info.latents_crop_ltrb = crop_ltrb
|
||||
|
||||
if use_alpha_mask:
|
||||
if image.shape[2] == 4:
|
||||
alpha_mask = image[:, :, 3] # [H,W]
|
||||
alpha_mask = alpha_mask.astype(np.float32) / 255.0
|
||||
else:
|
||||
alpha_mask = np.ones_like(image[:, :, 0], dtype=np.float32)
|
||||
else:
|
||||
alpha_mask = None
|
||||
alpha_masks.append(alpha_mask)
|
||||
|
||||
image = image[:, :, :3] # remove alpha channel if exists
|
||||
image = IMAGE_TRANSFORMS(image)
|
||||
images.append(image)
|
||||
|
||||
img_tensors = torch.stack(images, dim=0)
|
||||
img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype)
|
||||
|
||||
with torch.no_grad():
|
||||
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
|
||||
|
||||
if info.use_alpha_mask:
|
||||
alpha_masks = torch.stack(alpha_masks, dim=0).to("cpu")
|
||||
else:
|
||||
alpha_masks = [None] * len(image_infos)
|
||||
flipped_alpha_masks = [None] * len(image_infos)
|
||||
|
||||
if flip_aug:
|
||||
img_tensors = torch.flip(img_tensors, dims=[3])
|
||||
with torch.no_grad():
|
||||
flipped_latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
|
||||
if info.use_alpha_mask:
|
||||
flipped_alpha_masks = torch.flip(alpha_masks, dims=[3])
|
||||
else:
|
||||
flipped_latents = [None] * len(latents)
|
||||
flipped_alpha_masks = [None] * len(image_infos)
|
||||
|
||||
for info, latent, flipped_latent, alpha_mask, flipped_alpha_mask in zip(
|
||||
image_infos, latents, flipped_latents, alpha_masks, flipped_alpha_masks
|
||||
):
|
||||
for info, latent, flipped_latent, alpha_mask in zip(image_infos, latents, flipped_latents, alpha_masks):
|
||||
# check NaN
|
||||
if torch.isnan(latents).any() or (flipped_latent is not None and torch.isnan(flipped_latent).any()):
|
||||
raise RuntimeError(f"NaN detected in latents: {info.absolute_path}")
|
||||
@@ -2508,15 +2520,12 @@ def cache_batch_latents(
|
||||
info.latents_crop_ltrb,
|
||||
flipped_latent,
|
||||
alpha_mask,
|
||||
flipped_alpha_mask,
|
||||
)
|
||||
else:
|
||||
info.latents = latent
|
||||
if flip_aug:
|
||||
info.latents_flipped = flipped_latent
|
||||
|
||||
info.alpha_mask = alpha_mask
|
||||
info.alpha_mask_flipped = flipped_alpha_mask
|
||||
|
||||
if not HIGH_VRAM:
|
||||
clean_memory_on_device(vae.device)
|
||||
|
||||
Reference in New Issue
Block a user