revert kwargs to explicit declaration

This commit is contained in:
Kohya S
2024-05-19 19:23:59 +09:00
parent db6752901f
commit f2dd43e198

View File

@@ -409,6 +409,7 @@ class BaseSubset:
self.alpha_mask = alpha_mask self.alpha_mask = alpha_mask
class DreamBoothSubset(BaseSubset): class DreamBoothSubset(BaseSubset):
def __init__( def __init__(
self, self,
@@ -417,13 +418,47 @@ class DreamBoothSubset(BaseSubset):
class_tokens: Optional[str], class_tokens: Optional[str],
caption_extension: str, caption_extension: str,
cache_info: bool, cache_info: bool,
**kwargs, num_repeats,
shuffle_caption,
caption_separator: str,
keep_tokens,
keep_tokens_separator,
secondary_separator,
enable_wildcard,
color_aug,
flip_aug,
face_crop_aug_range,
random_crop,
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
caption_prefix,
caption_suffix,
token_warmup_min,
token_warmup_step,
) -> None: ) -> None:
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
super().__init__( super().__init__(
image_dir, image_dir,
**kwargs, num_repeats,
shuffle_caption,
caption_separator,
keep_tokens,
keep_tokens_separator,
secondary_separator,
enable_wildcard,
color_aug,
flip_aug,
face_crop_aug_range,
random_crop,
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
caption_prefix,
caption_suffix,
token_warmup_min,
token_warmup_step,
) )
self.is_reg = is_reg self.is_reg = is_reg
@@ -444,13 +479,47 @@ class FineTuningSubset(BaseSubset):
self, self,
image_dir, image_dir,
metadata_file: str, metadata_file: str,
**kwargs, num_repeats,
shuffle_caption,
caption_separator,
keep_tokens,
keep_tokens_separator,
secondary_separator,
enable_wildcard,
color_aug,
flip_aug,
face_crop_aug_range,
random_crop,
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
caption_prefix,
caption_suffix,
token_warmup_min,
token_warmup_step,
) -> None: ) -> None:
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です" assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
super().__init__( super().__init__(
image_dir, image_dir,
**kwargs, num_repeats,
shuffle_caption,
caption_separator,
keep_tokens,
keep_tokens_separator,
secondary_separator,
enable_wildcard,
color_aug,
flip_aug,
face_crop_aug_range,
random_crop,
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
caption_prefix,
caption_suffix,
token_warmup_min,
token_warmup_step,
) )
self.metadata_file = metadata_file self.metadata_file = metadata_file
@@ -468,13 +537,47 @@ class ControlNetSubset(BaseSubset):
conditioning_data_dir: str, conditioning_data_dir: str,
caption_extension: str, caption_extension: str,
cache_info: bool, cache_info: bool,
**kwargs, num_repeats,
shuffle_caption,
caption_separator,
keep_tokens,
keep_tokens_separator,
secondary_separator,
enable_wildcard,
color_aug,
flip_aug,
face_crop_aug_range,
random_crop,
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
caption_prefix,
caption_suffix,
token_warmup_min,
token_warmup_step,
) -> None: ) -> None:
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
super().__init__( super().__init__(
image_dir, image_dir,
**kwargs, num_repeats,
shuffle_caption,
caption_separator,
keep_tokens,
keep_tokens_separator,
secondary_separator,
enable_wildcard,
color_aug,
flip_aug,
face_crop_aug_range,
random_crop,
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
caption_prefix,
caption_suffix,
token_warmup_min,
token_warmup_step,
) )
self.conditioning_data_dir = conditioning_data_dir self.conditioning_data_dir = conditioning_data_dir
@@ -1103,7 +1206,9 @@ class BaseDataset(torch.utils.data.Dataset):
image = None image = None
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合 elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
latents, original_size, crop_ltrb, flipped_latents, alpha_mask, flipped_alpha_mask = load_latents_from_disk(image_info.latents_npz) latents, original_size, crop_ltrb, flipped_latents, alpha_mask, flipped_alpha_mask = load_latents_from_disk(
image_info.latents_npz
)
if flipped: if flipped:
latents = flipped_latents latents = flipped_latents
alpha_mask = flipped_alpha_mask alpha_mask = flipped_alpha_mask
@@ -1116,7 +1221,9 @@ class BaseDataset(torch.utils.data.Dataset):
image = None image = None
else: else:
# 画像を読み込み、必要ならcropする # 画像を読み込み、必要ならcropする
img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(subset, image_info.absolute_path, subset.alpha_mask) img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(
subset, image_info.absolute_path, subset.alpha_mask
)
im_h, im_w = img.shape[0:2] im_h, im_w = img.shape[0:2]
if self.enable_bucket: if self.enable_bucket:
@@ -2070,7 +2177,14 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool):
# 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top) # 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top)
def load_latents_from_disk( def load_latents_from_disk(
npz_path, npz_path,
) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: ) -> Tuple[
Optional[torch.Tensor],
Optional[List[int]],
Optional[List[int]],
Optional[torch.Tensor],
Optional[torch.Tensor],
Optional[torch.Tensor],
]:
npz = np.load(npz_path) npz = np.load(npz_path)
if "latents" not in npz: if "latents" not in npz:
raise ValueError(f"error: npz is old format. please re-generate {npz_path}") raise ValueError(f"error: npz is old format. please re-generate {npz_path}")
@@ -2084,7 +2198,9 @@ def load_latents_from_disk(
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask, flipped_alpha_mask return latents, original_size, crop_ltrb, flipped_latents, alpha_mask, flipped_alpha_mask
def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None, flipped_alpha_mask=None): def save_latents_to_disk(
npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None, flipped_alpha_mask=None
):
kwargs = {} kwargs = {}
if flipped_latents_tensor is not None: if flipped_latents_tensor is not None:
kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy() kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy()
@@ -2377,13 +2493,23 @@ def cache_batch_latents(
flipped_latents = [None] * len(latents) flipped_latents = [None] * len(latents)
flipped_alpha_masks = [None] * len(image_infos) flipped_alpha_masks = [None] * len(image_infos)
for info, latent, flipped_latent, alpha_mask, flipped_alpha_mask in zip(image_infos, latents, flipped_latents, alpha_masks, flipped_alpha_masks): for info, latent, flipped_latent, alpha_mask, flipped_alpha_mask in zip(
image_infos, latents, flipped_latents, alpha_masks, flipped_alpha_masks
):
# check NaN # check NaN
if torch.isnan(latents).any() or (flipped_latent is not None and torch.isnan(flipped_latent).any()): if torch.isnan(latents).any() or (flipped_latent is not None and torch.isnan(flipped_latent).any()):
raise RuntimeError(f"NaN detected in latents: {info.absolute_path}") raise RuntimeError(f"NaN detected in latents: {info.absolute_path}")
if cache_to_disk: if cache_to_disk:
save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_ltrb, flipped_latent, alpha_mask, flipped_alpha_mask) save_latents_to_disk(
info.latents_npz,
latent,
info.latents_original_size,
info.latents_crop_ltrb,
flipped_latent,
alpha_mask,
flipped_alpha_mask,
)
else: else:
info.latents = latent info.latents = latent
if flip_aug: if flip_aug: