mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 06:54:17 +00:00
add minimal impl for masked loss
This commit is contained in:
@@ -41,12 +41,17 @@ from .train_util import (
|
|||||||
DatasetGroup,
|
DatasetGroup,
|
||||||
)
|
)
|
||||||
from .utils import setup_logging
|
from .utils import setup_logging
|
||||||
|
|
||||||
setup_logging()
|
setup_logging()
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def add_config_arguments(parser: argparse.ArgumentParser):
|
def add_config_arguments(parser: argparse.ArgumentParser):
|
||||||
parser.add_argument("--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル")
|
parser.add_argument(
|
||||||
|
"--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# TODO: inherit Params class in Subset, Dataset
|
# TODO: inherit Params class in Subset, Dataset
|
||||||
@@ -248,9 +253,10 @@ class ConfigSanitizer:
|
|||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None:
|
def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None:
|
||||||
assert (
|
assert support_dreambooth or support_finetuning or support_controlnet, (
|
||||||
support_dreambooth or support_finetuning or support_controlnet
|
"Neither DreamBooth mode nor fine tuning mode nor controlnet mode specified. Please specify one mode or more."
|
||||||
), "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。"
|
+ " / DreamBooth モードか fine tuning モードか controlnet モードのどれも指定されていません。1つ以上指定してください。"
|
||||||
|
)
|
||||||
|
|
||||||
self.db_subset_schema = self.__merge_dict(
|
self.db_subset_schema = self.__merge_dict(
|
||||||
self.SUBSET_ASCENDABLE_SCHEMA,
|
self.SUBSET_ASCENDABLE_SCHEMA,
|
||||||
@@ -362,7 +368,9 @@ class ConfigSanitizer:
|
|||||||
return self.argparse_config_validator(argparse_namespace)
|
return self.argparse_config_validator(argparse_namespace)
|
||||||
except MultipleInvalid:
|
except MultipleInvalid:
|
||||||
# XXX: this should be a bug
|
# XXX: this should be a bug
|
||||||
logger.error("Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。")
|
logger.error(
|
||||||
|
"Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。"
|
||||||
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# NOTE: value would be overwritten by latter dict if there is already the same key
|
# NOTE: value would be overwritten by latter dict if there is already the same key
|
||||||
@@ -547,11 +555,11 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
|||||||
" ",
|
" ",
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f'{info}')
|
logger.info(f"{info}")
|
||||||
|
|
||||||
# make buckets first because it determines the length of dataset
|
# make buckets first because it determines the length of dataset
|
||||||
# and set the same seed for all datasets
|
# and set the same seed for all datasets
|
||||||
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
|
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
|
||||||
for i, dataset in enumerate(datasets):
|
for i, dataset in enumerate(datasets):
|
||||||
logger.info(f"[Dataset {i}]")
|
logger.info(f"[Dataset {i}]")
|
||||||
dataset.make_buckets()
|
dataset.make_buckets()
|
||||||
@@ -638,13 +646,17 @@ def load_user_config(file: str) -> dict:
|
|||||||
with open(file, "r") as f:
|
with open(file, "r") as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.error(f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}")
|
logger.error(
|
||||||
|
f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
|
||||||
|
)
|
||||||
raise
|
raise
|
||||||
elif file.name.lower().endswith(".toml"):
|
elif file.name.lower().endswith(".toml"):
|
||||||
try:
|
try:
|
||||||
config = toml.load(file)
|
config = toml.load(file)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.error(f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}")
|
logger.error(
|
||||||
|
f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
|
||||||
|
)
|
||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}")
|
raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}")
|
||||||
@@ -671,13 +683,13 @@ if __name__ == "__main__":
|
|||||||
train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning)
|
train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning)
|
||||||
|
|
||||||
logger.info("[argparse_namespace]")
|
logger.info("[argparse_namespace]")
|
||||||
logger.info(f'{vars(argparse_namespace)}')
|
logger.info(f"{vars(argparse_namespace)}")
|
||||||
|
|
||||||
user_config = load_user_config(config_args.dataset_config)
|
user_config = load_user_config(config_args.dataset_config)
|
||||||
|
|
||||||
logger.info("")
|
logger.info("")
|
||||||
logger.info("[user_config]")
|
logger.info("[user_config]")
|
||||||
logger.info(f'{user_config}')
|
logger.info(f"{user_config}")
|
||||||
|
|
||||||
sanitizer = ConfigSanitizer(
|
sanitizer = ConfigSanitizer(
|
||||||
config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout
|
config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout
|
||||||
@@ -686,10 +698,10 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
logger.info("")
|
logger.info("")
|
||||||
logger.info("[sanitized_user_config]")
|
logger.info("[sanitized_user_config]")
|
||||||
logger.info(f'{sanitized_user_config}')
|
logger.info(f"{sanitized_user_config}")
|
||||||
|
|
||||||
blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace)
|
blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace)
|
||||||
|
|
||||||
logger.info("")
|
logger.info("")
|
||||||
logger.info("[blueprint]")
|
logger.info("[blueprint]")
|
||||||
logger.info(f'{blueprint}')
|
logger.info(f"{blueprint}")
|
||||||
|
|||||||
@@ -1810,6 +1810,9 @@ class ControlNetDataset(BaseDataset):
|
|||||||
|
|
||||||
db_subsets = []
|
db_subsets = []
|
||||||
for subset in subsets:
|
for subset in subsets:
|
||||||
|
assert (
|
||||||
|
not subset.random_crop
|
||||||
|
), "random_crop is not supported in ControlNetDataset / random_cropはControlNetDatasetではサポートされていません"
|
||||||
db_subset = DreamBoothSubset(
|
db_subset = DreamBoothSubset(
|
||||||
subset.image_dir,
|
subset.image_dir,
|
||||||
False,
|
False,
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from library.device_utils import init_ipex, clean_memory_on_device
|
from library.device_utils import init_ipex, clean_memory_on_device
|
||||||
|
|
||||||
init_ipex()
|
init_ipex()
|
||||||
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
@@ -157,7 +158,7 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
# データセットを準備する
|
# データセットを準備する
|
||||||
if args.dataset_class is None:
|
if args.dataset_class is None:
|
||||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
|
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True))
|
||||||
if use_user_config:
|
if use_user_config:
|
||||||
logger.info(f"Loading dataset config from {args.dataset_config}")
|
logger.info(f"Loading dataset config from {args.dataset_config}")
|
||||||
user_config = config_util.load_user_config(args.dataset_config)
|
user_config = config_util.load_user_config(args.dataset_config)
|
||||||
@@ -834,6 +835,16 @@ class NetworkTrainer:
|
|||||||
target = noise
|
target = noise
|
||||||
|
|
||||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||||
|
|
||||||
|
if args.masked_loss:
|
||||||
|
# mask image is -1 to 1. we need to convert it to 0 to 1
|
||||||
|
mask_image = batch["conditioning_images"].to(dtype=weight_dtype)[:, 0].unsqueeze(1) # use R channel
|
||||||
|
|
||||||
|
# resize to the same size as the loss
|
||||||
|
mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area")
|
||||||
|
mask_image = mask_image / 2 + 0.5
|
||||||
|
loss = loss * mask_image
|
||||||
|
|
||||||
loss = loss.mean([1, 2, 3])
|
loss = loss.mean([1, 2, 3])
|
||||||
|
|
||||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||||
@@ -1050,6 +1061,11 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
|
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--masked_loss",
|
||||||
|
action="store_true",
|
||||||
|
help="apply mask for caclulating loss. conditioning_data_dir is required for dataset. / 損失計算時にマスクを適用する。datasetにはconditioning_data_dirが必要",
|
||||||
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user