Fix training, validation split, revert to using upstream implemenation

This commit is contained in:
rockerBOO
2025-01-03 15:20:25 -05:00
parent 6604b36044
commit 0522070d19
5 changed files with 152 additions and 160 deletions

View File

@@ -482,7 +482,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
dataset_klass = FineTuningDataset
subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params))
dataset = dataset_klass(subsets=subsets, is_training_dataset=True, **asdict(dataset_blueprint.params))
datasets.append(dataset)
val_datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
@@ -500,16 +500,16 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
dataset_klass = FineTuningDataset
subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params))
dataset = dataset_klass(subsets=subsets, is_training_dataset=False, **asdict(dataset_blueprint.params))
val_datasets.append(dataset)
def print_info(_datasets):
def print_info(_datasets, dataset_type: str):
info = ""
for i, dataset in enumerate(_datasets):
is_dreambooth = isinstance(dataset, DreamBoothDataset)
is_controlnet = isinstance(dataset, ControlNetDataset)
info += dedent(f"""\
[Dataset {i}]
[{dataset_type} {i}]
batch_size: {dataset.batch_size}
resolution: {(dataset.width, dataset.height)}
enable_bucket: {dataset.enable_bucket}
@@ -527,7 +527,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
for j, subset in enumerate(dataset.subsets):
info += indent(dedent(f"""\
[Subset {j} of Dataset {i}]
[Subset {j} of {dataset_type} {i}]
image_dir: "{subset.image_dir}"
image_count: {subset.img_count}
num_repeats: {subset.num_repeats}
@@ -544,8 +544,8 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
random_crop: {subset.random_crop}
token_warmup_min: {subset.token_warmup_min},
token_warmup_step: {subset.token_warmup_step},
alpha_mask: {subset.alpha_mask}
custom_attributes: {subset.custom_attributes}
alpha_mask: {subset.alpha_mask}
custom_attributes: {subset.custom_attributes}
"""), " ")
if is_dreambooth:
@@ -561,67 +561,22 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
logger.info(info)
print_info(datasets)
print_info(datasets, "Dataset")
if len(val_datasets) > 0:
logger.info("Validation dataset")
print_info(val_datasets)
if len(val_datasets) > 0:
info = ""
for i, dataset in enumerate(val_datasets):
info += dedent(
f"""\
[Validation Dataset {i}]
batch_size: {dataset.batch_size}
resolution: {(dataset.width, dataset.height)}
enable_bucket: {dataset.enable_bucket}
network_multiplier: {dataset.network_multiplier}
"""
)
if dataset.enable_bucket:
info += indent(
dedent(
f"""\
min_bucket_reso: {dataset.min_bucket_reso}
max_bucket_reso: {dataset.max_bucket_reso}
bucket_reso_steps: {dataset.bucket_reso_steps}
bucket_no_upscale: {dataset.bucket_no_upscale}
\n"""
),
" ",
)
else:
info += "\n"
for j, subset in enumerate(dataset.subsets):
info += indent(
dedent(
f"""\
[Subset {j} of Validation Dataset {i}]
image_dir: "{subset.image_dir}"
image_count: {subset.img_count}
num_repeats: {subset.num_repeats}
"""
),
" ",
)
logger.info(info)
print_info(val_datasets, "Validation Dataset")
# make buckets first because it determines the length of dataset
# and set the same seed for all datasets
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
for i, dataset in enumerate(datasets):
logger.info(f"[Dataset {i}]")
logger.info(f"[Prepare dataset {i}]")
dataset.make_buckets()
dataset.set_seed(seed)
for i, dataset in enumerate(val_datasets):
logger.info(f"[Validation Dataset {i}]")
logger.info(f"[Prepare validation dataset {i}]")
dataset.make_buckets()
dataset.set_seed(seed)