Preference optimization with MaPO and Diffusion-DPO

This commit is contained in:
rockerBOO
2024-06-19 16:22:11 -04:00
parent e5bab69e3a
commit 44fa71c78f
3 changed files with 365 additions and 102 deletions

View File

@@ -78,6 +78,11 @@ class BaseSubsetParams:
caption_tag_dropout_rate: float = 0.0
token_warmup_min: int = 1
token_warmup_step: float = 0
preference: bool = False
preference_caption_prefix: Optional[str] = None
preference_caption_suffix: Optional[str] = None
non_preference_caption_prefix: Optional[str] = None
non_preference_caption_suffix: Optional[str] = None
@dataclass
@@ -199,6 +204,11 @@ class ConfigSanitizer:
"token_warmup_step": Any(float, int),
"caption_prefix": str,
"caption_suffix": str,
"preference": bool,
"preference_caption_prefix": str,
"preference_caption_suffix": str,
"non_preference_caption_prefix": str,
"non_preference_caption_suffix": str
}
# DO means DropOut
DO_SUBSET_ASCENDABLE_SCHEMA = {
@@ -540,14 +550,29 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
flip_aug: {subset.flip_aug}
face_crop_aug_range: {subset.face_crop_aug_range}
random_crop: {subset.random_crop}
token_warmup_min: {subset.token_warmup_min},
token_warmup_step: {subset.token_warmup_step},
alpha_mask: {subset.alpha_mask},
token_warmup_min: {subset.token_warmup_min}
token_warmup_step: {subset.token_warmup_step}
alpha_mask: {subset.alpha_mask}
preference: {subset.preference}
"""
),
" ",
)
if subset.preference:
info += indent(
dedent(
f"""\
preference_caption_prefix: {subset.preference_caption_prefix}
preference_caption_suffix: {subset.preference_caption_suffix}
non_preference_caption_prefix: {subset.non_preference_caption_prefix}
non_preference_caption_suffix: {subset.non_preference_caption_suffix}
\n"""
),
" ",
)
if is_dreambooth:
info += indent(
dedent(

View File

@@ -162,6 +162,20 @@ class ImageInfo:
self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
class ImageSetInfo(ImageInfo):
def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None:
super().__init__(image_key, num_repeats, caption, is_reg, absolute_path)
self.absolute_paths = [absolute_path]
self.captions = [caption]
self.image_sizes = []
def add(self, absolute_path, caption, size):
self.absolute_paths.append(absolute_path)
self.captions.append(caption)
self.image_sizes.append(size)
class BucketManager:
def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None:
if max_size is not None:
@@ -381,6 +395,11 @@ class BaseSubset:
caption_suffix: Optional[str],
token_warmup_min: int,
token_warmup_step: Union[float, int],
preference: bool,
preference_caption_prefix: Optional[str],
preference_caption_suffix: Optional[str],
non_preference_caption_prefix: Optional[str],
non_preference_caption_suffix: Optional[str],
) -> None:
self.image_dir = image_dir
self.alpha_mask = alpha_mask if alpha_mask is not None else False
@@ -404,6 +423,12 @@ class BaseSubset:
self.token_warmup_min = token_warmup_min # step=0におけるタグの数
self.token_warmup_step = token_warmup_step # NN<1ならN*max_train_stepsステップ目でタグの数が最大になる
self.preference = preference
self.preference_caption_prefix = non_preference_caption_prefix
self.preference_caption_suffix = non_preference_caption_suffix
self.non_preference_caption_prefix = non_preference_caption_prefix
self.non_preference_caption_suffix = non_preference_caption_suffix
self.img_count = 0
@@ -434,6 +459,11 @@ class DreamBoothSubset(BaseSubset):
caption_suffix,
token_warmup_min,
token_warmup_step,
preference: bool,
preference_caption_prefix,
preference_caption_suffix,
non_preference_caption_prefix,
non_preference_caption_suffix,
) -> None:
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
@@ -458,6 +488,11 @@ class DreamBoothSubset(BaseSubset):
caption_suffix,
token_warmup_min,
token_warmup_step,
preference,
preference_caption_prefix,
preference_caption_suffix,
non_preference_caption_prefix,
non_preference_caption_suffix,
)
self.is_reg = is_reg
@@ -1098,6 +1133,75 @@ class BaseDataset(torch.utils.data.Dataset):
def get_image_size(self, image_path):
return imagesize.get(image_path)
def load_and_transform_image(self, subset, image_info, absolute_path, flipped):
# 画像を読み込み、必要ならcropする
img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(
subset, absolute_path, subset.alpha_mask
)
im_h, im_w = img.shape[0:2]
if self.enable_bucket:
img, original_size, crop_ltrb = trim_and_resize_if_required(
subset.random_crop, img, image_info.bucket_reso, image_info.resized_size
)
else:
if face_cx > 0: # 顔位置情報あり
img = self.crop_target(subset, img, face_cx, face_cy, face_w, face_h)
elif im_h > self.height or im_w > self.width:
assert (
subset.random_crop
), f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}"
if im_h > self.height:
p = random.randint(0, im_h - self.height)
img = img[p : p + self.height]
if im_w > self.width:
p = random.randint(0, im_w - self.width)
img = img[:, p : p + self.width]
im_h, im_w = img.shape[0:2]
assert (
im_h == self.height and im_w == self.width
), f"image size is small / 画像サイズが小さいようです: {absolute_path}"
original_size = [im_w, im_h]
crop_ltrb = (0, 0, 0, 0)
# augmentation
aug = self.aug_helper.get_augmentor(subset.color_aug)
if aug is not None:
# augment RGB channels only
img_rgb = img[:, :, :3]
img_rgb = aug(image=img_rgb)["image"]
img[:, :, :3] = img_rgb
if flipped:
img = img[:, ::-1, :].copy() # copy to avoid negative stride problem
if subset.alpha_mask:
if img.shape[2] == 4:
alpha_mask = img[:, :, 3] # [H,W]
alpha_mask = alpha_mask.astype(np.float32) / 255.0 # 0.0~1.0
alpha_mask = torch.FloatTensor(alpha_mask)
else:
alpha_mask = torch.ones((img.shape[0], img.shape[1]), dtype=torch.float32)
else:
alpha_mask = None
img = img[:, :, :3] # remove alpha channel
image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
target_size = (image.shape[2], image.shape[1]) if image is not None else (latents.shape[2] * 8, latents.shape[1] * 8)
if not flipped:
crop_left_top = (crop_ltrb[0], crop_ltrb[1])
else:
# crop_ltrb[2] is right, so target_size[0] - crop_ltrb[2] is left in flipped image
crop_left_top = (target_size[0] - crop_ltrb[2], crop_ltrb[1])
return image, original_size, crop_left_top, alpha_mask
def load_image_with_face_info(self, subset: BaseSubset, image_path: str, alpha_mask=False):
img = load_image(image_path, alpha_mask)
@@ -1171,20 +1275,26 @@ class BaseDataset(torch.utils.data.Dataset):
if self.caching_mode is not None: # return batch for latents/text encoder outputs caching
return self.get_item_for_caching(bucket, bucket_batch_size, image_index)
loss_weights = []
captions = []
input_ids_list = []
input_ids2_list = []
latents_list = []
alpha_mask_list = []
images = []
original_sizes_hw = []
crop_top_lefts = []
target_sizes_hw = []
flippeds = [] # 変数名が微妙
text_encoder_outputs1_list = []
text_encoder_outputs2_list = []
text_encoder_pool2_list = []
loss_weights: List[float] = []
images: List[Image] = []
latents_list: List[torch.Tensor] = []
alpha_mask_list: List[torch.Tensor] = []
original_sizes_hw: List[Tuple[int, int]] = []
crop_top_lefts: List[Tuple[int, int]] = []
target_sizes_hw: List[Tuple[int, int]] = []
flippeds: List[bool] = [] # 変数名が微妙
captions: List[str] = []
input_ids_list: List[List[int]] = []
input_ids2_list: List[List[int]] = []
text_encoder_outputs1_list: List[Optional[torch.Tensor]] = []
text_encoder_outputs2_list: List[Optional[torch.Tensor]] = []
text_encoder_pool2_list: List[Optional[torch.Tensor]] = []
for image_key in bucket[image_index : image_index + bucket_batch_size]:
image_info = self.image_data[image_key]
@@ -1207,6 +1317,9 @@ class BaseDataset(torch.utils.data.Dataset):
alpha_mask = None if image_info.alpha_mask is None else torch.flip(image_info.alpha_mask, [1])
image = None
images.append(image)
latents_list.append(latents)
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
latents, original_size, crop_ltrb, flipped_latents, alpha_mask = load_latents_from_disk(image_info.latents_npz)
if flipped:
@@ -1218,69 +1331,22 @@ class BaseDataset(torch.utils.data.Dataset):
alpha_mask = torch.FloatTensor(alpha_mask)
image = None
images.append(image)
latents_list.append(latents)
alpha_mask_list.append(alpha_mask)
else:
# 画像を読み込み、必要ならcropする
img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(
subset, image_info.absolute_path, subset.alpha_mask
)
im_h, im_w = img.shape[0:2]
if self.enable_bucket:
img, original_size, crop_ltrb = trim_and_resize_if_required(
subset.random_crop, img, image_info.bucket_reso, image_info.resized_size
)
if isinstance(image_info, ImageSetInfo):
for absolute_path in image_info.absolute_paths:
image, original_size, crop_ltrb, alpha_mask = self.load_and_transform_image(subset, image_info, absolute_path, flipped)
images.append(image)
latents_list.append(None)
alpha_mask_list.append(alpha_mask)
else:
if face_cx > 0: # 顔位置情報あり
img = self.crop_target(subset, img, face_cx, face_cy, face_w, face_h)
elif im_h > self.height or im_w > self.width:
assert (
subset.random_crop
), f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}"
if im_h > self.height:
p = random.randint(0, im_h - self.height)
img = img[p : p + self.height]
if im_w > self.width:
p = random.randint(0, im_w - self.width)
img = img[:, p : p + self.width]
im_h, im_w = img.shape[0:2]
assert (
im_h == self.height and im_w == self.width
), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
original_size = [im_w, im_h]
crop_ltrb = (0, 0, 0, 0)
# augmentation
aug = self.aug_helper.get_augmentor(subset.color_aug)
if aug is not None:
# augment RGB channels only
img_rgb = img[:, :, :3]
img_rgb = aug(image=img_rgb)["image"]
img[:, :, :3] = img_rgb
if flipped:
img = img[:, ::-1, :].copy() # copy to avoid negative stride problem
if subset.alpha_mask:
if img.shape[2] == 4:
alpha_mask = img[:, :, 3] # [H,W]
alpha_mask = alpha_mask.astype(np.float32) / 255.0 # 0.0~1.0
alpha_mask = torch.FloatTensor(alpha_mask)
else:
alpha_mask = torch.ones((img.shape[0], img.shape[1]), dtype=torch.float32)
else:
alpha_mask = None
img = img[:, :, :3] # remove alpha channel
latents = None
image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
del img
images.append(image)
latents_list.append(latents)
alpha_mask_list.append(alpha_mask)
image, original_size, crop_ltrb, alpha_mask = self.load_and_transform_image(subset, image_info, image_info.absolute_path, flipped)
images.append(image)
latents_list.append(None)
alpha_mask_list.append(alpha_mask)
target_size = (image.shape[2], image.shape[1]) if image is not None else (latents.shape[2] * 8, latents.shape[1] * 8)
@@ -1311,31 +1377,41 @@ class BaseDataset(torch.utils.data.Dataset):
text_encoder_pool2_list.append(text_encoder_pool2)
captions.append(caption)
else:
caption = self.process_caption(subset, image_info.caption)
if self.XTI_layers:
caption_layer = []
for layer in self.XTI_layers:
token_strings_from = " ".join(self.token_strings)
token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings])
caption_ = caption.replace(token_strings_from, token_strings_to)
caption_layer.append(caption_)
captions.append(caption_layer)
captions_list = []
if isinstance(image_info, ImageSetInfo):
for image_info_caption in image_info.captions:
caption = self.process_caption(subset, image_info_caption)
captions_list.append(caption)
else:
captions.append(caption)
caption = self.process_caption(subset, image_info.caption)
captions_list.append(caption)
if not self.token_padding_disabled: # this option might be omitted in future
for caption in captions_list:
if self.XTI_layers:
token_caption = self.get_input_ids(caption_layer, self.tokenizers[0])
caption_layer = []
for layer in self.XTI_layers:
token_strings_from = " ".join(self.token_strings)
token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings])
caption_ = caption.replace(token_strings_from, token_strings_to)
caption_layer.append(caption_)
captions.append(caption_layer)
else:
token_caption = self.get_input_ids(caption, self.tokenizers[0])
input_ids_list.append(token_caption)
captions.append(caption)
if len(self.tokenizers) > 1:
if not self.token_padding_disabled: # this option might be omitted in future
if self.XTI_layers:
token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1])
token_caption = self.get_input_ids(caption_layer, self.tokenizers[0])
else:
token_caption2 = self.get_input_ids(caption, self.tokenizers[1])
input_ids2_list.append(token_caption2)
token_caption = self.get_input_ids(caption, self.tokenizers[0])
input_ids_list.append(token_caption)
if len(self.tokenizers) > 1:
if self.XTI_layers:
token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1])
else:
token_caption2 = self.get_input_ids(caption, self.tokenizers[1])
input_ids2_list.append(token_caption2)
example = {}
example["loss_weights"] = torch.FloatTensor(loss_weights)
@@ -1568,6 +1644,11 @@ class DreamBoothDataset(BaseDataset):
img_paths = list(metas.keys())
sizes = [meta["resolution"] for meta in metas.values()]
elif subset.preference:
# We assume a image_dir path pattern for winner/loser
winner_path = str(pathlib.Path(subset.image_dir) / "w")
img_paths = glob_images(winner_path, "*")
sizes = [None] * len(img_paths)
# we may need to check image size and existence of image files, but it takes time, so user should check it before training
else:
img_paths = glob_images(subset.image_dir, "*")
@@ -1656,9 +1737,41 @@ class DreamBoothDataset(BaseDataset):
num_train_images += subset.num_repeats * len(img_paths)
for img_path, caption, size in zip(img_paths, captions, sizes):
info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path)
if size is not None:
info.image_size = size
if subset.preference:
def get_non_preferred_pair_info(img_path, subset):
head, file = os.path.split(img_path)
head, tail = os.path.split(head)
new_tail = tail.replace('w', 'l')
loser_img_path = os.path.join(head, new_tail, file)
caption = read_caption(img_path, subset.caption_extension, subset.enable_wildcard)
if subset.non_preference_caption_prefix:
caption = subset.non_preference_caption_prefix + " " + caption
if subset.non_preference_caption_suffix:
caption = caption + " " + subset.non_preference_caption_suffix
image_size = self.get_image_size(img_path) if size is not None else None
return loser_img_path, caption, image_size
if subset.preference_caption_prefix:
caption = subset.preference_caption_prefix + " " + caption
if subset.preference_caption_suffix:
caption = caption + " " + subset.preference_caption_suffix
info = ImageSetInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path)
if size is not None:
info.image_size = size
info.image_sizes = [size]
else:
info.image_sizes = [None]
info.add(*get_non_preferred_pair_info(img_path, subset))
else:
info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path)
if size is not None:
info.image_size = size
if subset.is_reg:
reg_infos.append((info, subset))
else:
@@ -1968,6 +2081,7 @@ class ControlNetDataset(BaseDataset):
subset.caption_suffix,
subset.token_warmup_min,
subset.token_warmup_step,
subset.preference
)
db_subsets.append(db_subset)
@@ -3471,6 +3585,21 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
default=None,
help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り",
)
parser.add_argument(
"--beta_dpo",
type=int,
help="DPO KL Divergence penalty. Recommended values for SD1.5 B=2000, SDXL B=5000 / DPO KL 発散ペナルティ。SD1.5 の推奨値 B=2000、SDXL B=5000",
)
parser.add_argument(
"--orpo_weight",
type=float,
help="ORPO weight for relative ratio loss. Recommended values of 0.1 to 0.25 / 相対比損失の ORPO 重み。推奨値は 0.1 0.25 です",
)
parser.add_argument(
"--mapo_weight",
type=float,
help="MaPO weight for relative ratio loss. Recommended values of 0.1 to 0.25 / 相対比損失の ORPO 重み。推奨値は 0.1 0.25 です",
)
if support_dreambooth:
# DreamBooth training
@@ -3714,6 +3843,30 @@ def add_dataset_arguments(
default=None,
help="suffix for caption text / captionのテキストの末尾に付ける文字列",
)
parser.add_argument(
"--preference_caption_prefix",
type=str,
default=None,
help="prefix for preference caption text / captionのテキストの先頭に付ける文字列",
)
parser.add_argument(
"--preference_caption_suffix",
type=str,
default=None,
help="suffix for preference caption text / captionのテキストの末尾に付ける文字列",
)
parser.add_argument(
"--non_preference_caption_prefix",
type=str,
default=None,
help="prefix for non-preference caption text / captionのテキストの先頭に付ける文字列",
)
parser.add_argument(
"--non_preference_caption_suffix",
type=str,
default=None,
help="suffix for non-preference caption text / captionのテキストの末尾に付ける文字列",
)
parser.add_argument(
"--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする"
)
@@ -5075,6 +5228,7 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
# Sample a random timestep for each image
b_size = latents.shape[0]
min_timestep = 0 if args.min_timestep is None else args.min_timestep
max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep
@@ -5488,7 +5642,7 @@ def sample_image_inference(
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
raise ImportError("No wandb / wandb がインストールされていないようです")
wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
wandb_tracker.log({f"sample_{i}": wandb.Image(image)}, step=steps)
except: # wandb 無効時
pass

View File

@@ -906,10 +906,93 @@ class NetworkTrainer:
)
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
if args.beta_dpo is not None:
model_loss = loss.mean(dim=list(range(1, len(loss.shape))))
model_loss_w, model_loss_l = model_loss.chunk(2)
raw_model_loss = 0.5 * (model_loss_w.mean() + model_loss_l.mean())
model_diff = model_loss_w - model_loss_l
# ref loss
with torch.no_grad():
# disable network for reference
accelerator.unwrap_model(network).set_multiplier(0.0)
with accelerator.autocast():
ref_noise_pred = self.call_unet(
args,
accelerator,
unet,
noisy_latents.requires_grad_(train_unet),
timesteps,
text_encoder_conds,
batch,
weight_dtype,
)
ref_loss = train_util.conditional_loss(
ref_noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
)
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
ref_loss = apply_masked_loss(ref_loss, batch)
ref_loss = ref_loss.mean(dim=list(range(1, len(ref_loss.shape))))
ref_losses_w, ref_losses_l = ref_loss.chunk(2)
ref_diff = ref_losses_w - ref_losses_l
raw_ref_loss = ref_loss.mean()
# reset network multipliers
accelerator.unwrap_model(network).set_multiplier(multipliers)
scale_term = -0.5 * args.beta_dpo
inside_term = scale_term * (model_diff - ref_diff)
loss = -1 * torch.nn.functional.logsigmoid(inside_term)
implicit_acc = (inside_term > 0).sum().float() / inside_term.size(0)
implicit_acc += 0.5 * (inside_term == 0).sum().float() / inside_term.size(0)
accelerator.log({
"total_loss": model_loss.detach().mean().item(),
"raw_model_loss": raw_model_loss.detach().mean().item(),
"ref_loss": raw_ref_loss.detach().item(),
"implicit_acc": implicit_acc.detach().item(),
}, step=global_step)
elif args.mapo_weight is not None:
model_loss = loss.mean(dim=list(range(1, len(loss.shape))))
snr = 0.5
model_losses_w, model_losses_l = model_loss.chunk(2)
log_odds = (snr * model_losses_w) / (torch.exp(snr * model_losses_w) - 1) - (
snr * model_losses_l
) / (torch.exp(snr * model_losses_l) - 1)
# Ratio loss.
# By multiplying T to the inner term, we try to maximize the margin throughout the overall denoising process.
ratio = torch.nn.functional.logsigmoid(log_odds * noise_scheduler.config.num_train_timesteps)
ratio_losses = args.mapo_weight * ratio
# Full MaPO loss
loss = model_losses_w.mean(dim=list(range(1, len(model_losses_w.shape)))) - ratio_losses.mean(dim=list(range(1, len(ratio_losses.shape))))
accelerator.log({
"total_loss": loss.detach().mean().item(),
"ratio_loss": -ratio_losses.mean().detach().item(),
"model_losses_w": model_losses_w.mean().detach().item(),
"model_losses_l": model_losses_l.mean().detach().item(),
"win_score": ((snr * model_losses_w) / (torch.exp(snr * model_losses_w) - 1))
.mean()
.detach()
.item(),
"lose_score": ((snr * model_losses_l) / (torch.exp(snr * model_losses_l) - 1))
.mean()
.detach()
.item()
}, step=global_step)
else:
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
timesteps = [timesteps[0]] * loss.shape[0]
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
@@ -983,7 +1066,8 @@ class NetworkTrainer:
if args.logging_dir is not None:
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)
# accelerator.log(logs, step=epoch + 1)
accelerator.log(logs, step=global_step)
accelerator.wait_for_everyone()