support masked loss in sdxl_train ref #589

This commit is contained in:
Kohya S
2024-02-27 21:43:55 +09:00
parent 4a5546d40e
commit a9b64ffba8
2 changed files with 22 additions and 2 deletions

View File

@@ -251,7 +251,9 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum
### Masked loss ### Masked loss
`train_network.py` and `sdxl_train_network.py` now support the masked loss. `--masked_loss` option is added. `train_network.py`, `sdxl_train_network.py` and `sdxl_train.py` now support the masked loss. `--masked_loss` option is added.
NOTE: `train_network.py` and `sdxl_train.py` are not tested yet.
ControlNet dataset is used to specify the mask. The mask images should be the RGB images. The pixel value 255 in R channel is treated as the mask (the loss is calculated only for the pixels with the mask), and 0 is treated as the non-mask. See details for the dataset specification in the [LLLite documentation](./docs/train_lllite_README.md#preparing-the-dataset). ControlNet dataset is used to specify the mask. The mask images should be the RGB images. The pixel value 255 in R channel is treated as the mask (the loss is calculated only for the pixels with the mask), and 0 is treated as the non-mask. See details for the dataset specification in the [LLLite documentation](./docs/train_lllite_README.md#preparing-the-dataset).

View File

@@ -11,6 +11,7 @@ from tqdm import tqdm
import torch import torch
from library.device_utils import init_ipex, clean_memory_on_device from library.device_utils import init_ipex, clean_memory_on_device
init_ipex() init_ipex()
from accelerate.utils import set_seed from accelerate.utils import set_seed
@@ -124,7 +125,7 @@ def train(args):
# データセットを準備する # データセットを準備する
if args.dataset_class is None: if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True))
if args.dataset_config is not None: if args.dataset_config is not None:
logger.info(f"Load dataset config from {args.dataset_config}") logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config) user_config = config_util.load_user_config(args.dataset_config)
@@ -579,6 +580,16 @@ def train(args):
): ):
# do not mean over batch dimension for snr weight or scale v-pred loss # do not mean over batch dimension for snr weight or scale v-pred loss
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
if args.masked_loss:
# mask image is -1 to 1. we need to convert it to 0 to 1
mask_image = batch["conditioning_images"].to(dtype=weight_dtype)[:, 0].unsqueeze(1) # use R channel
# resize to the same size as the loss
mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area")
mask_image = mask_image / 2 + 0.5
loss = loss * mask_image
loss = loss.mean([1, 2, 3]) loss = loss.mean([1, 2, 3])
if args.min_snr_gamma: if args.min_snr_gamma:
@@ -780,6 +791,13 @@ def setup_parser() -> argparse.ArgumentParser:
+ f"U-Netの各ブロックの学習率、カンマ区切り、{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値", + f"U-Netの各ブロックの学習率、カンマ区切り、{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値",
) )
# TODO common masked_loss argument
parser.add_argument(
"--masked_loss",
action="store_true",
help="apply mask for calculating loss. conditioning_data_dir is required for dataset. / 損失計算時にマスクを適用する。datasetにはconditioning_data_dirが必要",
)
return parser return parser