mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
support masked loss in sdxl_train ref #589
This commit is contained in:
@@ -11,6 +11,7 @@ from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
|
||||
from accelerate.utils import set_seed
|
||||
@@ -124,7 +125,7 @@ def train(args):
|
||||
|
||||
# データセットを準備する
|
||||
if args.dataset_class is None:
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True))
|
||||
if args.dataset_config is not None:
|
||||
logger.info(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
@@ -579,6 +580,16 @@ def train(args):
|
||||
):
|
||||
# do not mean over batch dimension for snr weight or scale v-pred loss
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
|
||||
if args.masked_loss:
|
||||
# mask image is -1 to 1. we need to convert it to 0 to 1
|
||||
mask_image = batch["conditioning_images"].to(dtype=weight_dtype)[:, 0].unsqueeze(1) # use R channel
|
||||
|
||||
# resize to the same size as the loss
|
||||
mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area")
|
||||
mask_image = mask_image / 2 + 0.5
|
||||
loss = loss * mask_image
|
||||
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
if args.min_snr_gamma:
|
||||
@@ -780,6 +791,13 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
+ f"U-Netの各ブロックの学習率、カンマ区切り、{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値",
|
||||
)
|
||||
|
||||
# TODO common masked_loss argument
|
||||
parser.add_argument(
|
||||
"--masked_loss",
|
||||
action="store_true",
|
||||
help="apply mask for calculating loss. conditioning_data_dir is required for dataset. / 損失計算時にマスクを適用する。datasetにはconditioning_data_dirが必要",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user