add minimal impl for masked loss

This commit is contained in:
Kohya S
2024-02-26 23:19:58 +09:00
parent 577e9913ca
commit f2c727fc8c
3 changed files with 45 additions and 14 deletions

View File

@@ -41,12 +41,17 @@ from .train_util import (
DatasetGroup, DatasetGroup,
) )
from .utils import setup_logging from .utils import setup_logging
setup_logging() setup_logging()
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def add_config_arguments(parser: argparse.ArgumentParser): def add_config_arguments(parser: argparse.ArgumentParser):
parser.add_argument("--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル") parser.add_argument(
"--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル"
)
# TODO: inherit Params class in Subset, Dataset # TODO: inherit Params class in Subset, Dataset
@@ -248,9 +253,10 @@ class ConfigSanitizer:
} }
def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None: def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None:
assert ( assert support_dreambooth or support_finetuning or support_controlnet, (
support_dreambooth or support_finetuning or support_controlnet "Neither DreamBooth mode nor fine tuning mode nor controlnet mode specified. Please specify one mode or more."
), "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。" + " / DreamBooth モードか fine tuning モードか controlnet モードのども指定されていません。1つ以上指定してください。"
)
self.db_subset_schema = self.__merge_dict( self.db_subset_schema = self.__merge_dict(
self.SUBSET_ASCENDABLE_SCHEMA, self.SUBSET_ASCENDABLE_SCHEMA,
@@ -362,7 +368,9 @@ class ConfigSanitizer:
return self.argparse_config_validator(argparse_namespace) return self.argparse_config_validator(argparse_namespace)
except MultipleInvalid: except MultipleInvalid:
# XXX: this should be a bug # XXX: this should be a bug
logger.error("Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。") logger.error(
"Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。"
)
raise raise
# NOTE: value would be overwritten by latter dict if there is already the same key # NOTE: value would be overwritten by latter dict if there is already the same key
@@ -547,11 +555,11 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
" ", " ",
) )
logger.info(f'{info}') logger.info(f"{info}")
# make buckets first because it determines the length of dataset # make buckets first because it determines the length of dataset
# and set the same seed for all datasets # and set the same seed for all datasets
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
for i, dataset in enumerate(datasets): for i, dataset in enumerate(datasets):
logger.info(f"[Dataset {i}]") logger.info(f"[Dataset {i}]")
dataset.make_buckets() dataset.make_buckets()
@@ -638,13 +646,17 @@ def load_user_config(file: str) -> dict:
with open(file, "r") as f: with open(file, "r") as f:
config = json.load(f) config = json.load(f)
except Exception: except Exception:
logger.error(f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}") logger.error(
f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
)
raise raise
elif file.name.lower().endswith(".toml"): elif file.name.lower().endswith(".toml"):
try: try:
config = toml.load(file) config = toml.load(file)
except Exception: except Exception:
logger.error(f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}") logger.error(
f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
)
raise raise
else: else:
raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}") raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}")
@@ -671,13 +683,13 @@ if __name__ == "__main__":
train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning) train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning)
logger.info("[argparse_namespace]") logger.info("[argparse_namespace]")
logger.info(f'{vars(argparse_namespace)}') logger.info(f"{vars(argparse_namespace)}")
user_config = load_user_config(config_args.dataset_config) user_config = load_user_config(config_args.dataset_config)
logger.info("") logger.info("")
logger.info("[user_config]") logger.info("[user_config]")
logger.info(f'{user_config}') logger.info(f"{user_config}")
sanitizer = ConfigSanitizer( sanitizer = ConfigSanitizer(
config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout
@@ -686,10 +698,10 @@ if __name__ == "__main__":
logger.info("") logger.info("")
logger.info("[sanitized_user_config]") logger.info("[sanitized_user_config]")
logger.info(f'{sanitized_user_config}') logger.info(f"{sanitized_user_config}")
blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace) blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace)
logger.info("") logger.info("")
logger.info("[blueprint]") logger.info("[blueprint]")
logger.info(f'{blueprint}') logger.info(f"{blueprint}")

View File

@@ -1810,6 +1810,9 @@ class ControlNetDataset(BaseDataset):
db_subsets = [] db_subsets = []
for subset in subsets: for subset in subsets:
assert (
not subset.random_crop
), "random_crop is not supported in ControlNetDataset / random_cropはControlNetDatasetではサポートされていません"
db_subset = DreamBoothSubset( db_subset = DreamBoothSubset(
subset.image_dir, subset.image_dir,
False, False,

View File

@@ -13,6 +13,7 @@ from tqdm import tqdm
import torch import torch
from library.device_utils import init_ipex, clean_memory_on_device from library.device_utils import init_ipex, clean_memory_on_device
init_ipex() init_ipex()
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
@@ -157,7 +158,7 @@ class NetworkTrainer:
# データセットを準備する # データセットを準備する
if args.dataset_class is None: if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True))
if use_user_config: if use_user_config:
logger.info(f"Loading dataset config from {args.dataset_config}") logger.info(f"Loading dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config) user_config = config_util.load_user_config(args.dataset_config)
@@ -834,6 +835,16 @@ class NetworkTrainer:
target = noise target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
if args.masked_loss:
# mask image is -1 to 1. we need to convert it to 0 to 1
mask_image = batch["conditioning_images"].to(dtype=weight_dtype)[:, 0].unsqueeze(1) # use R channel
# resize to the same size as the loss
mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area")
mask_image = mask_image / 2 + 0.5
loss = loss * mask_image
loss = loss.mean([1, 2, 3]) loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight loss_weights = batch["loss_weights"] # 各sampleごとのweight
@@ -1050,6 +1061,11 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true", action="store_true",
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
) )
parser.add_argument(
"--masked_loss",
action="store_true",
help="apply mask for caclulating loss. conditioning_data_dir is required for dataset. / 損失計算時にマスクを適用する。datasetにはconditioning_data_dirが必要",
)
return parser return parser