mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 06:28:48 +00:00
画像のアルファチャンネルをlossのマスクとして使用するオプションを追加 (#1223)
* Add alpha_mask parameter and apply masked loss * Fix type hint in trim_and_resize_if_required function * Refactor code to use keyword arguments in train_util.py * Fix alpha mask flipping logic * Fix alpha mask initialization * Fix alpha_mask transformation * Cache alpha_mask * Update alpha_masks to be on CPU * Set flipped_alpha_masks to Null if option disabled * Check if alpha_mask is None * Set alpha_mask to None if option disabled * Add description of alpha_mask option to docs
This commit is contained in:
@@ -102,6 +102,8 @@ accelerate launch --num_cpu_threads_per_process 1 train_network.py
|
||||
* Text Encoderに関連するLoRAモジュールに、通常の学習率(--learning_rateオプションで指定)とは異なる学習率を使う時に指定します。Text Encoderのほうを若干低めの学習率(5e-5など)にしたほうが良い、という話もあるようです。
|
||||
* `--network_args`
|
||||
* 複数の引数を指定できます。後述します。
|
||||
* `--alpha_mask`
|
||||
* 画像のアルファ値をマスクとして使用します。透過画像を学習する際に使用します。[PR #1223](https://github.com/kohya-ss/sd-scripts/pull/1223)
|
||||
|
||||
`--network_train_unet_only` と `--network_train_text_encoder_only` の両方とも未指定時(デフォルト)はText EncoderとU-Netの両方のLoRAモジュールを有効にします。
|
||||
|
||||
|
||||
@@ -101,6 +101,8 @@ LoRA的模型将会被保存在通过`--output_dir`选项指定的文件夹中
|
||||
* 当在Text Encoder相关的LoRA模块中使用与常规学习率(由`--learning_rate`选项指定)不同的学习率时,应指定此选项。可能最好将Text Encoder的学习率稍微降低(例如5e-5)。
|
||||
* `--network_args`
|
||||
* 可以指定多个参数。将在下面详细说明。
|
||||
* `--alpha_mask`
|
||||
* 使用图像的 Alpha 值作为遮罩。这在学习透明图像时使用。[PR #1223](https://github.com/kohya-ss/sd-scripts/pull/1223)
|
||||
|
||||
当未指定`--network_train_unet_only`和`--network_train_text_encoder_only`时(默认情况),将启用Text Encoder和U-Net的两个LoRA模块。
|
||||
|
||||
|
||||
@@ -78,6 +78,7 @@ class BaseSubsetParams:
|
||||
caption_tag_dropout_rate: float = 0.0
|
||||
token_warmup_min: int = 1
|
||||
token_warmup_step: float = 0
|
||||
alpha_mask: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -538,6 +539,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
random_crop: {subset.random_crop}
|
||||
token_warmup_min: {subset.token_warmup_min},
|
||||
token_warmup_step: {subset.token_warmup_step},
|
||||
alpha_mask: {subset.alpha_mask},
|
||||
"""
|
||||
),
|
||||
" ",
|
||||
|
||||
@@ -479,9 +479,10 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
|
||||
return noise
|
||||
|
||||
|
||||
def apply_masked_loss(loss, batch):
|
||||
def apply_masked_loss(loss, mask_image):
|
||||
# mask image is -1 to 1. we need to convert it to 0 to 1
|
||||
mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
|
||||
# mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
|
||||
mask_image = mask_image.to(dtype=loss.dtype)
|
||||
|
||||
# resize to the same size as the loss
|
||||
mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area")
|
||||
|
||||
@@ -159,6 +159,9 @@ class ImageInfo:
|
||||
self.text_encoder_outputs1: Optional[torch.Tensor] = None
|
||||
self.text_encoder_outputs2: Optional[torch.Tensor] = None
|
||||
self.text_encoder_pool2: Optional[torch.Tensor] = None
|
||||
self.alpha_mask: Optional[torch.Tensor] = None
|
||||
self.alpha_mask_flipped: Optional[torch.Tensor] = None
|
||||
self.use_alpha_mask: bool = False
|
||||
|
||||
|
||||
class BucketManager:
|
||||
@@ -379,6 +382,7 @@ class BaseSubset:
|
||||
caption_suffix: Optional[str],
|
||||
token_warmup_min: int,
|
||||
token_warmup_step: Union[float, int],
|
||||
alpha_mask: bool,
|
||||
) -> None:
|
||||
self.image_dir = image_dir
|
||||
self.num_repeats = num_repeats
|
||||
@@ -403,6 +407,7 @@ class BaseSubset:
|
||||
|
||||
self.img_count = 0
|
||||
|
||||
self.alpha_mask = alpha_mask
|
||||
|
||||
class DreamBoothSubset(BaseSubset):
|
||||
def __init__(
|
||||
@@ -412,47 +417,13 @@ class DreamBoothSubset(BaseSubset):
|
||||
class_tokens: Optional[str],
|
||||
caption_extension: str,
|
||||
cache_info: bool,
|
||||
num_repeats,
|
||||
shuffle_caption,
|
||||
caption_separator: str,
|
||||
keep_tokens,
|
||||
keep_tokens_separator,
|
||||
secondary_separator,
|
||||
enable_wildcard,
|
||||
color_aug,
|
||||
flip_aug,
|
||||
face_crop_aug_range,
|
||||
random_crop,
|
||||
caption_dropout_rate,
|
||||
caption_dropout_every_n_epochs,
|
||||
caption_tag_dropout_rate,
|
||||
caption_prefix,
|
||||
caption_suffix,
|
||||
token_warmup_min,
|
||||
token_warmup_step,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
|
||||
|
||||
super().__init__(
|
||||
image_dir,
|
||||
num_repeats,
|
||||
shuffle_caption,
|
||||
caption_separator,
|
||||
keep_tokens,
|
||||
keep_tokens_separator,
|
||||
secondary_separator,
|
||||
enable_wildcard,
|
||||
color_aug,
|
||||
flip_aug,
|
||||
face_crop_aug_range,
|
||||
random_crop,
|
||||
caption_dropout_rate,
|
||||
caption_dropout_every_n_epochs,
|
||||
caption_tag_dropout_rate,
|
||||
caption_prefix,
|
||||
caption_suffix,
|
||||
token_warmup_min,
|
||||
token_warmup_step,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.is_reg = is_reg
|
||||
@@ -473,47 +444,13 @@ class FineTuningSubset(BaseSubset):
|
||||
self,
|
||||
image_dir,
|
||||
metadata_file: str,
|
||||
num_repeats,
|
||||
shuffle_caption,
|
||||
caption_separator,
|
||||
keep_tokens,
|
||||
keep_tokens_separator,
|
||||
secondary_separator,
|
||||
enable_wildcard,
|
||||
color_aug,
|
||||
flip_aug,
|
||||
face_crop_aug_range,
|
||||
random_crop,
|
||||
caption_dropout_rate,
|
||||
caption_dropout_every_n_epochs,
|
||||
caption_tag_dropout_rate,
|
||||
caption_prefix,
|
||||
caption_suffix,
|
||||
token_warmup_min,
|
||||
token_warmup_step,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
|
||||
|
||||
super().__init__(
|
||||
image_dir,
|
||||
num_repeats,
|
||||
shuffle_caption,
|
||||
caption_separator,
|
||||
keep_tokens,
|
||||
keep_tokens_separator,
|
||||
secondary_separator,
|
||||
enable_wildcard,
|
||||
color_aug,
|
||||
flip_aug,
|
||||
face_crop_aug_range,
|
||||
random_crop,
|
||||
caption_dropout_rate,
|
||||
caption_dropout_every_n_epochs,
|
||||
caption_tag_dropout_rate,
|
||||
caption_prefix,
|
||||
caption_suffix,
|
||||
token_warmup_min,
|
||||
token_warmup_step,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.metadata_file = metadata_file
|
||||
@@ -531,47 +468,13 @@ class ControlNetSubset(BaseSubset):
|
||||
conditioning_data_dir: str,
|
||||
caption_extension: str,
|
||||
cache_info: bool,
|
||||
num_repeats,
|
||||
shuffle_caption,
|
||||
caption_separator,
|
||||
keep_tokens,
|
||||
keep_tokens_separator,
|
||||
secondary_separator,
|
||||
enable_wildcard,
|
||||
color_aug,
|
||||
flip_aug,
|
||||
face_crop_aug_range,
|
||||
random_crop,
|
||||
caption_dropout_rate,
|
||||
caption_dropout_every_n_epochs,
|
||||
caption_tag_dropout_rate,
|
||||
caption_prefix,
|
||||
caption_suffix,
|
||||
token_warmup_min,
|
||||
token_warmup_step,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
|
||||
|
||||
super().__init__(
|
||||
image_dir,
|
||||
num_repeats,
|
||||
shuffle_caption,
|
||||
caption_separator,
|
||||
keep_tokens,
|
||||
keep_tokens_separator,
|
||||
secondary_separator,
|
||||
enable_wildcard,
|
||||
color_aug,
|
||||
flip_aug,
|
||||
face_crop_aug_range,
|
||||
random_crop,
|
||||
caption_dropout_rate,
|
||||
caption_dropout_every_n_epochs,
|
||||
caption_tag_dropout_rate,
|
||||
caption_prefix,
|
||||
caption_suffix,
|
||||
token_warmup_min,
|
||||
token_warmup_step,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.conditioning_data_dir = conditioning_data_dir
|
||||
@@ -985,6 +888,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
for info in tqdm(image_infos):
|
||||
subset = self.image_to_subset[info.image_key]
|
||||
|
||||
info.use_alpha_mask = subset.alpha_mask
|
||||
|
||||
if info.latents_npz is not None: # fine tuning dataset
|
||||
continue
|
||||
|
||||
@@ -1088,8 +993,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
def get_image_size(self, image_path):
|
||||
return imagesize.get(image_path)
|
||||
|
||||
def load_image_with_face_info(self, subset: BaseSubset, image_path: str):
|
||||
img = load_image(image_path)
|
||||
def load_image_with_face_info(self, subset: BaseSubset, image_path: str, alpha_mask=False):
|
||||
img = load_image(image_path, alpha_mask)
|
||||
|
||||
face_cx = face_cy = face_w = face_h = 0
|
||||
if subset.face_crop_aug_range is not None:
|
||||
@@ -1166,6 +1071,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
input_ids_list = []
|
||||
input_ids2_list = []
|
||||
latents_list = []
|
||||
alpha_mask_list = []
|
||||
images = []
|
||||
original_sizes_hw = []
|
||||
crop_top_lefts = []
|
||||
@@ -1190,21 +1096,27 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
crop_ltrb = image_info.latents_crop_ltrb # calc values later if flipped
|
||||
if not flipped:
|
||||
latents = image_info.latents
|
||||
alpha_mask = image_info.alpha_mask
|
||||
else:
|
||||
latents = image_info.latents_flipped
|
||||
|
||||
alpha_mask = image_info.alpha_mask_flipped
|
||||
|
||||
image = None
|
||||
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
|
||||
latents, original_size, crop_ltrb, flipped_latents = load_latents_from_disk(image_info.latents_npz)
|
||||
latents, original_size, crop_ltrb, flipped_latents, alpha_mask, flipped_alpha_mask = load_latents_from_disk(image_info.latents_npz)
|
||||
if flipped:
|
||||
latents = flipped_latents
|
||||
alpha_mask = flipped_alpha_mask
|
||||
del flipped_latents
|
||||
del flipped_alpha_mask
|
||||
latents = torch.FloatTensor(latents)
|
||||
if alpha_mask is not None:
|
||||
alpha_mask = torch.FloatTensor(alpha_mask)
|
||||
|
||||
image = None
|
||||
else:
|
||||
# 画像を読み込み、必要ならcropする
|
||||
img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(subset, image_info.absolute_path)
|
||||
img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(subset, image_info.absolute_path, subset.alpha_mask)
|
||||
im_h, im_w = img.shape[0:2]
|
||||
|
||||
if self.enable_bucket:
|
||||
@@ -1241,11 +1153,22 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
if flipped:
|
||||
img = img[:, ::-1, :].copy() # copy to avoid negative stride problem
|
||||
|
||||
if subset.alpha_mask:
|
||||
if img.shape[2] == 4:
|
||||
alpha_mask = img[:, :, 3] # [W,H]
|
||||
else:
|
||||
alpha_mask = np.full((im_w, im_h), 255, dtype=np.uint8) # [W,H]
|
||||
alpha_mask = transforms.ToTensor()(alpha_mask)
|
||||
else:
|
||||
alpha_mask = None
|
||||
img = img[:, :, :3] # remove alpha channel
|
||||
|
||||
latents = None
|
||||
image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
|
||||
|
||||
images.append(image)
|
||||
latents_list.append(latents)
|
||||
alpha_mask_list.append(alpha_mask)
|
||||
|
||||
target_size = (image.shape[2], image.shape[1]) if image is not None else (latents.shape[2] * 8, latents.shape[1] * 8)
|
||||
|
||||
@@ -1348,6 +1271,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
example["network_multipliers"] = torch.FloatTensor([self.network_multiplier] * len(captions))
|
||||
|
||||
example["alpha_mask"] = torch.stack(alpha_mask_list) if alpha_mask_list[0] is not None else None
|
||||
|
||||
if self.debug_dataset:
|
||||
example["image_keys"] = bucket[image_index : image_index + self.batch_size]
|
||||
return example
|
||||
@@ -2145,7 +2070,7 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool):
|
||||
# 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top)
|
||||
def load_latents_from_disk(
|
||||
npz_path,
|
||||
) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[torch.Tensor]]:
|
||||
) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
npz = np.load(npz_path)
|
||||
if "latents" not in npz:
|
||||
raise ValueError(f"error: npz is old format. please re-generate {npz_path}")
|
||||
@@ -2154,13 +2079,19 @@ def load_latents_from_disk(
|
||||
original_size = npz["original_size"].tolist()
|
||||
crop_ltrb = npz["crop_ltrb"].tolist()
|
||||
flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None
|
||||
return latents, original_size, crop_ltrb, flipped_latents
|
||||
alpha_mask = npz["alpha_mask"] if "alpha_mask" in npz else None
|
||||
flipped_alpha_mask = npz["flipped_alpha_mask"] if "flipped_alpha_mask" in npz else None
|
||||
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask, flipped_alpha_mask
|
||||
|
||||
|
||||
def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None):
|
||||
def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None, flipped_alpha_mask=None):
|
||||
kwargs = {}
|
||||
if flipped_latents_tensor is not None:
|
||||
kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy()
|
||||
if alpha_mask is not None:
|
||||
kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy()
|
||||
if flipped_alpha_mask is not None:
|
||||
kwargs["flipped_alpha_mask"] = flipped_alpha_mask.float().cpu().numpy()
|
||||
np.savez(
|
||||
npz_path,
|
||||
latents=latents_tensor.float().cpu().numpy(),
|
||||
@@ -2349,17 +2280,20 @@ def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset:
|
||||
return train_dataset_group
|
||||
|
||||
|
||||
def load_image(image_path):
|
||||
def load_image(image_path, alpha=False):
|
||||
image = Image.open(image_path)
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
if alpha:
|
||||
image = image.convert("RGBA")
|
||||
else:
|
||||
image = image.convert("RGB")
|
||||
img = np.array(image, np.uint8)
|
||||
return img
|
||||
|
||||
|
||||
# 画像を読み込む。戻り値はnumpy.ndarray,(original width, original height),(crop left, crop top, crop right, crop bottom)
|
||||
def trim_and_resize_if_required(
|
||||
random_crop: bool, image: Image.Image, reso, resized_size: Tuple[int, int]
|
||||
random_crop: bool, image: np.ndarray, reso, resized_size: Tuple[int, int]
|
||||
) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int, int, int]]:
|
||||
image_height, image_width = image.shape[0:2]
|
||||
original_size = (image_width, image_height) # size before resize
|
||||
@@ -2403,10 +2337,18 @@ def cache_batch_latents(
|
||||
latents_original_size and latents_crop_ltrb are also set
|
||||
"""
|
||||
images = []
|
||||
alpha_masks = []
|
||||
for info in image_infos:
|
||||
image = load_image(info.absolute_path) if info.image is None else np.array(info.image, np.uint8)
|
||||
image = load_image(info.absolute_path, info.use_alpha_mask) if info.image is None else np.array(info.image, np.uint8)
|
||||
# TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要
|
||||
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size)
|
||||
if info.use_alpha_mask:
|
||||
if image.shape[2] == 4:
|
||||
alpha_mask = image[:, :, 3] # [W,H]
|
||||
image = image[:, :, :3]
|
||||
else:
|
||||
alpha_mask = np.full_like(image[:, :, 0], 255, dtype=np.uint8) # [W,H]
|
||||
alpha_masks.append(transforms.ToTensor()(alpha_mask))
|
||||
image = IMAGE_TRANSFORMS(image)
|
||||
images.append(image)
|
||||
|
||||
@@ -2419,25 +2361,37 @@ def cache_batch_latents(
|
||||
with torch.no_grad():
|
||||
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
|
||||
|
||||
if info.use_alpha_mask:
|
||||
alpha_masks = torch.stack(alpha_masks, dim=0).to("cpu")
|
||||
else:
|
||||
alpha_masks = [None] * len(image_infos)
|
||||
flipped_alpha_masks = [None] * len(image_infos)
|
||||
|
||||
if flip_aug:
|
||||
img_tensors = torch.flip(img_tensors, dims=[3])
|
||||
with torch.no_grad():
|
||||
flipped_latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
|
||||
if info.use_alpha_mask:
|
||||
flipped_alpha_masks = torch.flip(alpha_masks, dims=[3])
|
||||
else:
|
||||
flipped_latents = [None] * len(latents)
|
||||
flipped_alpha_masks = [None] * len(image_infos)
|
||||
|
||||
for info, latent, flipped_latent in zip(image_infos, latents, flipped_latents):
|
||||
for info, latent, flipped_latent, alpha_mask, flipped_alpha_mask in zip(image_infos, latents, flipped_latents, alpha_masks, flipped_alpha_masks):
|
||||
# check NaN
|
||||
if torch.isnan(latents).any() or (flipped_latent is not None and torch.isnan(flipped_latent).any()):
|
||||
raise RuntimeError(f"NaN detected in latents: {info.absolute_path}")
|
||||
|
||||
if cache_to_disk:
|
||||
save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_ltrb, flipped_latent)
|
||||
save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_ltrb, flipped_latent, alpha_mask, flipped_alpha_mask)
|
||||
else:
|
||||
info.latents = latent
|
||||
if flip_aug:
|
||||
info.latents_flipped = flipped_latent
|
||||
|
||||
info.alpha_mask = alpha_mask
|
||||
info.alpha_mask_flipped = flipped_alpha_mask
|
||||
|
||||
if not HIGH_VRAM:
|
||||
clean_memory_on_device(vae.device)
|
||||
|
||||
@@ -3683,6 +3637,11 @@ def add_dataset_arguments(
|
||||
default=0,
|
||||
help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / N(N<1ならN*max_train_steps)ステップでタグ長が最大になる。デフォルトは0(最初から最大)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--alpha_mask",
|
||||
action="store_true",
|
||||
help="use alpha channel as mask for training / 画像のアルファチャンネルをlossのマスクに使用する",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dataset_class",
|
||||
|
||||
@@ -712,7 +712,9 @@ def train(args):
|
||||
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
|
||||
)
|
||||
if args.masked_loss:
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
loss = apply_masked_loss(loss, batch["conditioning_images"][:, 0].unsqueeze(1))
|
||||
if "alpha_mask" in batch and batch["alpha_mask"] is not None:
|
||||
loss = apply_masked_loss(loss, batch["alpha_mask"])
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
if args.min_snr_gamma:
|
||||
|
||||
@@ -360,7 +360,9 @@ def train(args):
|
||||
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
|
||||
if args.masked_loss:
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
loss = apply_masked_loss(loss, batch["conditioning_images"][:, 0].unsqueeze(1))
|
||||
if "alpha_mask" in batch and batch["alpha_mask"] is not None:
|
||||
loss = apply_masked_loss(loss, batch["alpha_mask"])
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
|
||||
@@ -903,7 +903,9 @@ class NetworkTrainer:
|
||||
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
|
||||
)
|
||||
if args.masked_loss:
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
loss = apply_masked_loss(loss, batch["conditioning_images"][:, 0].unsqueeze(1))
|
||||
if "alpha_mask" in batch and batch["alpha_mask"] is not None:
|
||||
loss = apply_masked_loss(loss, batch["alpha_mask"])
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
|
||||
@@ -590,7 +590,9 @@ class TextualInversionTrainer:
|
||||
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
|
||||
if args.masked_loss:
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
loss = apply_masked_loss(loss, batch["conditioning_images"][:, 0].unsqueeze(1))
|
||||
if "alpha_mask" in batch and batch["alpha_mask"] is not None:
|
||||
loss = apply_masked_loss(loss, batch["alpha_mask"])
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
|
||||
@@ -475,7 +475,9 @@ def train(args):
|
||||
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
|
||||
if args.masked_loss:
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
loss = apply_masked_loss(loss, batch["conditioning_images"][:, 0].unsqueeze(1))
|
||||
if "alpha_mask" in batch and batch["alpha_mask"] is not None:
|
||||
loss = apply_masked_loss(loss, batch["alpha_mask"])
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
|
||||
Reference in New Issue
Block a user