mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
revert kwargs to explicit declaration
This commit is contained in:
@@ -409,6 +409,7 @@ class BaseSubset:
|
|||||||
|
|
||||||
self.alpha_mask = alpha_mask
|
self.alpha_mask = alpha_mask
|
||||||
|
|
||||||
|
|
||||||
class DreamBoothSubset(BaseSubset):
|
class DreamBoothSubset(BaseSubset):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -417,13 +418,47 @@ class DreamBoothSubset(BaseSubset):
|
|||||||
class_tokens: Optional[str],
|
class_tokens: Optional[str],
|
||||||
caption_extension: str,
|
caption_extension: str,
|
||||||
cache_info: bool,
|
cache_info: bool,
|
||||||
**kwargs,
|
num_repeats,
|
||||||
|
shuffle_caption,
|
||||||
|
caption_separator: str,
|
||||||
|
keep_tokens,
|
||||||
|
keep_tokens_separator,
|
||||||
|
secondary_separator,
|
||||||
|
enable_wildcard,
|
||||||
|
color_aug,
|
||||||
|
flip_aug,
|
||||||
|
face_crop_aug_range,
|
||||||
|
random_crop,
|
||||||
|
caption_dropout_rate,
|
||||||
|
caption_dropout_every_n_epochs,
|
||||||
|
caption_tag_dropout_rate,
|
||||||
|
caption_prefix,
|
||||||
|
caption_suffix,
|
||||||
|
token_warmup_min,
|
||||||
|
token_warmup_step,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
|
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
image_dir,
|
image_dir,
|
||||||
**kwargs,
|
num_repeats,
|
||||||
|
shuffle_caption,
|
||||||
|
caption_separator,
|
||||||
|
keep_tokens,
|
||||||
|
keep_tokens_separator,
|
||||||
|
secondary_separator,
|
||||||
|
enable_wildcard,
|
||||||
|
color_aug,
|
||||||
|
flip_aug,
|
||||||
|
face_crop_aug_range,
|
||||||
|
random_crop,
|
||||||
|
caption_dropout_rate,
|
||||||
|
caption_dropout_every_n_epochs,
|
||||||
|
caption_tag_dropout_rate,
|
||||||
|
caption_prefix,
|
||||||
|
caption_suffix,
|
||||||
|
token_warmup_min,
|
||||||
|
token_warmup_step,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.is_reg = is_reg
|
self.is_reg = is_reg
|
||||||
@@ -444,13 +479,47 @@ class FineTuningSubset(BaseSubset):
|
|||||||
self,
|
self,
|
||||||
image_dir,
|
image_dir,
|
||||||
metadata_file: str,
|
metadata_file: str,
|
||||||
**kwargs,
|
num_repeats,
|
||||||
|
shuffle_caption,
|
||||||
|
caption_separator,
|
||||||
|
keep_tokens,
|
||||||
|
keep_tokens_separator,
|
||||||
|
secondary_separator,
|
||||||
|
enable_wildcard,
|
||||||
|
color_aug,
|
||||||
|
flip_aug,
|
||||||
|
face_crop_aug_range,
|
||||||
|
random_crop,
|
||||||
|
caption_dropout_rate,
|
||||||
|
caption_dropout_every_n_epochs,
|
||||||
|
caption_tag_dropout_rate,
|
||||||
|
caption_prefix,
|
||||||
|
caption_suffix,
|
||||||
|
token_warmup_min,
|
||||||
|
token_warmup_step,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
|
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
image_dir,
|
image_dir,
|
||||||
**kwargs,
|
num_repeats,
|
||||||
|
shuffle_caption,
|
||||||
|
caption_separator,
|
||||||
|
keep_tokens,
|
||||||
|
keep_tokens_separator,
|
||||||
|
secondary_separator,
|
||||||
|
enable_wildcard,
|
||||||
|
color_aug,
|
||||||
|
flip_aug,
|
||||||
|
face_crop_aug_range,
|
||||||
|
random_crop,
|
||||||
|
caption_dropout_rate,
|
||||||
|
caption_dropout_every_n_epochs,
|
||||||
|
caption_tag_dropout_rate,
|
||||||
|
caption_prefix,
|
||||||
|
caption_suffix,
|
||||||
|
token_warmup_min,
|
||||||
|
token_warmup_step,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.metadata_file = metadata_file
|
self.metadata_file = metadata_file
|
||||||
@@ -468,13 +537,47 @@ class ControlNetSubset(BaseSubset):
|
|||||||
conditioning_data_dir: str,
|
conditioning_data_dir: str,
|
||||||
caption_extension: str,
|
caption_extension: str,
|
||||||
cache_info: bool,
|
cache_info: bool,
|
||||||
**kwargs,
|
num_repeats,
|
||||||
|
shuffle_caption,
|
||||||
|
caption_separator,
|
||||||
|
keep_tokens,
|
||||||
|
keep_tokens_separator,
|
||||||
|
secondary_separator,
|
||||||
|
enable_wildcard,
|
||||||
|
color_aug,
|
||||||
|
flip_aug,
|
||||||
|
face_crop_aug_range,
|
||||||
|
random_crop,
|
||||||
|
caption_dropout_rate,
|
||||||
|
caption_dropout_every_n_epochs,
|
||||||
|
caption_tag_dropout_rate,
|
||||||
|
caption_prefix,
|
||||||
|
caption_suffix,
|
||||||
|
token_warmup_min,
|
||||||
|
token_warmup_step,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
|
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
image_dir,
|
image_dir,
|
||||||
**kwargs,
|
num_repeats,
|
||||||
|
shuffle_caption,
|
||||||
|
caption_separator,
|
||||||
|
keep_tokens,
|
||||||
|
keep_tokens_separator,
|
||||||
|
secondary_separator,
|
||||||
|
enable_wildcard,
|
||||||
|
color_aug,
|
||||||
|
flip_aug,
|
||||||
|
face_crop_aug_range,
|
||||||
|
random_crop,
|
||||||
|
caption_dropout_rate,
|
||||||
|
caption_dropout_every_n_epochs,
|
||||||
|
caption_tag_dropout_rate,
|
||||||
|
caption_prefix,
|
||||||
|
caption_suffix,
|
||||||
|
token_warmup_min,
|
||||||
|
token_warmup_step,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.conditioning_data_dir = conditioning_data_dir
|
self.conditioning_data_dir = conditioning_data_dir
|
||||||
@@ -1100,10 +1203,12 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
else:
|
else:
|
||||||
latents = image_info.latents_flipped
|
latents = image_info.latents_flipped
|
||||||
alpha_mask = image_info.alpha_mask_flipped
|
alpha_mask = image_info.alpha_mask_flipped
|
||||||
|
|
||||||
image = None
|
image = None
|
||||||
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
|
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
|
||||||
latents, original_size, crop_ltrb, flipped_latents, alpha_mask, flipped_alpha_mask = load_latents_from_disk(image_info.latents_npz)
|
latents, original_size, crop_ltrb, flipped_latents, alpha_mask, flipped_alpha_mask = load_latents_from_disk(
|
||||||
|
image_info.latents_npz
|
||||||
|
)
|
||||||
if flipped:
|
if flipped:
|
||||||
latents = flipped_latents
|
latents = flipped_latents
|
||||||
alpha_mask = flipped_alpha_mask
|
alpha_mask = flipped_alpha_mask
|
||||||
@@ -1116,7 +1221,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
image = None
|
image = None
|
||||||
else:
|
else:
|
||||||
# 画像を読み込み、必要ならcropする
|
# 画像を読み込み、必要ならcropする
|
||||||
img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(subset, image_info.absolute_path, subset.alpha_mask)
|
img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(
|
||||||
|
subset, image_info.absolute_path, subset.alpha_mask
|
||||||
|
)
|
||||||
im_h, im_w = img.shape[0:2]
|
im_h, im_w = img.shape[0:2]
|
||||||
|
|
||||||
if self.enable_bucket:
|
if self.enable_bucket:
|
||||||
@@ -1157,7 +1264,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
if img.shape[2] == 4:
|
if img.shape[2] == 4:
|
||||||
alpha_mask = img[:, :, 3] # [W,H]
|
alpha_mask = img[:, :, 3] # [W,H]
|
||||||
else:
|
else:
|
||||||
alpha_mask = np.full((im_w, im_h), 255, dtype=np.uint8) # [W,H]
|
alpha_mask = np.full((im_w, im_h), 255, dtype=np.uint8) # [W,H]
|
||||||
alpha_mask = transforms.ToTensor()(alpha_mask)
|
alpha_mask = transforms.ToTensor()(alpha_mask)
|
||||||
else:
|
else:
|
||||||
alpha_mask = None
|
alpha_mask = None
|
||||||
@@ -2070,7 +2177,14 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool):
|
|||||||
# 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top)
|
# 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top)
|
||||||
def load_latents_from_disk(
|
def load_latents_from_disk(
|
||||||
npz_path,
|
npz_path,
|
||||||
) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
) -> Tuple[
|
||||||
|
Optional[torch.Tensor],
|
||||||
|
Optional[List[int]],
|
||||||
|
Optional[List[int]],
|
||||||
|
Optional[torch.Tensor],
|
||||||
|
Optional[torch.Tensor],
|
||||||
|
Optional[torch.Tensor],
|
||||||
|
]:
|
||||||
npz = np.load(npz_path)
|
npz = np.load(npz_path)
|
||||||
if "latents" not in npz:
|
if "latents" not in npz:
|
||||||
raise ValueError(f"error: npz is old format. please re-generate {npz_path}")
|
raise ValueError(f"error: npz is old format. please re-generate {npz_path}")
|
||||||
@@ -2084,7 +2198,9 @@ def load_latents_from_disk(
|
|||||||
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask, flipped_alpha_mask
|
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask, flipped_alpha_mask
|
||||||
|
|
||||||
|
|
||||||
def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None, flipped_alpha_mask=None):
|
def save_latents_to_disk(
|
||||||
|
npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None, flipped_alpha_mask=None
|
||||||
|
):
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if flipped_latents_tensor is not None:
|
if flipped_latents_tensor is not None:
|
||||||
kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy()
|
kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy()
|
||||||
@@ -2344,10 +2460,10 @@ def cache_batch_latents(
|
|||||||
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size)
|
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size)
|
||||||
if info.use_alpha_mask:
|
if info.use_alpha_mask:
|
||||||
if image.shape[2] == 4:
|
if image.shape[2] == 4:
|
||||||
alpha_mask = image[:, :, 3] # [W,H]
|
alpha_mask = image[:, :, 3] # [W,H]
|
||||||
image = image[:, :, :3]
|
image = image[:, :, :3]
|
||||||
else:
|
else:
|
||||||
alpha_mask = np.full_like(image[:, :, 0], 255, dtype=np.uint8) # [W,H]
|
alpha_mask = np.full_like(image[:, :, 0], 255, dtype=np.uint8) # [W,H]
|
||||||
alpha_masks.append(transforms.ToTensor()(alpha_mask))
|
alpha_masks.append(transforms.ToTensor()(alpha_mask))
|
||||||
image = IMAGE_TRANSFORMS(image)
|
image = IMAGE_TRANSFORMS(image)
|
||||||
images.append(image)
|
images.append(image)
|
||||||
@@ -2377,13 +2493,23 @@ def cache_batch_latents(
|
|||||||
flipped_latents = [None] * len(latents)
|
flipped_latents = [None] * len(latents)
|
||||||
flipped_alpha_masks = [None] * len(image_infos)
|
flipped_alpha_masks = [None] * len(image_infos)
|
||||||
|
|
||||||
for info, latent, flipped_latent, alpha_mask, flipped_alpha_mask in zip(image_infos, latents, flipped_latents, alpha_masks, flipped_alpha_masks):
|
for info, latent, flipped_latent, alpha_mask, flipped_alpha_mask in zip(
|
||||||
|
image_infos, latents, flipped_latents, alpha_masks, flipped_alpha_masks
|
||||||
|
):
|
||||||
# check NaN
|
# check NaN
|
||||||
if torch.isnan(latents).any() or (flipped_latent is not None and torch.isnan(flipped_latent).any()):
|
if torch.isnan(latents).any() or (flipped_latent is not None and torch.isnan(flipped_latent).any()):
|
||||||
raise RuntimeError(f"NaN detected in latents: {info.absolute_path}")
|
raise RuntimeError(f"NaN detected in latents: {info.absolute_path}")
|
||||||
|
|
||||||
if cache_to_disk:
|
if cache_to_disk:
|
||||||
save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_ltrb, flipped_latent, alpha_mask, flipped_alpha_mask)
|
save_latents_to_disk(
|
||||||
|
info.latents_npz,
|
||||||
|
latent,
|
||||||
|
info.latents_original_size,
|
||||||
|
info.latents_crop_ltrb,
|
||||||
|
flipped_latent,
|
||||||
|
alpha_mask,
|
||||||
|
flipped_alpha_mask,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
info.latents = latent
|
info.latents = latent
|
||||||
if flip_aug:
|
if flip_aug:
|
||||||
|
|||||||
Reference in New Issue
Block a user