add minimal impl for masked loss

This commit is contained in:
Kohya S
2024-02-26 23:19:58 +09:00
parent 577e9913ca
commit f2c727fc8c
3 changed files with 45 additions and 14 deletions

View File

@@ -41,12 +41,17 @@ from .train_util import (
DatasetGroup,
)
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def add_config_arguments(parser: argparse.ArgumentParser):
parser.add_argument("--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル")
parser.add_argument(
"--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル"
)
# TODO: inherit Params class in Subset, Dataset
@@ -248,9 +253,10 @@ class ConfigSanitizer:
}
def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None:
assert (
support_dreambooth or support_finetuning or support_controlnet
), "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。"
assert support_dreambooth or support_finetuning or support_controlnet, (
"Neither DreamBooth mode nor fine tuning mode nor controlnet mode specified. Please specify one mode or more."
+ " / DreamBooth モードか fine tuning モードか controlnet モードのども指定されていません。1つ以上指定してください。"
)
self.db_subset_schema = self.__merge_dict(
self.SUBSET_ASCENDABLE_SCHEMA,
@@ -362,7 +368,9 @@ class ConfigSanitizer:
return self.argparse_config_validator(argparse_namespace)
except MultipleInvalid:
# XXX: this should be a bug
logger.error("Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。")
logger.error(
"Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。"
)
raise
# NOTE: value would be overwritten by latter dict if there is already the same key
@@ -547,11 +555,11 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
" ",
)
logger.info(f'{info}')
logger.info(f"{info}")
# make buckets first because it determines the length of dataset
# and set the same seed for all datasets
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
for i, dataset in enumerate(datasets):
logger.info(f"[Dataset {i}]")
dataset.make_buckets()
@@ -638,13 +646,17 @@ def load_user_config(file: str) -> dict:
with open(file, "r") as f:
config = json.load(f)
except Exception:
logger.error(f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}")
logger.error(
f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
)
raise
elif file.name.lower().endswith(".toml"):
try:
config = toml.load(file)
except Exception:
logger.error(f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}")
logger.error(
f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
)
raise
else:
raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}")
@@ -671,13 +683,13 @@ if __name__ == "__main__":
train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning)
logger.info("[argparse_namespace]")
logger.info(f'{vars(argparse_namespace)}')
logger.info(f"{vars(argparse_namespace)}")
user_config = load_user_config(config_args.dataset_config)
logger.info("")
logger.info("[user_config]")
logger.info(f'{user_config}')
logger.info(f"{user_config}")
sanitizer = ConfigSanitizer(
config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout
@@ -686,10 +698,10 @@ if __name__ == "__main__":
logger.info("")
logger.info("[sanitized_user_config]")
logger.info(f'{sanitized_user_config}')
logger.info(f"{sanitized_user_config}")
blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace)
logger.info("")
logger.info("[blueprint]")
logger.info(f'{blueprint}')
logger.info(f"{blueprint}")