mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 16:39:42 +00:00
Preference optimization with MaPO and Diffusion-DPO
This commit is contained in:
@@ -75,6 +75,11 @@ class BaseSubsetParams:
|
||||
custom_attributes: Optional[Dict[str, Any]] = None
|
||||
validation_seed: int = 0
|
||||
validation_split: float = 0.0
|
||||
preference: bool = False
|
||||
preference_caption_prefix: Optional[str] = None
|
||||
preference_caption_suffix: Optional[str] = None
|
||||
non_preference_caption_prefix: Optional[str] = None
|
||||
non_preference_caption_suffix: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -196,6 +201,11 @@ class ConfigSanitizer:
|
||||
"caption_prefix": str,
|
||||
"caption_suffix": str,
|
||||
"custom_attributes": dict,
|
||||
"preference": bool,
|
||||
"preference_caption_prefix": str,
|
||||
"preference_caption_suffix": str,
|
||||
"non_preference_caption_prefix": str,
|
||||
"non_preference_caption_suffix": str
|
||||
}
|
||||
# DO means DropOut
|
||||
DO_SUBSET_ASCENDABLE_SCHEMA = {
|
||||
|
||||
@@ -209,6 +209,20 @@ class ImageInfo:
|
||||
self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
|
||||
|
||||
|
||||
class ImageSetInfo(ImageInfo):
|
||||
def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None:
|
||||
super().__init__(image_key, num_repeats, caption, is_reg, absolute_path)
|
||||
|
||||
self.absolute_paths = [absolute_path]
|
||||
self.captions = [caption]
|
||||
self.image_sizes = []
|
||||
|
||||
def add(self, absolute_path, caption, size):
|
||||
self.absolute_paths.append(absolute_path)
|
||||
self.captions.append(caption)
|
||||
self.image_sizes.append(size)
|
||||
|
||||
|
||||
class BucketManager:
|
||||
def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None:
|
||||
if max_size is not None:
|
||||
@@ -431,6 +445,11 @@ class BaseSubset:
|
||||
custom_attributes: Optional[Dict[str, Any]] = None,
|
||||
validation_seed: Optional[int] = None,
|
||||
validation_split: Optional[float] = 0.0,
|
||||
preference: bool,
|
||||
preference_caption_prefix: Optional[str],
|
||||
preference_caption_suffix: Optional[str],
|
||||
non_preference_caption_prefix: Optional[str],
|
||||
non_preference_caption_suffix: Optional[str],
|
||||
) -> None:
|
||||
self.image_dir = image_dir
|
||||
self.alpha_mask = alpha_mask if alpha_mask is not None else False
|
||||
@@ -455,6 +474,11 @@ class BaseSubset:
|
||||
self.token_warmup_step = token_warmup_step # N(N<1ならN*max_train_steps)ステップ目でタグの数が最大になる
|
||||
|
||||
self.custom_attributes = custom_attributes if custom_attributes is not None else {}
|
||||
self.preference = preference
|
||||
self.preference_caption_prefix = non_preference_caption_prefix
|
||||
self.preference_caption_suffix = non_preference_caption_suffix
|
||||
self.non_preference_caption_prefix = non_preference_caption_prefix
|
||||
self.non_preference_caption_suffix = non_preference_caption_suffix
|
||||
|
||||
self.img_count = 0
|
||||
|
||||
@@ -492,6 +516,11 @@ class DreamBoothSubset(BaseSubset):
|
||||
custom_attributes: Optional[Dict[str, Any]] = None,
|
||||
validation_seed: Optional[int] = None,
|
||||
validation_split: Optional[float] = 0.0,
|
||||
preference: bool,
|
||||
preference_caption_prefix,
|
||||
preference_caption_suffix,
|
||||
non_preference_caption_prefix,
|
||||
non_preference_caption_suffix,
|
||||
) -> None:
|
||||
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
|
||||
|
||||
@@ -519,6 +548,11 @@ class DreamBoothSubset(BaseSubset):
|
||||
custom_attributes=custom_attributes,
|
||||
validation_seed=validation_seed,
|
||||
validation_split=validation_split,
|
||||
preference,
|
||||
preference_caption_prefix,
|
||||
preference_caption_suffix,
|
||||
non_preference_caption_prefix,
|
||||
non_preference_caption_suffix,
|
||||
)
|
||||
|
||||
self.is_reg = is_reg
|
||||
@@ -1462,6 +1496,75 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
image_size = (0, 0)
|
||||
return image_size
|
||||
|
||||
def load_and_transform_image(self, subset, image_info, absolute_path, flipped):
|
||||
# 画像を読み込み、必要ならcropする
|
||||
|
||||
img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(
|
||||
subset, absolute_path, subset.alpha_mask
|
||||
)
|
||||
im_h, im_w = img.shape[0:2]
|
||||
|
||||
if self.enable_bucket:
|
||||
img, original_size, crop_ltrb = trim_and_resize_if_required(
|
||||
subset.random_crop, img, image_info.bucket_reso, image_info.resized_size
|
||||
)
|
||||
else:
|
||||
if face_cx > 0: # 顔位置情報あり
|
||||
img = self.crop_target(subset, img, face_cx, face_cy, face_w, face_h)
|
||||
elif im_h > self.height or im_w > self.width:
|
||||
assert (
|
||||
subset.random_crop
|
||||
), f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}"
|
||||
if im_h > self.height:
|
||||
p = random.randint(0, im_h - self.height)
|
||||
img = img[p : p + self.height]
|
||||
if im_w > self.width:
|
||||
p = random.randint(0, im_w - self.width)
|
||||
img = img[:, p : p + self.width]
|
||||
|
||||
im_h, im_w = img.shape[0:2]
|
||||
assert (
|
||||
im_h == self.height and im_w == self.width
|
||||
), f"image size is small / 画像サイズが小さいようです: {absolute_path}"
|
||||
|
||||
original_size = [im_w, im_h]
|
||||
crop_ltrb = (0, 0, 0, 0)
|
||||
|
||||
# augmentation
|
||||
aug = self.aug_helper.get_augmentor(subset.color_aug)
|
||||
if aug is not None:
|
||||
# augment RGB channels only
|
||||
img_rgb = img[:, :, :3]
|
||||
img_rgb = aug(image=img_rgb)["image"]
|
||||
img[:, :, :3] = img_rgb
|
||||
|
||||
if flipped:
|
||||
img = img[:, ::-1, :].copy() # copy to avoid negative stride problem
|
||||
|
||||
if subset.alpha_mask:
|
||||
if img.shape[2] == 4:
|
||||
alpha_mask = img[:, :, 3] # [H,W]
|
||||
alpha_mask = alpha_mask.astype(np.float32) / 255.0 # 0.0~1.0
|
||||
alpha_mask = torch.FloatTensor(alpha_mask)
|
||||
else:
|
||||
alpha_mask = torch.ones((img.shape[0], img.shape[1]), dtype=torch.float32)
|
||||
else:
|
||||
alpha_mask = None
|
||||
|
||||
img = img[:, :, :3] # remove alpha channel
|
||||
|
||||
image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
|
||||
|
||||
target_size = (image.shape[2], image.shape[1]) if image is not None else (latents.shape[2] * 8, latents.shape[1] * 8)
|
||||
|
||||
if not flipped:
|
||||
crop_left_top = (crop_ltrb[0], crop_ltrb[1])
|
||||
else:
|
||||
# crop_ltrb[2] is right, so target_size[0] - crop_ltrb[2] is left in flipped image
|
||||
crop_left_top = (target_size[0] - crop_ltrb[2], crop_ltrb[1])
|
||||
|
||||
return image, original_size, crop_left_top, alpha_mask
|
||||
|
||||
def load_image_with_face_info(self, subset: BaseSubset, image_path: str, alpha_mask=False):
|
||||
img = load_image(image_path, alpha_mask)
|
||||
|
||||
@@ -1571,6 +1674,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
alpha_mask = None if image_info.alpha_mask is None else torch.flip(image_info.alpha_mask, [1])
|
||||
|
||||
image = None
|
||||
|
||||
images.append(image)
|
||||
latents_list.append(latents)
|
||||
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
|
||||
latents, original_size, crop_ltrb, flipped_latents, alpha_mask = (
|
||||
self.latents_caching_strategy.load_latents_from_disk(image_info.latents_npz, image_info.bucket_reso)
|
||||
@@ -1584,69 +1690,22 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
alpha_mask = torch.FloatTensor(alpha_mask)
|
||||
|
||||
image = None
|
||||
|
||||
images.append(image)
|
||||
latents_list.append(latents)
|
||||
alpha_mask_list.append(alpha_mask)
|
||||
else:
|
||||
# 画像を読み込み、必要ならcropする
|
||||
img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(
|
||||
subset, image_info.absolute_path, subset.alpha_mask
|
||||
)
|
||||
im_h, im_w = img.shape[0:2]
|
||||
|
||||
if self.enable_bucket:
|
||||
img, original_size, crop_ltrb = trim_and_resize_if_required(
|
||||
subset.random_crop, img, image_info.bucket_reso, image_info.resized_size
|
||||
)
|
||||
if isinstance(image_info, ImageSetInfo):
|
||||
for absolute_path in image_info.absolute_paths:
|
||||
image, original_size, crop_ltrb, alpha_mask = self.load_and_transform_image(subset, image_info, absolute_path, flipped)
|
||||
images.append(image)
|
||||
latents_list.append(None)
|
||||
alpha_mask_list.append(alpha_mask)
|
||||
else:
|
||||
if face_cx > 0: # 顔位置情報あり
|
||||
img = self.crop_target(subset, img, face_cx, face_cy, face_w, face_h)
|
||||
elif im_h > self.height or im_w > self.width:
|
||||
assert (
|
||||
subset.random_crop
|
||||
), f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}"
|
||||
if im_h > self.height:
|
||||
p = random.randint(0, im_h - self.height)
|
||||
img = img[p : p + self.height]
|
||||
if im_w > self.width:
|
||||
p = random.randint(0, im_w - self.width)
|
||||
img = img[:, p : p + self.width]
|
||||
|
||||
im_h, im_w = img.shape[0:2]
|
||||
assert (
|
||||
im_h == self.height and im_w == self.width
|
||||
), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
|
||||
|
||||
original_size = [im_w, im_h]
|
||||
crop_ltrb = (0, 0, 0, 0)
|
||||
|
||||
# augmentation
|
||||
aug = self.aug_helper.get_augmentor(subset.color_aug)
|
||||
if aug is not None:
|
||||
# augment RGB channels only
|
||||
img_rgb = img[:, :, :3]
|
||||
img_rgb = aug(image=img_rgb)["image"]
|
||||
img[:, :, :3] = img_rgb
|
||||
|
||||
if flipped:
|
||||
img = img[:, ::-1, :].copy() # copy to avoid negative stride problem
|
||||
|
||||
if subset.alpha_mask:
|
||||
if img.shape[2] == 4:
|
||||
alpha_mask = img[:, :, 3] # [H,W]
|
||||
alpha_mask = alpha_mask.astype(np.float32) / 255.0 # 0.0~1.0
|
||||
alpha_mask = torch.FloatTensor(alpha_mask)
|
||||
else:
|
||||
alpha_mask = torch.ones((img.shape[0], img.shape[1]), dtype=torch.float32)
|
||||
else:
|
||||
alpha_mask = None
|
||||
|
||||
img = img[:, :, :3] # remove alpha channel
|
||||
|
||||
latents = None
|
||||
image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
|
||||
del img
|
||||
|
||||
images.append(image)
|
||||
latents_list.append(latents)
|
||||
alpha_mask_list.append(alpha_mask)
|
||||
image, original_size, crop_ltrb, alpha_mask = self.load_and_transform_image(subset, image_info, image_info.absolute_path, flipped)
|
||||
images.append(image)
|
||||
latents_list.append(None)
|
||||
alpha_mask_list.append(alpha_mask)
|
||||
|
||||
target_size = (image.shape[2], image.shape[1]) if image is not None else (latents.shape[2] * 8, latents.shape[1] * 8)
|
||||
|
||||
@@ -1933,6 +1992,11 @@ class DreamBoothDataset(BaseDataset):
|
||||
img_paths = list(metas.keys())
|
||||
sizes = [meta["resolution"] for meta in metas.values()]
|
||||
|
||||
elif subset.preference:
|
||||
# We assume a image_dir path pattern for winner/loser
|
||||
winner_path = str(pathlib.Path(subset.image_dir) / "w")
|
||||
img_paths = glob_images(winner_path, "*")
|
||||
sizes = [None] * len(img_paths)
|
||||
# we may need to check image size and existence of image files, but it takes time, so user should check it before training
|
||||
else:
|
||||
img_paths = glob_images(subset.image_dir, "*")
|
||||
@@ -2081,9 +2145,41 @@ class DreamBoothDataset(BaseDataset):
|
||||
num_train_images += num_repeats * len(img_paths)
|
||||
|
||||
for img_path, caption, size in zip(img_paths, captions, sizes):
|
||||
info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path)
|
||||
if size is not None:
|
||||
info.image_size = size
|
||||
if subset.preference:
|
||||
def get_non_preferred_pair_info(img_path, subset):
|
||||
head, file = os.path.split(img_path)
|
||||
head, tail = os.path.split(head)
|
||||
new_tail = tail.replace('w', 'l')
|
||||
loser_img_path = os.path.join(head, new_tail, file)
|
||||
|
||||
caption = read_caption(img_path, subset.caption_extension, subset.enable_wildcard)
|
||||
|
||||
if subset.non_preference_caption_prefix:
|
||||
caption = subset.non_preference_caption_prefix + " " + caption
|
||||
if subset.non_preference_caption_suffix:
|
||||
caption = caption + " " + subset.non_preference_caption_suffix
|
||||
|
||||
image_size = self.get_image_size(img_path) if size is not None else None
|
||||
|
||||
return loser_img_path, caption, image_size
|
||||
|
||||
if subset.preference_caption_prefix:
|
||||
caption = subset.preference_caption_prefix + " " + caption
|
||||
if subset.preference_caption_suffix:
|
||||
caption = caption + " " + subset.preference_caption_suffix
|
||||
|
||||
info = ImageSetInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path)
|
||||
if size is not None:
|
||||
info.image_size = size
|
||||
info.image_sizes = [size]
|
||||
else:
|
||||
info.image_sizes = [None]
|
||||
info.add(*get_non_preferred_pair_info(img_path, subset))
|
||||
else:
|
||||
info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path)
|
||||
if size is not None:
|
||||
info.image_size = size
|
||||
|
||||
if subset.is_reg:
|
||||
reg_infos.append((info, subset))
|
||||
else:
|
||||
@@ -2398,6 +2494,7 @@ class ControlNetDataset(BaseDataset):
|
||||
subset.caption_suffix,
|
||||
subset.token_warmup_min,
|
||||
subset.token_warmup_step,
|
||||
subset.preference
|
||||
)
|
||||
db_subsets.append(db_subset)
|
||||
|
||||
@@ -4113,6 +4210,21 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
default=None,
|
||||
help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--beta_dpo",
|
||||
type=int,
|
||||
help="DPO KL Divergence penalty. Recommended values for SD1.5 B=2000, SDXL B=5000 / DPO KL 発散ペナルティ。SD1.5 の推奨値 B=2000、SDXL B=5000",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--orpo_weight",
|
||||
type=float,
|
||||
help="ORPO weight for relative ratio loss. Recommended values of 0.1 to 0.25 / 相対比損失の ORPO 重み。推奨値は 0.1 ~ 0.25 です",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mapo_weight",
|
||||
type=float,
|
||||
help="MaPO weight for relative ratio loss. Recommended values of 0.1 to 0.25 / 相対比損失の ORPO 重み。推奨値は 0.1 ~ 0.25 です",
|
||||
)
|
||||
|
||||
if support_dreambooth:
|
||||
# DreamBooth training
|
||||
@@ -4422,6 +4534,30 @@ def add_dataset_arguments(
|
||||
default=None,
|
||||
help="suffix for caption text / captionのテキストの末尾に付ける文字列",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--preference_caption_prefix",
|
||||
type=str,
|
||||
default=None,
|
||||
help="prefix for preference caption text / captionのテキストの先頭に付ける文字列",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--preference_caption_suffix",
|
||||
type=str,
|
||||
default=None,
|
||||
help="suffix for preference caption text / captionのテキストの末尾に付ける文字列",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--non_preference_caption_prefix",
|
||||
type=str,
|
||||
default=None,
|
||||
help="prefix for non-preference caption text / captionのテキストの先頭に付ける文字列",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--non_preference_caption_suffix",
|
||||
type=str,
|
||||
default=None,
|
||||
help="suffix for non-preference caption text / captionのテキストの末尾に付ける文字列",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする"
|
||||
)
|
||||
@@ -5956,6 +6092,7 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents: torch.
|
||||
|
||||
# Sample a random timestep for each image
|
||||
b_size = latents.shape[0]
|
||||
|
||||
min_timestep = 0 if args.min_timestep is None else args.min_timestep
|
||||
max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep
|
||||
|
||||
|
||||
@@ -406,7 +406,89 @@ class NetworkTrainer:
|
||||
loss = loss * weighting
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
if args.beta_dpo is not None:
|
||||
model_loss = loss.mean(dim=list(range(1, len(loss.shape))))
|
||||
model_loss_w, model_loss_l = model_loss.chunk(2)
|
||||
raw_model_loss = 0.5 * (model_loss_w.mean() + model_loss_l.mean())
|
||||
model_diff = model_loss_w - model_loss_l
|
||||
|
||||
# ref loss
|
||||
with torch.no_grad():
|
||||
# disable network for reference
|
||||
accelerator.unwrap_model(network).set_multiplier(0.0)
|
||||
|
||||
with accelerator.autocast():
|
||||
ref_noise_pred = self.call_unet(
|
||||
args,
|
||||
accelerator,
|
||||
unet,
|
||||
noisy_latents.requires_grad_(train_unet),
|
||||
timesteps,
|
||||
text_encoder_conds,
|
||||
batch,
|
||||
weight_dtype,
|
||||
)
|
||||
ref_loss = train_util.conditional_loss(
|
||||
ref_noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
|
||||
)
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
ref_loss = apply_masked_loss(ref_loss, batch)
|
||||
|
||||
ref_loss = ref_loss.mean(dim=list(range(1, len(ref_loss.shape))))
|
||||
ref_losses_w, ref_losses_l = ref_loss.chunk(2)
|
||||
ref_diff = ref_losses_w - ref_losses_l
|
||||
raw_ref_loss = ref_loss.mean()
|
||||
|
||||
# reset network multipliers
|
||||
accelerator.unwrap_model(network).set_multiplier(multipliers)
|
||||
|
||||
scale_term = -0.5 * args.beta_dpo
|
||||
inside_term = scale_term * (model_diff - ref_diff)
|
||||
loss = -1 * torch.nn.functional.logsigmoid(inside_term)
|
||||
|
||||
implicit_acc = (inside_term > 0).sum().float() / inside_term.size(0)
|
||||
implicit_acc += 0.5 * (inside_term == 0).sum().float() / inside_term.size(0)
|
||||
|
||||
accelerator.log({
|
||||
"total_loss": model_loss.detach().mean().item(),
|
||||
"raw_model_loss": raw_model_loss.detach().mean().item(),
|
||||
"ref_loss": raw_ref_loss.detach().item(),
|
||||
"implicit_acc": implicit_acc.detach().item(),
|
||||
}, step=global_step)
|
||||
elif args.mapo_weight is not None:
|
||||
model_loss = loss.mean(dim=list(range(1, len(loss.shape))))
|
||||
|
||||
snr = 0.5
|
||||
model_losses_w, model_losses_l = model_loss.chunk(2)
|
||||
log_odds = (snr * model_losses_w) / (torch.exp(snr * model_losses_w) - 1) - (
|
||||
snr * model_losses_l
|
||||
) / (torch.exp(snr * model_losses_l) - 1)
|
||||
|
||||
# Ratio loss.
|
||||
# By multiplying T to the inner term, we try to maximize the margin throughout the overall denoising process.
|
||||
ratio = torch.nn.functional.logsigmoid(log_odds * noise_scheduler.config.num_train_timesteps)
|
||||
ratio_losses = args.mapo_weight * ratio
|
||||
|
||||
# Full MaPO loss
|
||||
loss = model_losses_w.mean(dim=list(range(1, len(model_losses_w.shape)))) - ratio_losses.mean(dim=list(range(1, len(ratio_losses.shape))))
|
||||
|
||||
accelerator.log({
|
||||
"total_loss": loss.detach().mean().item(),
|
||||
"ratio_loss": -ratio_losses.mean().detach().item(),
|
||||
"model_losses_w": model_losses_w.mean().detach().item(),
|
||||
"model_losses_l": model_losses_l.mean().detach().item(),
|
||||
"win_score": ((snr * model_losses_w) / (torch.exp(snr * model_losses_w) - 1))
|
||||
.mean()
|
||||
.detach()
|
||||
.item(),
|
||||
"lose_score": ((snr * model_losses_l) / (torch.exp(snr * model_losses_l) - 1))
|
||||
.mean()
|
||||
.detach()
|
||||
.item()
|
||||
}, step=global_step)
|
||||
else:
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
loss = loss * loss_weights
|
||||
|
||||
Reference in New Issue
Block a user