add warning for bucket_reso_steps with SDXL

This commit is contained in:
Kohya S
2023-08-11 19:02:36 +09:00
parent bf31f18c46
commit 6889ee2b85
5 changed files with 23 additions and 0 deletions

View File

@@ -800,6 +800,12 @@ class BaseDataset(torch.utils.data.Dataset):
random.shuffle(self.buckets_indices)
self.bucket_manager.shuffle()
def verify_bucket_reso_steps(self, min_steps: int):
assert self.bucket_reso_steps is None or self.bucket_reso_steps % min_steps == 0, (
f"bucket_reso_steps is {self.bucket_reso_steps}. it must be divisible by {min_steps}.\n"
+ f"bucket_reso_stepsが{self.bucket_reso_steps}です。{min_steps}で割り切れる必要があります"
)
def is_latent_cacheable(self):
return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])
@@ -1831,6 +1837,10 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
for dataset in self.datasets:
dataset.set_caching_mode(caching_mode)
def verify_bucket_reso_steps(self, min_steps: int):
for dataset in self.datasets:
dataset.verify_bucket_reso_steps(min_steps)
def is_latent_cacheable(self) -> bool:
return all([dataset.is_latent_cacheable() for dataset in self.datasets])
@@ -2020,6 +2030,9 @@ class MinimalDataset(BaseDataset):
self.is_reg = False
self.image_dir = "dummy" # for metadata
def verify_bucket_reso_steps(self, min_steps: int):
pass
def is_latent_cacheable(self) -> bool:
return False