revert kwargs to explicit declaration

This commit is contained in:
Kohya S
2024-05-19 19:23:59 +09:00
parent db6752901f
commit f2dd43e198

View File

@@ -409,6 +409,7 @@ class BaseSubset:
self.alpha_mask = alpha_mask
class DreamBoothSubset(BaseSubset):
def __init__(
self,
@@ -417,13 +418,47 @@ class DreamBoothSubset(BaseSubset):
class_tokens: Optional[str],
caption_extension: str,
cache_info: bool,
**kwargs,
num_repeats,
shuffle_caption,
caption_separator: str,
keep_tokens,
keep_tokens_separator,
secondary_separator,
enable_wildcard,
color_aug,
flip_aug,
face_crop_aug_range,
random_crop,
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
caption_prefix,
caption_suffix,
token_warmup_min,
token_warmup_step,
) -> None:
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
super().__init__(
image_dir,
**kwargs,
num_repeats,
shuffle_caption,
caption_separator,
keep_tokens,
keep_tokens_separator,
secondary_separator,
enable_wildcard,
color_aug,
flip_aug,
face_crop_aug_range,
random_crop,
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
caption_prefix,
caption_suffix,
token_warmup_min,
token_warmup_step,
)
self.is_reg = is_reg
@@ -444,13 +479,47 @@ class FineTuningSubset(BaseSubset):
self,
image_dir,
metadata_file: str,
**kwargs,
num_repeats,
shuffle_caption,
caption_separator,
keep_tokens,
keep_tokens_separator,
secondary_separator,
enable_wildcard,
color_aug,
flip_aug,
face_crop_aug_range,
random_crop,
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
caption_prefix,
caption_suffix,
token_warmup_min,
token_warmup_step,
) -> None:
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
super().__init__(
image_dir,
**kwargs,
num_repeats,
shuffle_caption,
caption_separator,
keep_tokens,
keep_tokens_separator,
secondary_separator,
enable_wildcard,
color_aug,
flip_aug,
face_crop_aug_range,
random_crop,
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
caption_prefix,
caption_suffix,
token_warmup_min,
token_warmup_step,
)
self.metadata_file = metadata_file
@@ -468,13 +537,47 @@ class ControlNetSubset(BaseSubset):
conditioning_data_dir: str,
caption_extension: str,
cache_info: bool,
**kwargs,
num_repeats,
shuffle_caption,
caption_separator,
keep_tokens,
keep_tokens_separator,
secondary_separator,
enable_wildcard,
color_aug,
flip_aug,
face_crop_aug_range,
random_crop,
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
caption_prefix,
caption_suffix,
token_warmup_min,
token_warmup_step,
) -> None:
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
super().__init__(
image_dir,
**kwargs,
num_repeats,
shuffle_caption,
caption_separator,
keep_tokens,
keep_tokens_separator,
secondary_separator,
enable_wildcard,
color_aug,
flip_aug,
face_crop_aug_range,
random_crop,
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
caption_prefix,
caption_suffix,
token_warmup_min,
token_warmup_step,
)
self.conditioning_data_dir = conditioning_data_dir
@@ -1103,7 +1206,9 @@ class BaseDataset(torch.utils.data.Dataset):
image = None
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
latents, original_size, crop_ltrb, flipped_latents, alpha_mask, flipped_alpha_mask = load_latents_from_disk(image_info.latents_npz)
latents, original_size, crop_ltrb, flipped_latents, alpha_mask, flipped_alpha_mask = load_latents_from_disk(
image_info.latents_npz
)
if flipped:
latents = flipped_latents
alpha_mask = flipped_alpha_mask
@@ -1116,7 +1221,9 @@ class BaseDataset(torch.utils.data.Dataset):
image = None
else:
# 画像を読み込み、必要ならcropする
img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(subset, image_info.absolute_path, subset.alpha_mask)
img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(
subset, image_info.absolute_path, subset.alpha_mask
)
im_h, im_w = img.shape[0:2]
if self.enable_bucket:
@@ -1157,7 +1264,7 @@ class BaseDataset(torch.utils.data.Dataset):
if img.shape[2] == 4:
alpha_mask = img[:, :, 3] # [W,H]
else:
alpha_mask = np.full((im_w, im_h), 255, dtype=np.uint8) # [W,H]
alpha_mask = np.full((im_w, im_h), 255, dtype=np.uint8) # [W,H]
alpha_mask = transforms.ToTensor()(alpha_mask)
else:
alpha_mask = None
@@ -2070,7 +2177,14 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool):
# 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top)
def load_latents_from_disk(
npz_path,
) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
) -> Tuple[
Optional[torch.Tensor],
Optional[List[int]],
Optional[List[int]],
Optional[torch.Tensor],
Optional[torch.Tensor],
Optional[torch.Tensor],
]:
npz = np.load(npz_path)
if "latents" not in npz:
raise ValueError(f"error: npz is old format. please re-generate {npz_path}")
@@ -2084,7 +2198,9 @@ def load_latents_from_disk(
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask, flipped_alpha_mask
def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None, flipped_alpha_mask=None):
def save_latents_to_disk(
npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None, flipped_alpha_mask=None
):
kwargs = {}
if flipped_latents_tensor is not None:
kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy()
@@ -2344,10 +2460,10 @@ def cache_batch_latents(
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size)
if info.use_alpha_mask:
if image.shape[2] == 4:
alpha_mask = image[:, :, 3] # [W,H]
alpha_mask = image[:, :, 3] # [W,H]
image = image[:, :, :3]
else:
alpha_mask = np.full_like(image[:, :, 0], 255, dtype=np.uint8) # [W,H]
alpha_mask = np.full_like(image[:, :, 0], 255, dtype=np.uint8) # [W,H]
alpha_masks.append(transforms.ToTensor()(alpha_mask))
image = IMAGE_TRANSFORMS(image)
images.append(image)
@@ -2377,13 +2493,23 @@ def cache_batch_latents(
flipped_latents = [None] * len(latents)
flipped_alpha_masks = [None] * len(image_infos)
for info, latent, flipped_latent, alpha_mask, flipped_alpha_mask in zip(image_infos, latents, flipped_latents, alpha_masks, flipped_alpha_masks):
for info, latent, flipped_latent, alpha_mask, flipped_alpha_mask in zip(
image_infos, latents, flipped_latents, alpha_masks, flipped_alpha_masks
):
# check NaN
if torch.isnan(latents).any() or (flipped_latent is not None and torch.isnan(flipped_latent).any()):
raise RuntimeError(f"NaN detected in latents: {info.absolute_path}")
if cache_to_disk:
save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_ltrb, flipped_latent, alpha_mask, flipped_alpha_mask)
save_latents_to_disk(
info.latents_npz,
latent,
info.latents_original_size,
info.latents_crop_ltrb,
flipped_latent,
alpha_mask,
flipped_alpha_mask,
)
else:
info.latents = latent
if flip_aug: