Update PO cached latents, move out functions, update calls

This commit is contained in:
rockerBOO
2025-04-27 17:38:50 -04:00
parent 74529743d4
commit d22c827544
11 changed files with 480 additions and 129 deletions

View File

@@ -1700,10 +1700,14 @@ class BaseDataset(torch.utils.data.Dataset):
latents = image_info.latents_flipped
alpha_mask = None if image_info.alpha_mask is None else torch.flip(image_info.alpha_mask, [1])
target_size = (latents.shape[2] * 8, latents.shape[1] * 8)
image = None
images.append(image)
latents_list.append(latents)
original_sizes_hw.append((int(original_size[1]), int(original_size[0])))
crop_top_lefts.append((int(crop_ltrb[1]), int(crop_ltrb[0])))
target_sizes_hw.append((int(target_size[1]), int(target_size[0])))
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
latents, original_size, crop_ltrb, flipped_latents, alpha_mask = (
self.latents_caching_strategy.load_latents_from_disk(image_info.latents_npz, image_info.bucket_reso)
@@ -1715,12 +1719,16 @@ class BaseDataset(torch.utils.data.Dataset):
latents = torch.FloatTensor(latents)
if alpha_mask is not None:
alpha_mask = torch.FloatTensor(alpha_mask)
target_size = (latents.shape[2] * 8, latents.shape[1] * 8)
image = None
images.append(image)
latents_list.append(latents)
alpha_mask_list.append(alpha_mask)
original_sizes_hw.append((int(original_size[1]), int(original_size[0])))
crop_top_lefts.append((int(crop_ltrb[1]), int(crop_ltrb[0])))
target_sizes_hw.append((int(target_size[1]), int(target_size[0])))
else:
if isinstance(image_info, ImageSetInfo):
for absolute_path in image_info.absolute_paths:
@@ -2543,11 +2551,11 @@ class ControlNetDataset(BaseDataset):
subset.token_warmup_min,
subset.token_warmup_step,
resize_interpolation=subset.resize_interpolation,
subset.preference,
subset.preference_caption_prefix,
subset.preference_caption_suffix,
subset.non_preference_caption_prefix,
subset.non_preference_caption_suffix,
preference=subset.preference,
preference_caption_prefix=subset.preference_caption_prefix,
preference_caption_suffix=subset.preference_caption_suffix,
non_preference_caption_prefix=subset.non_preference_caption_prefix,
non_preference_caption_suffix=subset.non_preference_caption_suffix,
)
db_subsets.append(db_subset)