mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix flip_aug, alpha_mask, random_crop issue in caching in caching strategy
This commit is contained in:
@@ -993,9 +993,26 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
# sort by resolution
|
||||
image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1])
|
||||
|
||||
# split by resolution
|
||||
batches = []
|
||||
batch = []
|
||||
# split by resolution and some conditions
|
||||
class Condition:
|
||||
def __init__(self, reso, flip_aug, alpha_mask, random_crop):
|
||||
self.reso = reso
|
||||
self.flip_aug = flip_aug
|
||||
self.alpha_mask = alpha_mask
|
||||
self.random_crop = random_crop
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
self.reso == other.reso
|
||||
and self.flip_aug == other.flip_aug
|
||||
and self.alpha_mask == other.alpha_mask
|
||||
and self.random_crop == other.random_crop
|
||||
)
|
||||
|
||||
batches: List[Tuple[Condition, List[ImageInfo]]] = []
|
||||
batch: List[ImageInfo] = []
|
||||
current_condition = None
|
||||
|
||||
logger.info("checking cache validity...")
|
||||
for info in tqdm(image_infos):
|
||||
subset = self.image_to_subset[info.image_key]
|
||||
@@ -1016,20 +1033,23 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
if cache_available: # do not add to batch
|
||||
continue
|
||||
|
||||
# if last member of batch has different resolution, flush the batch
|
||||
if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso:
|
||||
batches.append(batch)
|
||||
# if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty
|
||||
condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop)
|
||||
if len(batch) > 0 and current_condition != condition:
|
||||
batches.append((current_condition, batch))
|
||||
batch = []
|
||||
|
||||
batch.append(info)
|
||||
current_condition = condition
|
||||
|
||||
# if number of data in batch is enough, flush the batch
|
||||
if len(batch) >= caching_strategy.batch_size:
|
||||
batches.append(batch)
|
||||
batches.append((current_condition, batch))
|
||||
batch = []
|
||||
current_condition = None
|
||||
|
||||
if len(batch) > 0:
|
||||
batches.append(batch)
|
||||
batches.append((current_condition, batch))
|
||||
|
||||
# if cache to disk, don't cache latents in non-main process, set to info only
|
||||
if caching_strategy.cache_to_disk and not is_main_process:
|
||||
@@ -1041,9 +1061,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
# iterate batches: batch doesn't have image here. image will be loaded in cache_batch_latents and discarded
|
||||
logger.info("caching latents...")
|
||||
for batch in tqdm(batches, smoothing=1, total=len(batches)):
|
||||
# cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop)
|
||||
caching_strategy.cache_batch_latents(model, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop)
|
||||
for condition, batch in tqdm(batches, smoothing=1, total=len(batches)):
|
||||
caching_strategy.cache_batch_latents(model, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop)
|
||||
|
||||
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"):
|
||||
# マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
|
||||
|
||||
Reference in New Issue
Block a user