diff --git a/library/config_util.py b/library/config_util.py index 10b2457f..3ab0de11 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -78,6 +78,11 @@ class BaseSubsetParams: caption_tag_dropout_rate: float = 0.0 token_warmup_min: int = 1 token_warmup_step: float = 0 + preference: bool = False + preference_caption_prefix: Optional[str] = None + preference_caption_suffix: Optional[str] = None + non_preference_caption_prefix: Optional[str] = None + non_preference_caption_suffix: Optional[str] = None @dataclass @@ -199,6 +204,11 @@ class ConfigSanitizer: "token_warmup_step": Any(float, int), "caption_prefix": str, "caption_suffix": str, + "preference": bool, + "preference_caption_prefix": str, + "preference_caption_suffix": str, + "non_preference_caption_prefix": str, + "non_preference_caption_suffix": str } # DO means DropOut DO_SUBSET_ASCENDABLE_SCHEMA = { @@ -540,14 +550,29 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu flip_aug: {subset.flip_aug} face_crop_aug_range: {subset.face_crop_aug_range} random_crop: {subset.random_crop} - token_warmup_min: {subset.token_warmup_min}, - token_warmup_step: {subset.token_warmup_step}, - alpha_mask: {subset.alpha_mask}, + token_warmup_min: {subset.token_warmup_min} + token_warmup_step: {subset.token_warmup_step} + alpha_mask: {subset.alpha_mask} + preference: {subset.preference} """ ), " ", ) + if subset.preference: + info += indent( + dedent( + + f"""\ + preference_caption_prefix: {subset.preference_caption_prefix} + preference_caption_suffix: {subset.preference_caption_suffix} + non_preference_caption_prefix: {subset.non_preference_caption_prefix} + non_preference_caption_suffix: {subset.non_preference_caption_suffix} + \n""" + ), + " ", + ) + if is_dreambooth: info += indent( dedent( diff --git a/library/train_util.py b/library/train_util.py index 566f5927..a5092685 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -162,6 +162,20 @@ class ImageInfo: self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime +class ImageSetInfo(ImageInfo): + def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None: + super().__init__(image_key, num_repeats, caption, is_reg, absolute_path) + + self.absolute_paths = [absolute_path] + self.captions = [caption] + self.image_sizes = [] + + def add(self, absolute_path, caption, size): + self.absolute_paths.append(absolute_path) + self.captions.append(caption) + self.image_sizes.append(size) + + class BucketManager: def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None: if max_size is not None: @@ -381,6 +395,11 @@ class BaseSubset: caption_suffix: Optional[str], token_warmup_min: int, token_warmup_step: Union[float, int], + preference: bool, + preference_caption_prefix: Optional[str], + preference_caption_suffix: Optional[str], + non_preference_caption_prefix: Optional[str], + non_preference_caption_suffix: Optional[str], ) -> None: self.image_dir = image_dir self.alpha_mask = alpha_mask if alpha_mask is not None else False @@ -404,6 +423,12 @@ class BaseSubset: self.token_warmup_min = token_warmup_min # step=0におけるタグの数 self.token_warmup_step = token_warmup_step # N(N<1ならN*max_train_steps)ステップ目でタグの数が最大になる + self.preference = preference + self.preference_caption_prefix = non_preference_caption_prefix + self.preference_caption_suffix = non_preference_caption_suffix + self.non_preference_caption_prefix = non_preference_caption_prefix + self.non_preference_caption_suffix = non_preference_caption_suffix + self.img_count = 0 @@ -434,6 +459,11 @@ class DreamBoothSubset(BaseSubset): caption_suffix, token_warmup_min, token_warmup_step, + preference: bool, + preference_caption_prefix, + preference_caption_suffix, + non_preference_caption_prefix, + non_preference_caption_suffix, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" @@ -458,6 +488,11 @@ class DreamBoothSubset(BaseSubset): caption_suffix, token_warmup_min, token_warmup_step, + preference, + preference_caption_prefix, + preference_caption_suffix, + non_preference_caption_prefix, + non_preference_caption_suffix, ) self.is_reg = is_reg @@ -1098,6 +1133,75 @@ class BaseDataset(torch.utils.data.Dataset): def get_image_size(self, image_path): return imagesize.get(image_path) + def load_and_transform_image(self, subset, image_info, absolute_path, flipped): + # 画像を読み込み、必要ならcropする + + img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info( + subset, absolute_path, subset.alpha_mask + ) + im_h, im_w = img.shape[0:2] + + if self.enable_bucket: + img, original_size, crop_ltrb = trim_and_resize_if_required( + subset.random_crop, img, image_info.bucket_reso, image_info.resized_size + ) + else: + if face_cx > 0: # 顔位置情報あり + img = self.crop_target(subset, img, face_cx, face_cy, face_w, face_h) + elif im_h > self.height or im_w > self.width: + assert ( + subset.random_crop + ), f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}" + if im_h > self.height: + p = random.randint(0, im_h - self.height) + img = img[p : p + self.height] + if im_w > self.width: + p = random.randint(0, im_w - self.width) + img = img[:, p : p + self.width] + + im_h, im_w = img.shape[0:2] + assert ( + im_h == self.height and im_w == self.width + ), f"image size is small / 画像サイズが小さいようです: {absolute_path}" + + original_size = [im_w, im_h] + crop_ltrb = (0, 0, 0, 0) + + # augmentation + aug = self.aug_helper.get_augmentor(subset.color_aug) + if aug is not None: + # augment RGB channels only + img_rgb = img[:, :, :3] + img_rgb = aug(image=img_rgb)["image"] + img[:, :, :3] = img_rgb + + if flipped: + img = img[:, ::-1, :].copy() # copy to avoid negative stride problem + + if subset.alpha_mask: + if img.shape[2] == 4: + alpha_mask = img[:, :, 3] # [H,W] + alpha_mask = alpha_mask.astype(np.float32) / 255.0 # 0.0~1.0 + alpha_mask = torch.FloatTensor(alpha_mask) + else: + alpha_mask = torch.ones((img.shape[0], img.shape[1]), dtype=torch.float32) + else: + alpha_mask = None + + img = img[:, :, :3] # remove alpha channel + + image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる + + target_size = (image.shape[2], image.shape[1]) if image is not None else (latents.shape[2] * 8, latents.shape[1] * 8) + + if not flipped: + crop_left_top = (crop_ltrb[0], crop_ltrb[1]) + else: + # crop_ltrb[2] is right, so target_size[0] - crop_ltrb[2] is left in flipped image + crop_left_top = (target_size[0] - crop_ltrb[2], crop_ltrb[1]) + + return image, original_size, crop_left_top, alpha_mask + def load_image_with_face_info(self, subset: BaseSubset, image_path: str, alpha_mask=False): img = load_image(image_path, alpha_mask) @@ -1171,20 +1275,26 @@ class BaseDataset(torch.utils.data.Dataset): if self.caching_mode is not None: # return batch for latents/text encoder outputs caching return self.get_item_for_caching(bucket, bucket_batch_size, image_index) - loss_weights = [] - captions = [] - input_ids_list = [] - input_ids2_list = [] - latents_list = [] - alpha_mask_list = [] - images = [] - original_sizes_hw = [] - crop_top_lefts = [] - target_sizes_hw = [] - flippeds = [] # 変数名が微妙 - text_encoder_outputs1_list = [] - text_encoder_outputs2_list = [] - text_encoder_pool2_list = [] + loss_weights: List[float] = [] + + images: List[Image] = [] + latents_list: List[torch.Tensor] = [] + alpha_mask_list: List[torch.Tensor] = [] + + original_sizes_hw: List[Tuple[int, int]] = [] + crop_top_lefts: List[Tuple[int, int]] = [] + target_sizes_hw: List[Tuple[int, int]] = [] + + flippeds: List[bool] = [] # 変数名が微妙 + + captions: List[str] = [] + + input_ids_list: List[List[int]] = [] + input_ids2_list: List[List[int]] = [] + + text_encoder_outputs1_list: List[Optional[torch.Tensor]] = [] + text_encoder_outputs2_list: List[Optional[torch.Tensor]] = [] + text_encoder_pool2_list: List[Optional[torch.Tensor]] = [] for image_key in bucket[image_index : image_index + bucket_batch_size]: image_info = self.image_data[image_key] @@ -1207,6 +1317,9 @@ class BaseDataset(torch.utils.data.Dataset): alpha_mask = None if image_info.alpha_mask is None else torch.flip(image_info.alpha_mask, [1]) image = None + + images.append(image) + latents_list.append(latents) elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合 latents, original_size, crop_ltrb, flipped_latents, alpha_mask = load_latents_from_disk(image_info.latents_npz) if flipped: @@ -1218,69 +1331,22 @@ class BaseDataset(torch.utils.data.Dataset): alpha_mask = torch.FloatTensor(alpha_mask) image = None + + images.append(image) + latents_list.append(latents) + alpha_mask_list.append(alpha_mask) else: - # 画像を読み込み、必要ならcropする - img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info( - subset, image_info.absolute_path, subset.alpha_mask - ) - im_h, im_w = img.shape[0:2] - - if self.enable_bucket: - img, original_size, crop_ltrb = trim_and_resize_if_required( - subset.random_crop, img, image_info.bucket_reso, image_info.resized_size - ) + if isinstance(image_info, ImageSetInfo): + for absolute_path in image_info.absolute_paths: + image, original_size, crop_ltrb, alpha_mask = self.load_and_transform_image(subset, image_info, absolute_path, flipped) + images.append(image) + latents_list.append(None) + alpha_mask_list.append(alpha_mask) else: - if face_cx > 0: # 顔位置情報あり - img = self.crop_target(subset, img, face_cx, face_cy, face_w, face_h) - elif im_h > self.height or im_w > self.width: - assert ( - subset.random_crop - ), f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}" - if im_h > self.height: - p = random.randint(0, im_h - self.height) - img = img[p : p + self.height] - if im_w > self.width: - p = random.randint(0, im_w - self.width) - img = img[:, p : p + self.width] - - im_h, im_w = img.shape[0:2] - assert ( - im_h == self.height and im_w == self.width - ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}" - - original_size = [im_w, im_h] - crop_ltrb = (0, 0, 0, 0) - - # augmentation - aug = self.aug_helper.get_augmentor(subset.color_aug) - if aug is not None: - # augment RGB channels only - img_rgb = img[:, :, :3] - img_rgb = aug(image=img_rgb)["image"] - img[:, :, :3] = img_rgb - - if flipped: - img = img[:, ::-1, :].copy() # copy to avoid negative stride problem - - if subset.alpha_mask: - if img.shape[2] == 4: - alpha_mask = img[:, :, 3] # [H,W] - alpha_mask = alpha_mask.astype(np.float32) / 255.0 # 0.0~1.0 - alpha_mask = torch.FloatTensor(alpha_mask) - else: - alpha_mask = torch.ones((img.shape[0], img.shape[1]), dtype=torch.float32) - else: - alpha_mask = None - - img = img[:, :, :3] # remove alpha channel - - latents = None - image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる - del img - - images.append(image) - latents_list.append(latents) - alpha_mask_list.append(alpha_mask) + image, original_size, crop_ltrb, alpha_mask = self.load_and_transform_image(subset, image_info, image_info.absolute_path, flipped) + images.append(image) + latents_list.append(None) + alpha_mask_list.append(alpha_mask) target_size = (image.shape[2], image.shape[1]) if image is not None else (latents.shape[2] * 8, latents.shape[1] * 8) @@ -1311,31 +1377,41 @@ class BaseDataset(torch.utils.data.Dataset): text_encoder_pool2_list.append(text_encoder_pool2) captions.append(caption) else: - caption = self.process_caption(subset, image_info.caption) - if self.XTI_layers: - caption_layer = [] - for layer in self.XTI_layers: - token_strings_from = " ".join(self.token_strings) - token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings]) - caption_ = caption.replace(token_strings_from, token_strings_to) - caption_layer.append(caption_) - captions.append(caption_layer) + captions_list = [] + + if isinstance(image_info, ImageSetInfo): + for image_info_caption in image_info.captions: + caption = self.process_caption(subset, image_info_caption) + captions_list.append(caption) else: - captions.append(caption) + caption = self.process_caption(subset, image_info.caption) + captions_list.append(caption) - if not self.token_padding_disabled: # this option might be omitted in future + for caption in captions_list: if self.XTI_layers: - token_caption = self.get_input_ids(caption_layer, self.tokenizers[0]) + caption_layer = [] + for layer in self.XTI_layers: + token_strings_from = " ".join(self.token_strings) + token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings]) + caption_ = caption.replace(token_strings_from, token_strings_to) + caption_layer.append(caption_) + captions.append(caption_layer) else: - token_caption = self.get_input_ids(caption, self.tokenizers[0]) - input_ids_list.append(token_caption) + captions.append(caption) - if len(self.tokenizers) > 1: + if not self.token_padding_disabled: # this option might be omitted in future if self.XTI_layers: - token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1]) + token_caption = self.get_input_ids(caption_layer, self.tokenizers[0]) else: - token_caption2 = self.get_input_ids(caption, self.tokenizers[1]) - input_ids2_list.append(token_caption2) + token_caption = self.get_input_ids(caption, self.tokenizers[0]) + input_ids_list.append(token_caption) + + if len(self.tokenizers) > 1: + if self.XTI_layers: + token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1]) + else: + token_caption2 = self.get_input_ids(caption, self.tokenizers[1]) + input_ids2_list.append(token_caption2) example = {} example["loss_weights"] = torch.FloatTensor(loss_weights) @@ -1568,6 +1644,11 @@ class DreamBoothDataset(BaseDataset): img_paths = list(metas.keys()) sizes = [meta["resolution"] for meta in metas.values()] + elif subset.preference: + # We assume a image_dir path pattern for winner/loser + winner_path = str(pathlib.Path(subset.image_dir) / "w") + img_paths = glob_images(winner_path, "*") + sizes = [None] * len(img_paths) # we may need to check image size and existence of image files, but it takes time, so user should check it before training else: img_paths = glob_images(subset.image_dir, "*") @@ -1656,9 +1737,41 @@ class DreamBoothDataset(BaseDataset): num_train_images += subset.num_repeats * len(img_paths) for img_path, caption, size in zip(img_paths, captions, sizes): - info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path) - if size is not None: - info.image_size = size + if subset.preference: + def get_non_preferred_pair_info(img_path, subset): + head, file = os.path.split(img_path) + head, tail = os.path.split(head) + new_tail = tail.replace('w', 'l') + loser_img_path = os.path.join(head, new_tail, file) + + caption = read_caption(img_path, subset.caption_extension, subset.enable_wildcard) + + if subset.non_preference_caption_prefix: + caption = subset.non_preference_caption_prefix + " " + caption + if subset.non_preference_caption_suffix: + caption = caption + " " + subset.non_preference_caption_suffix + + image_size = self.get_image_size(img_path) if size is not None else None + + return loser_img_path, caption, image_size + + if subset.preference_caption_prefix: + caption = subset.preference_caption_prefix + " " + caption + if subset.preference_caption_suffix: + caption = caption + " " + subset.preference_caption_suffix + + info = ImageSetInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path) + if size is not None: + info.image_size = size + info.image_sizes = [size] + else: + info.image_sizes = [None] + info.add(*get_non_preferred_pair_info(img_path, subset)) + else: + info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path) + if size is not None: + info.image_size = size + if subset.is_reg: reg_infos.append((info, subset)) else: @@ -1968,6 +2081,7 @@ class ControlNetDataset(BaseDataset): subset.caption_suffix, subset.token_warmup_min, subset.token_warmup_step, + subset.preference ) db_subsets.append(db_subset) @@ -3471,6 +3585,21 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: default=None, help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り", ) + parser.add_argument( + "--beta_dpo", + type=int, + help="DPO KL Divergence penalty. Recommended values for SD1.5 B=2000, SDXL B=5000 / DPO KL 発散ペナルティ。SD1.5 の推奨値 B=2000、SDXL B=5000", + ) + parser.add_argument( + "--orpo_weight", + type=float, + help="ORPO weight for relative ratio loss. Recommended values of 0.1 to 0.25 / 相対比損失の ORPO 重み。推奨値は 0.1 ~ 0.25 です", + ) + parser.add_argument( + "--mapo_weight", + type=float, + help="MaPO weight for relative ratio loss. Recommended values of 0.1 to 0.25 / 相対比損失の ORPO 重み。推奨値は 0.1 ~ 0.25 です", + ) if support_dreambooth: # DreamBooth training @@ -3714,6 +3843,30 @@ def add_dataset_arguments( default=None, help="suffix for caption text / captionのテキストの末尾に付ける文字列", ) + parser.add_argument( + "--preference_caption_prefix", + type=str, + default=None, + help="prefix for preference caption text / captionのテキストの先頭に付ける文字列", + ) + parser.add_argument( + "--preference_caption_suffix", + type=str, + default=None, + help="suffix for preference caption text / captionのテキストの末尾に付ける文字列", + ) + parser.add_argument( + "--non_preference_caption_prefix", + type=str, + default=None, + help="prefix for non-preference caption text / captionのテキストの先頭に付ける文字列", + ) + parser.add_argument( + "--non_preference_caption_suffix", + type=str, + default=None, + help="suffix for non-preference caption text / captionのテキストの末尾に付ける文字列", + ) parser.add_argument( "--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする" ) @@ -5075,6 +5228,7 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): # Sample a random timestep for each image b_size = latents.shape[0] + min_timestep = 0 if args.min_timestep is None else args.min_timestep max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep @@ -5488,7 +5642,7 @@ def sample_image_inference( except ImportError: # 事前に一度確認するのでここはエラー出ないはず raise ImportError("No wandb / wandb がインストールされていないようです") - wandb_tracker.log({f"sample_{i}": wandb.Image(image)}) + wandb_tracker.log({f"sample_{i}": wandb.Image(image)}, step=steps) except: # wandb 無効時 pass diff --git a/train_network.py b/train_network.py index b272a6e1..d56e0b61 100644 --- a/train_network.py +++ b/train_network.py @@ -906,10 +906,93 @@ class NetworkTrainer: ) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) - loss = loss.mean([1, 2, 3]) + + if args.beta_dpo is not None: + model_loss = loss.mean(dim=list(range(1, len(loss.shape)))) + model_loss_w, model_loss_l = model_loss.chunk(2) + raw_model_loss = 0.5 * (model_loss_w.mean() + model_loss_l.mean()) + model_diff = model_loss_w - model_loss_l + + # ref loss + with torch.no_grad(): + # disable network for reference + accelerator.unwrap_model(network).set_multiplier(0.0) + + with accelerator.autocast(): + ref_noise_pred = self.call_unet( + args, + accelerator, + unet, + noisy_latents.requires_grad_(train_unet), + timesteps, + text_encoder_conds, + batch, + weight_dtype, + ) + ref_loss = train_util.conditional_loss( + ref_noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + ) + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + ref_loss = apply_masked_loss(ref_loss, batch) + + ref_loss = ref_loss.mean(dim=list(range(1, len(ref_loss.shape)))) + ref_losses_w, ref_losses_l = ref_loss.chunk(2) + ref_diff = ref_losses_w - ref_losses_l + raw_ref_loss = ref_loss.mean() + + # reset network multipliers + accelerator.unwrap_model(network).set_multiplier(multipliers) + + scale_term = -0.5 * args.beta_dpo + inside_term = scale_term * (model_diff - ref_diff) + loss = -1 * torch.nn.functional.logsigmoid(inside_term) + + implicit_acc = (inside_term > 0).sum().float() / inside_term.size(0) + implicit_acc += 0.5 * (inside_term == 0).sum().float() / inside_term.size(0) + + accelerator.log({ + "total_loss": model_loss.detach().mean().item(), + "raw_model_loss": raw_model_loss.detach().mean().item(), + "ref_loss": raw_ref_loss.detach().item(), + "implicit_acc": implicit_acc.detach().item(), + }, step=global_step) + elif args.mapo_weight is not None: + model_loss = loss.mean(dim=list(range(1, len(loss.shape)))) + + snr = 0.5 + model_losses_w, model_losses_l = model_loss.chunk(2) + log_odds = (snr * model_losses_w) / (torch.exp(snr * model_losses_w) - 1) - ( + snr * model_losses_l + ) / (torch.exp(snr * model_losses_l) - 1) + + # Ratio loss. + # By multiplying T to the inner term, we try to maximize the margin throughout the overall denoising process. + ratio = torch.nn.functional.logsigmoid(log_odds * noise_scheduler.config.num_train_timesteps) + ratio_losses = args.mapo_weight * ratio + + # Full MaPO loss + loss = model_losses_w.mean(dim=list(range(1, len(model_losses_w.shape)))) - ratio_losses.mean(dim=list(range(1, len(ratio_losses.shape)))) + + accelerator.log({ + "total_loss": loss.detach().mean().item(), + "ratio_loss": -ratio_losses.mean().detach().item(), + "model_losses_w": model_losses_w.mean().detach().item(), + "model_losses_l": model_losses_l.mean().detach().item(), + "win_score": ((snr * model_losses_w) / (torch.exp(snr * model_losses_w) - 1)) + .mean() + .detach() + .item(), + "lose_score": ((snr * model_losses_l) / (torch.exp(snr * model_losses_l) - 1)) + .mean() + .detach() + .item() + }, step=global_step) + else: + loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight loss = loss * loss_weights + timesteps = [timesteps[0]] * loss.shape[0] if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) @@ -983,7 +1066,8 @@ class NetworkTrainer: if args.logging_dir is not None: logs = {"loss/epoch": loss_recorder.moving_average} - accelerator.log(logs, step=epoch + 1) + # accelerator.log(logs, step=epoch + 1) + accelerator.log(logs, step=global_step) accelerator.wait_for_everyone()