Support SD3.5M multi resolutional training

This commit is contained in:
Kohya S
2024-10-31 19:58:22 +09:00
parent 70a179e446
commit 1434d8506f
8 changed files with 215 additions and 10 deletions

View File

@@ -2510,6 +2510,9 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
for dataset in self.datasets:
dataset.verify_bucket_reso_steps(min_steps)
def get_resolutions(self) -> List[Tuple[int, int]]:
return [(dataset.width, dataset.height) for dataset in self.datasets]
def is_latent_cacheable(self) -> bool:
return all([dataset.is_latent_cacheable() for dataset in self.datasets])