This commit is contained in:
gesen2egee
2024-03-10 18:55:48 +08:00
parent b558a5b73d
commit 78cfb01922
2 changed files with 230 additions and 89 deletions

View File

@@ -41,12 +41,17 @@ from .train_util import (
DatasetGroup,
)
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def add_config_arguments(parser: argparse.ArgumentParser):
parser.add_argument("--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル")
parser.add_argument(
"--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル"
)
# TODO: inherit Params class in Subset, Dataset
@@ -60,6 +65,8 @@ class BaseSubsetParams:
caption_separator: str = (",",)
keep_tokens: int = 0
keep_tokens_separator: str = (None,)
secondary_separator: Optional[str] = None
enable_wildcard: bool = False
color_aug: bool = False
flip_aug: bool = False
face_crop_aug_range: Optional[Tuple[float, float]] = None
@@ -181,6 +188,8 @@ class ConfigSanitizer:
"shuffle_caption": bool,
"keep_tokens": int,
"keep_tokens_separator": str,
"secondary_separator": str,
"enable_wildcard": bool,
"token_warmup_min": int,
"token_warmup_step": Any(float, int),
"caption_prefix": str,
@@ -247,9 +256,10 @@ class ConfigSanitizer:
}
def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None:
assert (
support_dreambooth or support_finetuning or support_controlnet
), "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。"
assert support_dreambooth or support_finetuning or support_controlnet, (
"Neither DreamBooth mode nor fine tuning mode nor controlnet mode specified. Please specify one mode or more."
+ " / DreamBooth モードか fine tuning モードか controlnet モードのども指定されていません。1つ以上指定してください。"
)
self.db_subset_schema = self.__merge_dict(
self.SUBSET_ASCENDABLE_SCHEMA,
@@ -361,7 +371,9 @@ class ConfigSanitizer:
return self.argparse_config_validator(argparse_namespace)
except MultipleInvalid:
# XXX: this should be a bug
logger.error("Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。")
logger.error(
"Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。"
)
raise
# NOTE: value would be overwritten by latter dict if there is already the same key
@@ -447,7 +459,6 @@ class BlueprintGenerator:
return default_value
def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint):
datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
@@ -485,35 +496,47 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
dataset = dataset_klass(subsets=subsets, is_train=False, **asdict(dataset_blueprint.params))
val_datasets.append(dataset)
def print_info(_datasets):
# print info
info = ""
for i, dataset in enumerate(_datasets):
for i, dataset in enumerate(datasets):
is_dreambooth = isinstance(dataset, DreamBoothDataset)
is_controlnet = isinstance(dataset, ControlNetDataset)
info += dedent(f"""\
info += dedent(
f"""\
[Dataset {i}]
batch_size: {dataset.batch_size}
resolution: {(dataset.width, dataset.height)}
enable_bucket: {dataset.enable_bucket}
""")
network_multiplier: {dataset.network_multiplier}
"""
)
if dataset.enable_bucket:
info += indent(dedent(f"""\
info += indent(
dedent(
f"""\
min_bucket_reso: {dataset.min_bucket_reso}
max_bucket_reso: {dataset.max_bucket_reso}
bucket_reso_steps: {dataset.bucket_reso_steps}
bucket_no_upscale: {dataset.bucket_no_upscale}
\n"""), " ")
\n"""
),
" ",
)
else:
info += "\n"
for j, subset in enumerate(dataset.subsets):
info += indent(dedent(f"""\
info += indent(
dedent(
f"""\
[Subset {j} of Dataset {i}]
image_dir: "{subset.image_dir}"
image_count: {subset.img_count}
num_repeats: {subset.num_repeats}
shuffle_caption: {subset.shuffle_caption}
keep_tokens: {subset.keep_tokens}
keep_tokens_separator: {subset.keep_tokens_separator}
caption_dropout_rate: {subset.caption_dropout_rate}
caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
@@ -525,32 +548,119 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
random_crop: {subset.random_crop}
token_warmup_min: {subset.token_warmup_min},
token_warmup_step: {subset.token_warmup_step},
"""), " ")
"""
),
" ",
)
if is_dreambooth:
info += indent(dedent(f"""\
info += indent(
dedent(
f"""\
is_reg: {subset.is_reg}
class_tokens: {subset.class_tokens}
caption_extension: {subset.caption_extension}
\n"""), " ")
\n"""
),
" ",
)
elif not is_controlnet:
info += indent(dedent(f"""\
info += indent(
dedent(
f"""\
metadata_file: {subset.metadata_file}
\n"""), " ")
\n"""
),
" ",
)
print(info)
logger.info(f'{info}')
print_info(datasets)
# print validation info
info = ""
for i, dataset in enumerate(val_datasets):
is_dreambooth = isinstance(dataset, DreamBoothDataset)
is_controlnet = isinstance(dataset, ControlNetDataset)
info += dedent(
f"""\
[Validation Dataset {i}]
batch_size: {dataset.batch_size}
resolution: {(dataset.width, dataset.height)}
enable_bucket: {dataset.enable_bucket}
network_multiplier: {dataset.network_multiplier}
"""
)
if len(val_datasets) > 0:
print("Validation dataset")
print_info(val_datasets)
if dataset.enable_bucket:
info += indent(
dedent(
f"""\
min_bucket_reso: {dataset.min_bucket_reso}
max_bucket_reso: {dataset.max_bucket_reso}
bucket_reso_steps: {dataset.bucket_reso_steps}
bucket_no_upscale: {dataset.bucket_no_upscale}
\n"""
),
" ",
)
else:
info += "\n"
for j, subset in enumerate(dataset.subsets):
info += indent(
dedent(
f"""\
[Subset {j} of Dataset {i}]
image_dir: "{subset.image_dir}"
image_count: {subset.img_count}
num_repeats: {subset.num_repeats}
shuffle_caption: {subset.shuffle_caption}
keep_tokens: {subset.keep_tokens}
keep_tokens_separator: {subset.keep_tokens_separator}
caption_dropout_rate: {subset.caption_dropout_rate}
caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
caption_prefix: {subset.caption_prefix}
caption_suffix: {subset.caption_suffix}
color_aug: {subset.color_aug}
flip_aug: {subset.flip_aug}
face_crop_aug_range: {subset.face_crop_aug_range}
random_crop: {subset.random_crop}
token_warmup_min: {subset.token_warmup_min},
token_warmup_step: {subset.token_warmup_step},
"""
),
" ",
)
if is_dreambooth:
info += indent(
dedent(
f"""\
is_reg: {subset.is_reg}
class_tokens: {subset.class_tokens}
caption_extension: {subset.caption_extension}
\n"""
),
" ",
)
elif not is_controlnet:
info += indent(
dedent(
f"""\
metadata_file: {subset.metadata_file}
\n"""
),
" ",
)
logger.info(f'{info}')
# make buckets first because it determines the length of dataset
# and set the same seed for all datasets
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
for i, dataset in enumerate(datasets):
print(f"[Dataset {i}]")
logger.info(f"[Dataset {i}]")
dataset.make_buckets()
dataset.set_seed(seed)
@@ -642,13 +752,17 @@ def load_user_config(file: str) -> dict:
with open(file, "r") as f:
config = json.load(f)
except Exception:
logger.error(f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}")
logger.error(
f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
)
raise
elif file.name.lower().endswith(".toml"):
try:
config = toml.load(file)
except Exception:
logger.error(f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}")
logger.error(
f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
)
raise
else:
raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}")
@@ -675,13 +789,13 @@ if __name__ == "__main__":
train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning)
logger.info("[argparse_namespace]")
logger.info(f'{vars(argparse_namespace)}')
logger.info(f"{vars(argparse_namespace)}")
user_config = load_user_config(config_args.dataset_config)
logger.info("")
logger.info("[user_config]")
logger.info(f'{user_config}')
logger.info(f"{user_config}")
sanitizer = ConfigSanitizer(
config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout
@@ -690,10 +804,10 @@ if __name__ == "__main__":
logger.info("")
logger.info("[sanitized_user_config]")
logger.info(f'{sanitized_user_config}')
logger.info(f"{sanitized_user_config}")
blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace)
logger.info("")
logger.info("[blueprint]")
logger.info(f'{blueprint}')
logger.info(f"{blueprint}")

View File

@@ -44,6 +44,7 @@ from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
import itertools
logger = logging.getLogger(__name__)
@@ -438,6 +439,7 @@ class NetworkTrainer:
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
cyclic_val_dataloader = itertools.cycle(val_dataloader)
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
@@ -980,22 +982,23 @@ class NetworkTrainer:
logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm)
accelerator.log(logs, step=global_step)
if global_step % 25 == 0:
if args.validation_every_n_step is not None:
if global_step % (args.validation_every_n_step) == 0:
if len(val_dataloader) > 0:
print("Validating バリデーション処理...")
total_loss = 0.0
with torch.no_grad():
val_dataloader_iter = iter(val_dataloader)
batch = next(val_dataloader_iter)
for val_step in min(len(val_dataloader), args.validation_batches):
is_train = False
batch = next(cyclic_val_dataloader)
loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args)
current_loss = loss.detach().item()
total_loss += loss.detach().item()
current_loss = total_loss / args.validation_batches
val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_loss)
if args.logging_dir is not None:
avr_loss: float = val_loss_recorder.moving_average
logs = {"loss/validation_current": current_loss}
logs = {"loss/avr_val_loss": avr_loss}
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
@@ -1005,10 +1008,22 @@ class NetworkTrainer:
logs = {"loss/epoch_average": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)
if args.validation_every_n_step is None:
if len(val_dataloader) > 0:
print("Validating バリデーション処理...")
total_loss = 0.0
with torch.no_grad():
for val_step in min(len(val_dataloader), args.validation_batches):
is_train = False
batch = next(cyclic_val_dataloader)
loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args)
total_loss += loss.detach().item()
current_loss = total_loss / args.validation_batches
val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_loss)
if args.logging_dir is not None:
avr_loss: float = val_loss_recorder.moving_average
logs = {"loss/validation_epoch_average": avr_loss}
logs = {"loss/val_epoch_average": avr_loss}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()
@@ -1162,6 +1177,18 @@ def setup_parser() -> argparse.ArgumentParser:
default=0.0,
help="Split for validation images out of the training dataset"
)
parser.add_argument(
"--validation_every_n_step",
type=int,
default=None,
help="Number of steps for counting validation loss. By default, validation per epoch is performed"
)
parser.add_argument(
"--validation_batches",
type=int,
default=1,
help="Number of val steps for counting validation loss. By default, validation one batch is performed"
)
return parser