Merge branch 'sd3' into multi-gpu-caching

This commit is contained in:
Kohya S
2024-09-29 10:12:18 +09:00
13 changed files with 75 additions and 14 deletions

View File

@@ -661,6 +661,34 @@ class BaseDataset(torch.utils.data.Dataset):
self.text_encoder_output_caching_strategy = TextEncoderOutputsCachingStrategy.get_strategy()
self.latents_caching_strategy = LatentsCachingStrategy.get_strategy()
def adjust_min_max_bucket_reso_by_steps(
self, resolution: Tuple[int, int], min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int
) -> Tuple[int, int]:
# make min/max bucket reso to be multiple of bucket_reso_steps
if min_bucket_reso % bucket_reso_steps != 0:
adjusted_min_bucket_reso = min_bucket_reso - min_bucket_reso % bucket_reso_steps
logger.warning(
f"min_bucket_reso is adjusted to be multiple of bucket_reso_steps"
f" / min_bucket_resoがbucket_reso_stepsの倍数になるように調整されました: {min_bucket_reso} -> {adjusted_min_bucket_reso}"
)
min_bucket_reso = adjusted_min_bucket_reso
if max_bucket_reso % bucket_reso_steps != 0:
adjusted_max_bucket_reso = max_bucket_reso + bucket_reso_steps - max_bucket_reso % bucket_reso_steps
logger.warning(
f"max_bucket_reso is adjusted to be multiple of bucket_reso_steps"
f" / max_bucket_resoがbucket_reso_stepsの倍数になるように調整されました: {max_bucket_reso} -> {adjusted_max_bucket_reso}"
)
max_bucket_reso = adjusted_max_bucket_reso
assert (
min(resolution) >= min_bucket_reso
), f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください"
assert (
max(resolution) <= max_bucket_reso
), f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
return min_bucket_reso, max_bucket_reso
def set_seed(self, seed):
self.seed = seed
@@ -1712,12 +1740,9 @@ class DreamBoothDataset(BaseDataset):
self.enable_bucket = enable_bucket
if self.enable_bucket:
assert (
min(resolution) >= min_bucket_reso
), f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください"
assert (
max(resolution) <= max_bucket_reso
), f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
min_bucket_reso, max_bucket_reso = self.adjust_min_max_bucket_reso_by_steps(
resolution, min_bucket_reso, max_bucket_reso, bucket_reso_steps
)
self.min_bucket_reso = min_bucket_reso
self.max_bucket_reso = max_bucket_reso
self.bucket_reso_steps = bucket_reso_steps
@@ -2090,6 +2115,9 @@ class FineTuningDataset(BaseDataset):
self.enable_bucket = enable_bucket
if self.enable_bucket:
min_bucket_reso, max_bucket_reso = self.adjust_min_max_bucket_reso_by_steps(
resolution, min_bucket_reso, max_bucket_reso, bucket_reso_steps
)
self.min_bucket_reso = min_bucket_reso
self.max_bucket_reso = max_bucket_reso
self.bucket_reso_steps = bucket_reso_steps
@@ -4154,8 +4182,20 @@ def add_dataset_arguments(
action="store_true",
help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする",
)
parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最大解像度")
parser.add_argument(
"--min_bucket_reso",
type=int,
default=256,
help="minimum resolution for buckets, must be divisible by bucket_reso_steps "
" / bucketの最小解像度、bucket_reso_stepsで割り切れる必要があります",
)
parser.add_argument(
"--max_bucket_reso",
type=int,
default=1024,
help="maximum resolution for buckets, must be divisible by bucket_reso_steps "
" / bucketの最大解像度、bucket_reso_stepsで割り切れる必要があります",
)
parser.add_argument(
"--bucket_reso_steps",
type=int,