mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 06:54:17 +00:00
add minimal impl for masked loss
This commit is contained in:
@@ -41,12 +41,17 @@ from .train_util import (
|
||||
DatasetGroup,
|
||||
)
|
||||
from .utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def add_config_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル")
|
||||
parser.add_argument(
|
||||
"--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル"
|
||||
)
|
||||
|
||||
|
||||
# TODO: inherit Params class in Subset, Dataset
|
||||
@@ -248,9 +253,10 @@ class ConfigSanitizer:
|
||||
}
|
||||
|
||||
def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None:
|
||||
assert (
|
||||
support_dreambooth or support_finetuning or support_controlnet
|
||||
), "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。"
|
||||
assert support_dreambooth or support_finetuning or support_controlnet, (
|
||||
"Neither DreamBooth mode nor fine tuning mode nor controlnet mode specified. Please specify one mode or more."
|
||||
+ " / DreamBooth モードか fine tuning モードか controlnet モードのどれも指定されていません。1つ以上指定してください。"
|
||||
)
|
||||
|
||||
self.db_subset_schema = self.__merge_dict(
|
||||
self.SUBSET_ASCENDABLE_SCHEMA,
|
||||
@@ -362,7 +368,9 @@ class ConfigSanitizer:
|
||||
return self.argparse_config_validator(argparse_namespace)
|
||||
except MultipleInvalid:
|
||||
# XXX: this should be a bug
|
||||
logger.error("Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。")
|
||||
logger.error(
|
||||
"Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。"
|
||||
)
|
||||
raise
|
||||
|
||||
# NOTE: value would be overwritten by latter dict if there is already the same key
|
||||
@@ -547,11 +555,11 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
" ",
|
||||
)
|
||||
|
||||
logger.info(f'{info}')
|
||||
logger.info(f"{info}")
|
||||
|
||||
# make buckets first because it determines the length of dataset
|
||||
# and set the same seed for all datasets
|
||||
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
|
||||
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
|
||||
for i, dataset in enumerate(datasets):
|
||||
logger.info(f"[Dataset {i}]")
|
||||
dataset.make_buckets()
|
||||
@@ -638,13 +646,17 @@ def load_user_config(file: str) -> dict:
|
||||
with open(file, "r") as f:
|
||||
config = json.load(f)
|
||||
except Exception:
|
||||
logger.error(f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}")
|
||||
logger.error(
|
||||
f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
|
||||
)
|
||||
raise
|
||||
elif file.name.lower().endswith(".toml"):
|
||||
try:
|
||||
config = toml.load(file)
|
||||
except Exception:
|
||||
logger.error(f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}")
|
||||
logger.error(
|
||||
f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
|
||||
)
|
||||
raise
|
||||
else:
|
||||
raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}")
|
||||
@@ -671,13 +683,13 @@ if __name__ == "__main__":
|
||||
train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning)
|
||||
|
||||
logger.info("[argparse_namespace]")
|
||||
logger.info(f'{vars(argparse_namespace)}')
|
||||
logger.info(f"{vars(argparse_namespace)}")
|
||||
|
||||
user_config = load_user_config(config_args.dataset_config)
|
||||
|
||||
logger.info("")
|
||||
logger.info("[user_config]")
|
||||
logger.info(f'{user_config}')
|
||||
logger.info(f"{user_config}")
|
||||
|
||||
sanitizer = ConfigSanitizer(
|
||||
config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout
|
||||
@@ -686,10 +698,10 @@ if __name__ == "__main__":
|
||||
|
||||
logger.info("")
|
||||
logger.info("[sanitized_user_config]")
|
||||
logger.info(f'{sanitized_user_config}')
|
||||
logger.info(f"{sanitized_user_config}")
|
||||
|
||||
blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace)
|
||||
|
||||
logger.info("")
|
||||
logger.info("[blueprint]")
|
||||
logger.info(f'{blueprint}')
|
||||
logger.info(f"{blueprint}")
|
||||
|
||||
@@ -1810,6 +1810,9 @@ class ControlNetDataset(BaseDataset):
|
||||
|
||||
db_subsets = []
|
||||
for subset in subsets:
|
||||
assert (
|
||||
not subset.random_crop
|
||||
), "random_crop is not supported in ControlNetDataset / random_cropはControlNetDatasetではサポートされていません"
|
||||
db_subset = DreamBoothSubset(
|
||||
subset.image_dir,
|
||||
False,
|
||||
|
||||
@@ -13,6 +13,7 @@ from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
@@ -157,7 +158,7 @@ class NetworkTrainer:
|
||||
|
||||
# データセットを準備する
|
||||
if args.dataset_class is None:
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True))
|
||||
if use_user_config:
|
||||
logger.info(f"Loading dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
@@ -834,6 +835,16 @@ class NetworkTrainer:
|
||||
target = noise
|
||||
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
|
||||
if args.masked_loss:
|
||||
# mask image is -1 to 1. we need to convert it to 0 to 1
|
||||
mask_image = batch["conditioning_images"].to(dtype=weight_dtype)[:, 0].unsqueeze(1) # use R channel
|
||||
|
||||
# resize to the same size as the loss
|
||||
mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area")
|
||||
mask_image = mask_image / 2 + 0.5
|
||||
loss = loss * mask_image
|
||||
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
@@ -1050,6 +1061,11 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
action="store_true",
|
||||
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--masked_loss",
|
||||
action="store_true",
|
||||
help="apply mask for caclulating loss. conditioning_data_dir is required for dataset. / 損失計算時にマスクを適用する。datasetにはconditioning_data_dirが必要",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user