mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
support masked loss in sdxl_train ref #589
This commit is contained in:
@@ -251,7 +251,9 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum
|
|||||||
|
|
||||||
### Masked loss
|
### Masked loss
|
||||||
|
|
||||||
`train_network.py` and `sdxl_train_network.py` now support the masked loss. `--masked_loss` option is added.
|
`train_network.py`, `sdxl_train_network.py` and `sdxl_train.py` now support the masked loss. `--masked_loss` option is added.
|
||||||
|
|
||||||
|
NOTE: `train_network.py` and `sdxl_train.py` are not tested yet.
|
||||||
|
|
||||||
ControlNet dataset is used to specify the mask. The mask images should be the RGB images. The pixel value 255 in R channel is treated as the mask (the loss is calculated only for the pixels with the mask), and 0 is treated as the non-mask. See details for the dataset specification in the [LLLite documentation](./docs/train_lllite_README.md#preparing-the-dataset).
|
ControlNet dataset is used to specify the mask. The mask images should be the RGB images. The pixel value 255 in R channel is treated as the mask (the loss is calculated only for the pixels with the mask), and 0 is treated as the non-mask. See details for the dataset specification in the [LLLite documentation](./docs/train_lllite_README.md#preparing-the-dataset).
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from library.device_utils import init_ipex, clean_memory_on_device
|
from library.device_utils import init_ipex, clean_memory_on_device
|
||||||
|
|
||||||
init_ipex()
|
init_ipex()
|
||||||
|
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
@@ -124,7 +125,7 @@ def train(args):
|
|||||||
|
|
||||||
# データセットを準備する
|
# データセットを準備する
|
||||||
if args.dataset_class is None:
|
if args.dataset_class is None:
|
||||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
|
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True))
|
||||||
if args.dataset_config is not None:
|
if args.dataset_config is not None:
|
||||||
logger.info(f"Load dataset config from {args.dataset_config}")
|
logger.info(f"Load dataset config from {args.dataset_config}")
|
||||||
user_config = config_util.load_user_config(args.dataset_config)
|
user_config = config_util.load_user_config(args.dataset_config)
|
||||||
@@ -579,6 +580,16 @@ def train(args):
|
|||||||
):
|
):
|
||||||
# do not mean over batch dimension for snr weight or scale v-pred loss
|
# do not mean over batch dimension for snr weight or scale v-pred loss
|
||||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||||
|
|
||||||
|
if args.masked_loss:
|
||||||
|
# mask image is -1 to 1. we need to convert it to 0 to 1
|
||||||
|
mask_image = batch["conditioning_images"].to(dtype=weight_dtype)[:, 0].unsqueeze(1) # use R channel
|
||||||
|
|
||||||
|
# resize to the same size as the loss
|
||||||
|
mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area")
|
||||||
|
mask_image = mask_image / 2 + 0.5
|
||||||
|
loss = loss * mask_image
|
||||||
|
|
||||||
loss = loss.mean([1, 2, 3])
|
loss = loss.mean([1, 2, 3])
|
||||||
|
|
||||||
if args.min_snr_gamma:
|
if args.min_snr_gamma:
|
||||||
@@ -780,6 +791,13 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
+ f"U-Netの各ブロックの学習率、カンマ区切り、{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値",
|
+ f"U-Netの各ブロックの学習率、カンマ区切り、{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO common masked_loss argument
|
||||||
|
parser.add_argument(
|
||||||
|
"--masked_loss",
|
||||||
|
action="store_true",
|
||||||
|
help="apply mask for calculating loss. conditioning_data_dir is required for dataset. / 損失計算時にマスクを適用する。datasetにはconditioning_data_dirが必要",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user