mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
add warning for bucket_reso_steps with SDXL
This commit is contained in:
@@ -52,6 +52,10 @@ def main(args):
|
||||
# assert args.bucket_reso_steps % 8 == 0, f"bucket_reso_steps must be divisible by 8 / bucket_reso_stepは8で割り切れる必要があります"
|
||||
if args.bucket_reso_steps % 8 > 0:
|
||||
print(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります")
|
||||
if args.bucket_reso_steps % 32 > 0:
|
||||
print(
|
||||
f"WARNING: bucket_reso_steps is not divisible by 32. It is not working with SDXL / bucket_reso_stepsが32で割り切れません。SDXLでは動作しません"
|
||||
)
|
||||
|
||||
train_data_dir_path = Path(args.train_data_dir)
|
||||
image_paths: List[str] = [str(p) for p in train_util.glob_images_pathlib(train_data_dir_path, args.recursive)]
|
||||
|
||||
@@ -800,6 +800,12 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
random.shuffle(self.buckets_indices)
|
||||
self.bucket_manager.shuffle()
|
||||
|
||||
def verify_bucket_reso_steps(self, min_steps: int):
|
||||
assert self.bucket_reso_steps is None or self.bucket_reso_steps % min_steps == 0, (
|
||||
f"bucket_reso_steps is {self.bucket_reso_steps}. it must be divisible by {min_steps}.\n"
|
||||
+ f"bucket_reso_stepsが{self.bucket_reso_steps}です。{min_steps}で割り切れる必要があります"
|
||||
)
|
||||
|
||||
def is_latent_cacheable(self):
|
||||
return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])
|
||||
|
||||
@@ -1831,6 +1837,10 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
for dataset in self.datasets:
|
||||
dataset.set_caching_mode(caching_mode)
|
||||
|
||||
def verify_bucket_reso_steps(self, min_steps: int):
|
||||
for dataset in self.datasets:
|
||||
dataset.verify_bucket_reso_steps(min_steps)
|
||||
|
||||
def is_latent_cacheable(self) -> bool:
|
||||
return all([dataset.is_latent_cacheable() for dataset in self.datasets])
|
||||
|
||||
@@ -2020,6 +2030,9 @@ class MinimalDataset(BaseDataset):
|
||||
self.is_reg = False
|
||||
self.image_dir = "dummy" # for metadata
|
||||
|
||||
def verify_bucket_reso_steps(self, min_steps: int):
|
||||
pass
|
||||
|
||||
def is_latent_cacheable(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@@ -98,6 +98,8 @@ def train(args):
|
||||
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(32)
|
||||
|
||||
if args.debug_dataset:
|
||||
train_util.debug_dataset(train_dataset_group, True)
|
||||
return
|
||||
|
||||
@@ -23,6 +23,8 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
args.network_train_unet_only or not args.cache_text_encoder_outputs
|
||||
), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません"
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(32)
|
||||
|
||||
def load_target_model(self, args, weight_dtype, accelerator):
|
||||
(
|
||||
load_stable_diffusion_format,
|
||||
|
||||
@@ -19,6 +19,8 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
|
||||
super().assert_extra_args(args, train_dataset_group)
|
||||
sdxl_train_util.verify_sdxl_training_args(args, supportTextEncoderCaching=False)
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(32)
|
||||
|
||||
def load_target_model(self, args, weight_dtype, accelerator):
|
||||
(
|
||||
load_stable_diffusion_format,
|
||||
|
||||
Reference in New Issue
Block a user