This commit is contained in:
gesen2egee
2024-03-10 18:55:48 +08:00
parent b558a5b73d
commit 78cfb01922
2 changed files with 230 additions and 89 deletions

View File

@@ -41,12 +41,17 @@ from .train_util import (
DatasetGroup, DatasetGroup,
) )
from .utils import setup_logging from .utils import setup_logging
setup_logging() setup_logging()
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def add_config_arguments(parser: argparse.ArgumentParser): def add_config_arguments(parser: argparse.ArgumentParser):
parser.add_argument("--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル") parser.add_argument(
"--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル"
)
# TODO: inherit Params class in Subset, Dataset # TODO: inherit Params class in Subset, Dataset
@@ -60,6 +65,8 @@ class BaseSubsetParams:
caption_separator: str = (",",) caption_separator: str = (",",)
keep_tokens: int = 0 keep_tokens: int = 0
keep_tokens_separator: str = (None,) keep_tokens_separator: str = (None,)
secondary_separator: Optional[str] = None
enable_wildcard: bool = False
color_aug: bool = False color_aug: bool = False
flip_aug: bool = False flip_aug: bool = False
face_crop_aug_range: Optional[Tuple[float, float]] = None face_crop_aug_range: Optional[Tuple[float, float]] = None
@@ -181,6 +188,8 @@ class ConfigSanitizer:
"shuffle_caption": bool, "shuffle_caption": bool,
"keep_tokens": int, "keep_tokens": int,
"keep_tokens_separator": str, "keep_tokens_separator": str,
"secondary_separator": str,
"enable_wildcard": bool,
"token_warmup_min": int, "token_warmup_min": int,
"token_warmup_step": Any(float, int), "token_warmup_step": Any(float, int),
"caption_prefix": str, "caption_prefix": str,
@@ -247,9 +256,10 @@ class ConfigSanitizer:
} }
def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None: def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None:
assert ( assert support_dreambooth or support_finetuning or support_controlnet, (
support_dreambooth or support_finetuning or support_controlnet "Neither DreamBooth mode nor fine tuning mode nor controlnet mode specified. Please specify one mode or more."
), "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。" + " / DreamBooth モードか fine tuning モードか controlnet モードのども指定されていません。1つ以上指定してください。"
)
self.db_subset_schema = self.__merge_dict( self.db_subset_schema = self.__merge_dict(
self.SUBSET_ASCENDABLE_SCHEMA, self.SUBSET_ASCENDABLE_SCHEMA,
@@ -361,7 +371,9 @@ class ConfigSanitizer:
return self.argparse_config_validator(argparse_namespace) return self.argparse_config_validator(argparse_namespace)
except MultipleInvalid: except MultipleInvalid:
# XXX: this should be a bug # XXX: this should be a bug
logger.error("Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。") logger.error(
"Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。"
)
raise raise
# NOTE: value would be overwritten by latter dict if there is already the same key # NOTE: value would be overwritten by latter dict if there is already the same key
@@ -447,7 +459,6 @@ class BlueprintGenerator:
return default_value return default_value
def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint): def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint):
datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
@@ -467,7 +478,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
datasets.append(dataset) datasets.append(dataset)
val_datasets:List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] val_datasets:List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
for dataset_blueprint in dataset_group_blueprint.datasets: for dataset_blueprint in dataset_group_blueprint.datasets:
if dataset_blueprint.params.validation_split <= 0.0: if dataset_blueprint.params.validation_split <= 0.0:
continue continue
@@ -485,75 +496,174 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
dataset = dataset_klass(subsets=subsets, is_train=False, **asdict(dataset_blueprint.params)) dataset = dataset_klass(subsets=subsets, is_train=False, **asdict(dataset_blueprint.params))
val_datasets.append(dataset) val_datasets.append(dataset)
def print_info(_datasets): # print info
info = "" info = ""
for i, dataset in enumerate(_datasets): for i, dataset in enumerate(datasets):
is_dreambooth = isinstance(dataset, DreamBoothDataset) is_dreambooth = isinstance(dataset, DreamBoothDataset)
is_controlnet = isinstance(dataset, ControlNetDataset) is_controlnet = isinstance(dataset, ControlNetDataset)
info += dedent(f"""\ info += dedent(
[Dataset {i}] f"""\
batch_size: {dataset.batch_size} [Dataset {i}]
resolution: {(dataset.width, dataset.height)} batch_size: {dataset.batch_size}
enable_bucket: {dataset.enable_bucket} resolution: {(dataset.width, dataset.height)}
""") enable_bucket: {dataset.enable_bucket}
network_multiplier: {dataset.network_multiplier}
"""
)
if dataset.enable_bucket: if dataset.enable_bucket:
info += indent(dedent(f"""\ info += indent(
min_bucket_reso: {dataset.min_bucket_reso} dedent(
max_bucket_reso: {dataset.max_bucket_reso} f"""\
bucket_reso_steps: {dataset.bucket_reso_steps} min_bucket_reso: {dataset.min_bucket_reso}
bucket_no_upscale: {dataset.bucket_no_upscale} max_bucket_reso: {dataset.max_bucket_reso}
\n"""), " ") bucket_reso_steps: {dataset.bucket_reso_steps}
bucket_no_upscale: {dataset.bucket_no_upscale}
\n"""
),
" ",
)
else: else:
info += "\n" info += "\n"
for j, subset in enumerate(dataset.subsets): for j, subset in enumerate(dataset.subsets):
info += indent(dedent(f"""\ info += indent(
[Subset {j} of Dataset {i}] dedent(
image_dir: "{subset.image_dir}" f"""\
image_count: {subset.img_count} [Subset {j} of Dataset {i}]
num_repeats: {subset.num_repeats} image_dir: "{subset.image_dir}"
shuffle_caption: {subset.shuffle_caption} image_count: {subset.img_count}
keep_tokens: {subset.keep_tokens} num_repeats: {subset.num_repeats}
caption_dropout_rate: {subset.caption_dropout_rate} shuffle_caption: {subset.shuffle_caption}
caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs} keep_tokens: {subset.keep_tokens}
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate} keep_tokens_separator: {subset.keep_tokens_separator}
caption_prefix: {subset.caption_prefix} caption_dropout_rate: {subset.caption_dropout_rate}
caption_suffix: {subset.caption_suffix} caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
color_aug: {subset.color_aug} caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
flip_aug: {subset.flip_aug} caption_prefix: {subset.caption_prefix}
face_crop_aug_range: {subset.face_crop_aug_range} caption_suffix: {subset.caption_suffix}
random_crop: {subset.random_crop} color_aug: {subset.color_aug}
token_warmup_min: {subset.token_warmup_min}, flip_aug: {subset.flip_aug}
token_warmup_step: {subset.token_warmup_step}, face_crop_aug_range: {subset.face_crop_aug_range}
"""), " ") random_crop: {subset.random_crop}
token_warmup_min: {subset.token_warmup_min},
token_warmup_step: {subset.token_warmup_step},
"""
),
" ",
)
if is_dreambooth: if is_dreambooth:
info += indent(dedent(f"""\ info += indent(
is_reg: {subset.is_reg} dedent(
class_tokens: {subset.class_tokens} f"""\
caption_extension: {subset.caption_extension} is_reg: {subset.is_reg}
\n"""), " ") class_tokens: {subset.class_tokens}
elif not is_controlnet: caption_extension: {subset.caption_extension}
info += indent(dedent(f"""\ \n"""
metadata_file: {subset.metadata_file} ),
\n"""), " ") " ",
)
elif not is_controlnet:
info += indent(
dedent(
f"""\
metadata_file: {subset.metadata_file}
\n"""
),
" ",
)
print(info) logger.info(f'{info}')
print_info(datasets) # print validation info
info = ""
for i, dataset in enumerate(val_datasets):
is_dreambooth = isinstance(dataset, DreamBoothDataset)
is_controlnet = isinstance(dataset, ControlNetDataset)
info += dedent(
f"""\
[Validation Dataset {i}]
batch_size: {dataset.batch_size}
resolution: {(dataset.width, dataset.height)}
enable_bucket: {dataset.enable_bucket}
network_multiplier: {dataset.network_multiplier}
"""
)
if dataset.enable_bucket:
info += indent(
dedent(
f"""\
min_bucket_reso: {dataset.min_bucket_reso}
max_bucket_reso: {dataset.max_bucket_reso}
bucket_reso_steps: {dataset.bucket_reso_steps}
bucket_no_upscale: {dataset.bucket_no_upscale}
\n"""
),
" ",
)
else:
info += "\n"
for j, subset in enumerate(dataset.subsets):
info += indent(
dedent(
f"""\
[Subset {j} of Dataset {i}]
image_dir: "{subset.image_dir}"
image_count: {subset.img_count}
num_repeats: {subset.num_repeats}
shuffle_caption: {subset.shuffle_caption}
keep_tokens: {subset.keep_tokens}
keep_tokens_separator: {subset.keep_tokens_separator}
caption_dropout_rate: {subset.caption_dropout_rate}
caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
caption_prefix: {subset.caption_prefix}
caption_suffix: {subset.caption_suffix}
color_aug: {subset.color_aug}
flip_aug: {subset.flip_aug}
face_crop_aug_range: {subset.face_crop_aug_range}
random_crop: {subset.random_crop}
token_warmup_min: {subset.token_warmup_min},
token_warmup_step: {subset.token_warmup_step},
"""
),
" ",
)
if is_dreambooth:
info += indent(
dedent(
f"""\
is_reg: {subset.is_reg}
class_tokens: {subset.class_tokens}
caption_extension: {subset.caption_extension}
\n"""
),
" ",
)
elif not is_controlnet:
info += indent(
dedent(
f"""\
metadata_file: {subset.metadata_file}
\n"""
),
" ",
)
logger.info(f'{info}')
if len(val_datasets) > 0:
print("Validation dataset")
print_info(val_datasets)
# make buckets first because it determines the length of dataset # make buckets first because it determines the length of dataset
# and set the same seed for all datasets # and set the same seed for all datasets
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
for i, dataset in enumerate(datasets): for i, dataset in enumerate(datasets):
print(f"[Dataset {i}]") logger.info(f"[Dataset {i}]")
dataset.make_buckets() dataset.make_buckets()
dataset.set_seed(seed) dataset.set_seed(seed)
for i, dataset in enumerate(val_datasets): for i, dataset in enumerate(val_datasets):
print(f"[Validation Dataset {i}]") print(f"[Validation Dataset {i}]")
dataset.make_buckets() dataset.make_buckets()
@@ -562,8 +672,8 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
return ( return (
DatasetGroup(datasets), DatasetGroup(datasets),
DatasetGroup(val_datasets) if val_datasets else None DatasetGroup(val_datasets) if val_datasets else None
) )
def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None): def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None):
def extract_dreambooth_params(name: str) -> Tuple[int, str]: def extract_dreambooth_params(name: str) -> Tuple[int, str]:
tokens = name.split("_") tokens = name.split("_")
@@ -642,13 +752,17 @@ def load_user_config(file: str) -> dict:
with open(file, "r") as f: with open(file, "r") as f:
config = json.load(f) config = json.load(f)
except Exception: except Exception:
logger.error(f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}") logger.error(
f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
)
raise raise
elif file.name.lower().endswith(".toml"): elif file.name.lower().endswith(".toml"):
try: try:
config = toml.load(file) config = toml.load(file)
except Exception: except Exception:
logger.error(f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}") logger.error(
f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
)
raise raise
else: else:
raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}") raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}")
@@ -675,13 +789,13 @@ if __name__ == "__main__":
train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning) train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning)
logger.info("[argparse_namespace]") logger.info("[argparse_namespace]")
logger.info(f'{vars(argparse_namespace)}') logger.info(f"{vars(argparse_namespace)}")
user_config = load_user_config(config_args.dataset_config) user_config = load_user_config(config_args.dataset_config)
logger.info("") logger.info("")
logger.info("[user_config]") logger.info("[user_config]")
logger.info(f'{user_config}') logger.info(f"{user_config}")
sanitizer = ConfigSanitizer( sanitizer = ConfigSanitizer(
config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout
@@ -690,10 +804,10 @@ if __name__ == "__main__":
logger.info("") logger.info("")
logger.info("[sanitized_user_config]") logger.info("[sanitized_user_config]")
logger.info(f'{sanitized_user_config}') logger.info(f"{sanitized_user_config}")
blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace) blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace)
logger.info("") logger.info("")
logger.info("[blueprint]") logger.info("[blueprint]")
logger.info(f'{blueprint}') logger.info(f"{blueprint}")

View File

@@ -44,6 +44,7 @@ from library.utils import setup_logging, add_logging_arguments
setup_logging() setup_logging()
import logging import logging
import itertools
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -438,6 +439,7 @@ class NetworkTrainer:
num_workers=n_workers, num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers, persistent_workers=args.persistent_data_loader_workers,
) )
cyclic_val_dataloader = itertools.cycle(val_dataloader)
# 学習ステップ数を計算する # 学習ステップ数を計算する
if args.max_train_epochs is not None: if args.max_train_epochs is not None:
@@ -979,23 +981,24 @@ class NetworkTrainer:
if args.logging_dir is not None: if args.logging_dir is not None:
logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm)
accelerator.log(logs, step=global_step) accelerator.log(logs, step=global_step)
if global_step % 25 == 0: if args.validation_every_n_step is not None:
if len(val_dataloader) > 0: if global_step % (args.validation_every_n_step) == 0:
print("Validating バリデーション処理...") if len(val_dataloader) > 0:
print("Validating バリデーション処理...")
with torch.no_grad(): total_loss = 0.0
val_dataloader_iter = iter(val_dataloader) with torch.no_grad():
batch = next(val_dataloader_iter) for val_step in min(len(val_dataloader), args.validation_batches):
is_train = False is_train = False
loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) batch = next(cyclic_val_dataloader)
loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args)
current_loss = loss.detach().item() total_loss += loss.detach().item()
val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_loss) current_loss = total_loss / args.validation_batches
val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_loss)
if args.logging_dir is not None: if args.logging_dir is not None:
avr_loss: float = val_loss_recorder.moving_average avr_loss: float = val_loss_recorder.moving_average
logs = {"loss/validation_current": current_loss} logs = {"loss/avr_val_loss": avr_loss}
accelerator.log(logs, step=global_step) accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps: if global_step >= args.max_train_steps:
@@ -1005,12 +1008,24 @@ class NetworkTrainer:
logs = {"loss/epoch_average": loss_recorder.moving_average} logs = {"loss/epoch_average": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1) accelerator.log(logs, step=epoch + 1)
if len(val_dataloader) > 0: if args.validation_every_n_step is None:
if args.logging_dir is not None: if len(val_dataloader) > 0:
avr_loss: float = val_loss_recorder.moving_average print("Validating バリデーション処理...")
logs = {"loss/validation_epoch_average": avr_loss} total_loss = 0.0
accelerator.log(logs, step=epoch + 1) with torch.no_grad():
for val_step in min(len(val_dataloader), args.validation_batches):
is_train = False
batch = next(cyclic_val_dataloader)
loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args)
total_loss += loss.detach().item()
current_loss = total_loss / args.validation_batches
val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_loss)
if args.logging_dir is not None:
avr_loss: float = val_loss_recorder.moving_average
logs = {"loss/val_epoch_average": avr_loss}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
# 指定エポックごとにモデルを保存 # 指定エポックごとにモデルを保存
@@ -1162,6 +1177,18 @@ def setup_parser() -> argparse.ArgumentParser:
default=0.0, default=0.0,
help="Split for validation images out of the training dataset" help="Split for validation images out of the training dataset"
) )
parser.add_argument(
"--validation_every_n_step",
type=int,
default=None,
help="Number of steps for counting validation loss. By default, validation per epoch is performed"
)
parser.add_argument(
"--validation_batches",
type=int,
default=1,
help="Number of val steps for counting validation loss. By default, validation one batch is performed"
)
return parser return parser