mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 08:36:41 +00:00
Preference optimization with MaPO and Diffusion-DPO
This commit is contained in:
@@ -78,6 +78,11 @@ class BaseSubsetParams:
|
||||
caption_tag_dropout_rate: float = 0.0
|
||||
token_warmup_min: int = 1
|
||||
token_warmup_step: float = 0
|
||||
preference: bool = False
|
||||
preference_caption_prefix: Optional[str] = None
|
||||
preference_caption_suffix: Optional[str] = None
|
||||
non_preference_caption_prefix: Optional[str] = None
|
||||
non_preference_caption_suffix: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -199,6 +204,11 @@ class ConfigSanitizer:
|
||||
"token_warmup_step": Any(float, int),
|
||||
"caption_prefix": str,
|
||||
"caption_suffix": str,
|
||||
"preference": bool,
|
||||
"preference_caption_prefix": str,
|
||||
"preference_caption_suffix": str,
|
||||
"non_preference_caption_prefix": str,
|
||||
"non_preference_caption_suffix": str
|
||||
}
|
||||
# DO means DropOut
|
||||
DO_SUBSET_ASCENDABLE_SCHEMA = {
|
||||
@@ -540,14 +550,29 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
flip_aug: {subset.flip_aug}
|
||||
face_crop_aug_range: {subset.face_crop_aug_range}
|
||||
random_crop: {subset.random_crop}
|
||||
token_warmup_min: {subset.token_warmup_min},
|
||||
token_warmup_step: {subset.token_warmup_step},
|
||||
alpha_mask: {subset.alpha_mask},
|
||||
token_warmup_min: {subset.token_warmup_min}
|
||||
token_warmup_step: {subset.token_warmup_step}
|
||||
alpha_mask: {subset.alpha_mask}
|
||||
preference: {subset.preference}
|
||||
"""
|
||||
),
|
||||
" ",
|
||||
)
|
||||
|
||||
if subset.preference:
|
||||
info += indent(
|
||||
dedent(
|
||||
|
||||
f"""\
|
||||
preference_caption_prefix: {subset.preference_caption_prefix}
|
||||
preference_caption_suffix: {subset.preference_caption_suffix}
|
||||
non_preference_caption_prefix: {subset.non_preference_caption_prefix}
|
||||
non_preference_caption_suffix: {subset.non_preference_caption_suffix}
|
||||
\n"""
|
||||
),
|
||||
" ",
|
||||
)
|
||||
|
||||
if is_dreambooth:
|
||||
info += indent(
|
||||
dedent(
|
||||
|
||||
@@ -162,6 +162,20 @@ class ImageInfo:
|
||||
self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
|
||||
|
||||
|
||||
class ImageSetInfo(ImageInfo):
|
||||
def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None:
|
||||
super().__init__(image_key, num_repeats, caption, is_reg, absolute_path)
|
||||
|
||||
self.absolute_paths = [absolute_path]
|
||||
self.captions = [caption]
|
||||
self.image_sizes = []
|
||||
|
||||
def add(self, absolute_path, caption, size):
|
||||
self.absolute_paths.append(absolute_path)
|
||||
self.captions.append(caption)
|
||||
self.image_sizes.append(size)
|
||||
|
||||
|
||||
class BucketManager:
|
||||
def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None:
|
||||
if max_size is not None:
|
||||
@@ -381,6 +395,11 @@ class BaseSubset:
|
||||
caption_suffix: Optional[str],
|
||||
token_warmup_min: int,
|
||||
token_warmup_step: Union[float, int],
|
||||
preference: bool,
|
||||
preference_caption_prefix: Optional[str],
|
||||
preference_caption_suffix: Optional[str],
|
||||
non_preference_caption_prefix: Optional[str],
|
||||
non_preference_caption_suffix: Optional[str],
|
||||
) -> None:
|
||||
self.image_dir = image_dir
|
||||
self.alpha_mask = alpha_mask if alpha_mask is not None else False
|
||||
@@ -404,6 +423,12 @@ class BaseSubset:
|
||||
self.token_warmup_min = token_warmup_min # step=0におけるタグの数
|
||||
self.token_warmup_step = token_warmup_step # N(N<1ならN*max_train_steps)ステップ目でタグの数が最大になる
|
||||
|
||||
self.preference = preference
|
||||
self.preference_caption_prefix = non_preference_caption_prefix
|
||||
self.preference_caption_suffix = non_preference_caption_suffix
|
||||
self.non_preference_caption_prefix = non_preference_caption_prefix
|
||||
self.non_preference_caption_suffix = non_preference_caption_suffix
|
||||
|
||||
self.img_count = 0
|
||||
|
||||
|
||||
@@ -434,6 +459,11 @@ class DreamBoothSubset(BaseSubset):
|
||||
caption_suffix,
|
||||
token_warmup_min,
|
||||
token_warmup_step,
|
||||
preference: bool,
|
||||
preference_caption_prefix,
|
||||
preference_caption_suffix,
|
||||
non_preference_caption_prefix,
|
||||
non_preference_caption_suffix,
|
||||
) -> None:
|
||||
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
|
||||
|
||||
@@ -458,6 +488,11 @@ class DreamBoothSubset(BaseSubset):
|
||||
caption_suffix,
|
||||
token_warmup_min,
|
||||
token_warmup_step,
|
||||
preference,
|
||||
preference_caption_prefix,
|
||||
preference_caption_suffix,
|
||||
non_preference_caption_prefix,
|
||||
non_preference_caption_suffix,
|
||||
)
|
||||
|
||||
self.is_reg = is_reg
|
||||
@@ -1098,6 +1133,75 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
def get_image_size(self, image_path):
|
||||
return imagesize.get(image_path)
|
||||
|
||||
def load_and_transform_image(self, subset, image_info, absolute_path, flipped):
|
||||
# 画像を読み込み、必要ならcropする
|
||||
|
||||
img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(
|
||||
subset, absolute_path, subset.alpha_mask
|
||||
)
|
||||
im_h, im_w = img.shape[0:2]
|
||||
|
||||
if self.enable_bucket:
|
||||
img, original_size, crop_ltrb = trim_and_resize_if_required(
|
||||
subset.random_crop, img, image_info.bucket_reso, image_info.resized_size
|
||||
)
|
||||
else:
|
||||
if face_cx > 0: # 顔位置情報あり
|
||||
img = self.crop_target(subset, img, face_cx, face_cy, face_w, face_h)
|
||||
elif im_h > self.height or im_w > self.width:
|
||||
assert (
|
||||
subset.random_crop
|
||||
), f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}"
|
||||
if im_h > self.height:
|
||||
p = random.randint(0, im_h - self.height)
|
||||
img = img[p : p + self.height]
|
||||
if im_w > self.width:
|
||||
p = random.randint(0, im_w - self.width)
|
||||
img = img[:, p : p + self.width]
|
||||
|
||||
im_h, im_w = img.shape[0:2]
|
||||
assert (
|
||||
im_h == self.height and im_w == self.width
|
||||
), f"image size is small / 画像サイズが小さいようです: {absolute_path}"
|
||||
|
||||
original_size = [im_w, im_h]
|
||||
crop_ltrb = (0, 0, 0, 0)
|
||||
|
||||
# augmentation
|
||||
aug = self.aug_helper.get_augmentor(subset.color_aug)
|
||||
if aug is not None:
|
||||
# augment RGB channels only
|
||||
img_rgb = img[:, :, :3]
|
||||
img_rgb = aug(image=img_rgb)["image"]
|
||||
img[:, :, :3] = img_rgb
|
||||
|
||||
if flipped:
|
||||
img = img[:, ::-1, :].copy() # copy to avoid negative stride problem
|
||||
|
||||
if subset.alpha_mask:
|
||||
if img.shape[2] == 4:
|
||||
alpha_mask = img[:, :, 3] # [H,W]
|
||||
alpha_mask = alpha_mask.astype(np.float32) / 255.0 # 0.0~1.0
|
||||
alpha_mask = torch.FloatTensor(alpha_mask)
|
||||
else:
|
||||
alpha_mask = torch.ones((img.shape[0], img.shape[1]), dtype=torch.float32)
|
||||
else:
|
||||
alpha_mask = None
|
||||
|
||||
img = img[:, :, :3] # remove alpha channel
|
||||
|
||||
image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
|
||||
|
||||
target_size = (image.shape[2], image.shape[1]) if image is not None else (latents.shape[2] * 8, latents.shape[1] * 8)
|
||||
|
||||
if not flipped:
|
||||
crop_left_top = (crop_ltrb[0], crop_ltrb[1])
|
||||
else:
|
||||
# crop_ltrb[2] is right, so target_size[0] - crop_ltrb[2] is left in flipped image
|
||||
crop_left_top = (target_size[0] - crop_ltrb[2], crop_ltrb[1])
|
||||
|
||||
return image, original_size, crop_left_top, alpha_mask
|
||||
|
||||
def load_image_with_face_info(self, subset: BaseSubset, image_path: str, alpha_mask=False):
|
||||
img = load_image(image_path, alpha_mask)
|
||||
|
||||
@@ -1171,20 +1275,26 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
if self.caching_mode is not None: # return batch for latents/text encoder outputs caching
|
||||
return self.get_item_for_caching(bucket, bucket_batch_size, image_index)
|
||||
|
||||
loss_weights = []
|
||||
captions = []
|
||||
input_ids_list = []
|
||||
input_ids2_list = []
|
||||
latents_list = []
|
||||
alpha_mask_list = []
|
||||
images = []
|
||||
original_sizes_hw = []
|
||||
crop_top_lefts = []
|
||||
target_sizes_hw = []
|
||||
flippeds = [] # 変数名が微妙
|
||||
text_encoder_outputs1_list = []
|
||||
text_encoder_outputs2_list = []
|
||||
text_encoder_pool2_list = []
|
||||
loss_weights: List[float] = []
|
||||
|
||||
images: List[Image] = []
|
||||
latents_list: List[torch.Tensor] = []
|
||||
alpha_mask_list: List[torch.Tensor] = []
|
||||
|
||||
original_sizes_hw: List[Tuple[int, int]] = []
|
||||
crop_top_lefts: List[Tuple[int, int]] = []
|
||||
target_sizes_hw: List[Tuple[int, int]] = []
|
||||
|
||||
flippeds: List[bool] = [] # 変数名が微妙
|
||||
|
||||
captions: List[str] = []
|
||||
|
||||
input_ids_list: List[List[int]] = []
|
||||
input_ids2_list: List[List[int]] = []
|
||||
|
||||
text_encoder_outputs1_list: List[Optional[torch.Tensor]] = []
|
||||
text_encoder_outputs2_list: List[Optional[torch.Tensor]] = []
|
||||
text_encoder_pool2_list: List[Optional[torch.Tensor]] = []
|
||||
|
||||
for image_key in bucket[image_index : image_index + bucket_batch_size]:
|
||||
image_info = self.image_data[image_key]
|
||||
@@ -1207,6 +1317,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
alpha_mask = None if image_info.alpha_mask is None else torch.flip(image_info.alpha_mask, [1])
|
||||
|
||||
image = None
|
||||
|
||||
images.append(image)
|
||||
latents_list.append(latents)
|
||||
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
|
||||
latents, original_size, crop_ltrb, flipped_latents, alpha_mask = load_latents_from_disk(image_info.latents_npz)
|
||||
if flipped:
|
||||
@@ -1218,69 +1331,22 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
alpha_mask = torch.FloatTensor(alpha_mask)
|
||||
|
||||
image = None
|
||||
|
||||
images.append(image)
|
||||
latents_list.append(latents)
|
||||
alpha_mask_list.append(alpha_mask)
|
||||
else:
|
||||
# 画像を読み込み、必要ならcropする
|
||||
img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(
|
||||
subset, image_info.absolute_path, subset.alpha_mask
|
||||
)
|
||||
im_h, im_w = img.shape[0:2]
|
||||
|
||||
if self.enable_bucket:
|
||||
img, original_size, crop_ltrb = trim_and_resize_if_required(
|
||||
subset.random_crop, img, image_info.bucket_reso, image_info.resized_size
|
||||
)
|
||||
if isinstance(image_info, ImageSetInfo):
|
||||
for absolute_path in image_info.absolute_paths:
|
||||
image, original_size, crop_ltrb, alpha_mask = self.load_and_transform_image(subset, image_info, absolute_path, flipped)
|
||||
images.append(image)
|
||||
latents_list.append(None)
|
||||
alpha_mask_list.append(alpha_mask)
|
||||
else:
|
||||
if face_cx > 0: # 顔位置情報あり
|
||||
img = self.crop_target(subset, img, face_cx, face_cy, face_w, face_h)
|
||||
elif im_h > self.height or im_w > self.width:
|
||||
assert (
|
||||
subset.random_crop
|
||||
), f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}"
|
||||
if im_h > self.height:
|
||||
p = random.randint(0, im_h - self.height)
|
||||
img = img[p : p + self.height]
|
||||
if im_w > self.width:
|
||||
p = random.randint(0, im_w - self.width)
|
||||
img = img[:, p : p + self.width]
|
||||
|
||||
im_h, im_w = img.shape[0:2]
|
||||
assert (
|
||||
im_h == self.height and im_w == self.width
|
||||
), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
|
||||
|
||||
original_size = [im_w, im_h]
|
||||
crop_ltrb = (0, 0, 0, 0)
|
||||
|
||||
# augmentation
|
||||
aug = self.aug_helper.get_augmentor(subset.color_aug)
|
||||
if aug is not None:
|
||||
# augment RGB channels only
|
||||
img_rgb = img[:, :, :3]
|
||||
img_rgb = aug(image=img_rgb)["image"]
|
||||
img[:, :, :3] = img_rgb
|
||||
|
||||
if flipped:
|
||||
img = img[:, ::-1, :].copy() # copy to avoid negative stride problem
|
||||
|
||||
if subset.alpha_mask:
|
||||
if img.shape[2] == 4:
|
||||
alpha_mask = img[:, :, 3] # [H,W]
|
||||
alpha_mask = alpha_mask.astype(np.float32) / 255.0 # 0.0~1.0
|
||||
alpha_mask = torch.FloatTensor(alpha_mask)
|
||||
else:
|
||||
alpha_mask = torch.ones((img.shape[0], img.shape[1]), dtype=torch.float32)
|
||||
else:
|
||||
alpha_mask = None
|
||||
|
||||
img = img[:, :, :3] # remove alpha channel
|
||||
|
||||
latents = None
|
||||
image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
|
||||
del img
|
||||
|
||||
images.append(image)
|
||||
latents_list.append(latents)
|
||||
alpha_mask_list.append(alpha_mask)
|
||||
image, original_size, crop_ltrb, alpha_mask = self.load_and_transform_image(subset, image_info, image_info.absolute_path, flipped)
|
||||
images.append(image)
|
||||
latents_list.append(None)
|
||||
alpha_mask_list.append(alpha_mask)
|
||||
|
||||
target_size = (image.shape[2], image.shape[1]) if image is not None else (latents.shape[2] * 8, latents.shape[1] * 8)
|
||||
|
||||
@@ -1311,31 +1377,41 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
text_encoder_pool2_list.append(text_encoder_pool2)
|
||||
captions.append(caption)
|
||||
else:
|
||||
caption = self.process_caption(subset, image_info.caption)
|
||||
if self.XTI_layers:
|
||||
caption_layer = []
|
||||
for layer in self.XTI_layers:
|
||||
token_strings_from = " ".join(self.token_strings)
|
||||
token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings])
|
||||
caption_ = caption.replace(token_strings_from, token_strings_to)
|
||||
caption_layer.append(caption_)
|
||||
captions.append(caption_layer)
|
||||
captions_list = []
|
||||
|
||||
if isinstance(image_info, ImageSetInfo):
|
||||
for image_info_caption in image_info.captions:
|
||||
caption = self.process_caption(subset, image_info_caption)
|
||||
captions_list.append(caption)
|
||||
else:
|
||||
captions.append(caption)
|
||||
caption = self.process_caption(subset, image_info.caption)
|
||||
captions_list.append(caption)
|
||||
|
||||
if not self.token_padding_disabled: # this option might be omitted in future
|
||||
for caption in captions_list:
|
||||
if self.XTI_layers:
|
||||
token_caption = self.get_input_ids(caption_layer, self.tokenizers[0])
|
||||
caption_layer = []
|
||||
for layer in self.XTI_layers:
|
||||
token_strings_from = " ".join(self.token_strings)
|
||||
token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings])
|
||||
caption_ = caption.replace(token_strings_from, token_strings_to)
|
||||
caption_layer.append(caption_)
|
||||
captions.append(caption_layer)
|
||||
else:
|
||||
token_caption = self.get_input_ids(caption, self.tokenizers[0])
|
||||
input_ids_list.append(token_caption)
|
||||
captions.append(caption)
|
||||
|
||||
if len(self.tokenizers) > 1:
|
||||
if not self.token_padding_disabled: # this option might be omitted in future
|
||||
if self.XTI_layers:
|
||||
token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1])
|
||||
token_caption = self.get_input_ids(caption_layer, self.tokenizers[0])
|
||||
else:
|
||||
token_caption2 = self.get_input_ids(caption, self.tokenizers[1])
|
||||
input_ids2_list.append(token_caption2)
|
||||
token_caption = self.get_input_ids(caption, self.tokenizers[0])
|
||||
input_ids_list.append(token_caption)
|
||||
|
||||
if len(self.tokenizers) > 1:
|
||||
if self.XTI_layers:
|
||||
token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1])
|
||||
else:
|
||||
token_caption2 = self.get_input_ids(caption, self.tokenizers[1])
|
||||
input_ids2_list.append(token_caption2)
|
||||
|
||||
example = {}
|
||||
example["loss_weights"] = torch.FloatTensor(loss_weights)
|
||||
@@ -1568,6 +1644,11 @@ class DreamBoothDataset(BaseDataset):
|
||||
img_paths = list(metas.keys())
|
||||
sizes = [meta["resolution"] for meta in metas.values()]
|
||||
|
||||
elif subset.preference:
|
||||
# We assume a image_dir path pattern for winner/loser
|
||||
winner_path = str(pathlib.Path(subset.image_dir) / "w")
|
||||
img_paths = glob_images(winner_path, "*")
|
||||
sizes = [None] * len(img_paths)
|
||||
# we may need to check image size and existence of image files, but it takes time, so user should check it before training
|
||||
else:
|
||||
img_paths = glob_images(subset.image_dir, "*")
|
||||
@@ -1656,9 +1737,41 @@ class DreamBoothDataset(BaseDataset):
|
||||
num_train_images += subset.num_repeats * len(img_paths)
|
||||
|
||||
for img_path, caption, size in zip(img_paths, captions, sizes):
|
||||
info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path)
|
||||
if size is not None:
|
||||
info.image_size = size
|
||||
if subset.preference:
|
||||
def get_non_preferred_pair_info(img_path, subset):
|
||||
head, file = os.path.split(img_path)
|
||||
head, tail = os.path.split(head)
|
||||
new_tail = tail.replace('w', 'l')
|
||||
loser_img_path = os.path.join(head, new_tail, file)
|
||||
|
||||
caption = read_caption(img_path, subset.caption_extension, subset.enable_wildcard)
|
||||
|
||||
if subset.non_preference_caption_prefix:
|
||||
caption = subset.non_preference_caption_prefix + " " + caption
|
||||
if subset.non_preference_caption_suffix:
|
||||
caption = caption + " " + subset.non_preference_caption_suffix
|
||||
|
||||
image_size = self.get_image_size(img_path) if size is not None else None
|
||||
|
||||
return loser_img_path, caption, image_size
|
||||
|
||||
if subset.preference_caption_prefix:
|
||||
caption = subset.preference_caption_prefix + " " + caption
|
||||
if subset.preference_caption_suffix:
|
||||
caption = caption + " " + subset.preference_caption_suffix
|
||||
|
||||
info = ImageSetInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path)
|
||||
if size is not None:
|
||||
info.image_size = size
|
||||
info.image_sizes = [size]
|
||||
else:
|
||||
info.image_sizes = [None]
|
||||
info.add(*get_non_preferred_pair_info(img_path, subset))
|
||||
else:
|
||||
info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path)
|
||||
if size is not None:
|
||||
info.image_size = size
|
||||
|
||||
if subset.is_reg:
|
||||
reg_infos.append((info, subset))
|
||||
else:
|
||||
@@ -1968,6 +2081,7 @@ class ControlNetDataset(BaseDataset):
|
||||
subset.caption_suffix,
|
||||
subset.token_warmup_min,
|
||||
subset.token_warmup_step,
|
||||
subset.preference
|
||||
)
|
||||
db_subsets.append(db_subset)
|
||||
|
||||
@@ -3471,6 +3585,21 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
default=None,
|
||||
help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--beta_dpo",
|
||||
type=int,
|
||||
help="DPO KL Divergence penalty. Recommended values for SD1.5 B=2000, SDXL B=5000 / DPO KL 発散ペナルティ。SD1.5 の推奨値 B=2000、SDXL B=5000",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--orpo_weight",
|
||||
type=float,
|
||||
help="ORPO weight for relative ratio loss. Recommended values of 0.1 to 0.25 / 相対比損失の ORPO 重み。推奨値は 0.1 ~ 0.25 です",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mapo_weight",
|
||||
type=float,
|
||||
help="MaPO weight for relative ratio loss. Recommended values of 0.1 to 0.25 / 相対比損失の ORPO 重み。推奨値は 0.1 ~ 0.25 です",
|
||||
)
|
||||
|
||||
if support_dreambooth:
|
||||
# DreamBooth training
|
||||
@@ -3714,6 +3843,30 @@ def add_dataset_arguments(
|
||||
default=None,
|
||||
help="suffix for caption text / captionのテキストの末尾に付ける文字列",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--preference_caption_prefix",
|
||||
type=str,
|
||||
default=None,
|
||||
help="prefix for preference caption text / captionのテキストの先頭に付ける文字列",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--preference_caption_suffix",
|
||||
type=str,
|
||||
default=None,
|
||||
help="suffix for preference caption text / captionのテキストの末尾に付ける文字列",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--non_preference_caption_prefix",
|
||||
type=str,
|
||||
default=None,
|
||||
help="prefix for non-preference caption text / captionのテキストの先頭に付ける文字列",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--non_preference_caption_suffix",
|
||||
type=str,
|
||||
default=None,
|
||||
help="suffix for non-preference caption text / captionのテキストの末尾に付ける文字列",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする"
|
||||
)
|
||||
@@ -5075,6 +5228,7 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
|
||||
|
||||
# Sample a random timestep for each image
|
||||
b_size = latents.shape[0]
|
||||
|
||||
min_timestep = 0 if args.min_timestep is None else args.min_timestep
|
||||
max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep
|
||||
|
||||
@@ -5488,7 +5642,7 @@ def sample_image_inference(
|
||||
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
|
||||
raise ImportError("No wandb / wandb がインストールされていないようです")
|
||||
|
||||
wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
|
||||
wandb_tracker.log({f"sample_{i}": wandb.Image(image)}, step=steps)
|
||||
except: # wandb 無効時
|
||||
pass
|
||||
|
||||
|
||||
@@ -906,10 +906,93 @@ class NetworkTrainer:
|
||||
)
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
if args.beta_dpo is not None:
|
||||
model_loss = loss.mean(dim=list(range(1, len(loss.shape))))
|
||||
model_loss_w, model_loss_l = model_loss.chunk(2)
|
||||
raw_model_loss = 0.5 * (model_loss_w.mean() + model_loss_l.mean())
|
||||
model_diff = model_loss_w - model_loss_l
|
||||
|
||||
# ref loss
|
||||
with torch.no_grad():
|
||||
# disable network for reference
|
||||
accelerator.unwrap_model(network).set_multiplier(0.0)
|
||||
|
||||
with accelerator.autocast():
|
||||
ref_noise_pred = self.call_unet(
|
||||
args,
|
||||
accelerator,
|
||||
unet,
|
||||
noisy_latents.requires_grad_(train_unet),
|
||||
timesteps,
|
||||
text_encoder_conds,
|
||||
batch,
|
||||
weight_dtype,
|
||||
)
|
||||
ref_loss = train_util.conditional_loss(
|
||||
ref_noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
|
||||
)
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
ref_loss = apply_masked_loss(ref_loss, batch)
|
||||
|
||||
ref_loss = ref_loss.mean(dim=list(range(1, len(ref_loss.shape))))
|
||||
ref_losses_w, ref_losses_l = ref_loss.chunk(2)
|
||||
ref_diff = ref_losses_w - ref_losses_l
|
||||
raw_ref_loss = ref_loss.mean()
|
||||
|
||||
# reset network multipliers
|
||||
accelerator.unwrap_model(network).set_multiplier(multipliers)
|
||||
|
||||
scale_term = -0.5 * args.beta_dpo
|
||||
inside_term = scale_term * (model_diff - ref_diff)
|
||||
loss = -1 * torch.nn.functional.logsigmoid(inside_term)
|
||||
|
||||
implicit_acc = (inside_term > 0).sum().float() / inside_term.size(0)
|
||||
implicit_acc += 0.5 * (inside_term == 0).sum().float() / inside_term.size(0)
|
||||
|
||||
accelerator.log({
|
||||
"total_loss": model_loss.detach().mean().item(),
|
||||
"raw_model_loss": raw_model_loss.detach().mean().item(),
|
||||
"ref_loss": raw_ref_loss.detach().item(),
|
||||
"implicit_acc": implicit_acc.detach().item(),
|
||||
}, step=global_step)
|
||||
elif args.mapo_weight is not None:
|
||||
model_loss = loss.mean(dim=list(range(1, len(loss.shape))))
|
||||
|
||||
snr = 0.5
|
||||
model_losses_w, model_losses_l = model_loss.chunk(2)
|
||||
log_odds = (snr * model_losses_w) / (torch.exp(snr * model_losses_w) - 1) - (
|
||||
snr * model_losses_l
|
||||
) / (torch.exp(snr * model_losses_l) - 1)
|
||||
|
||||
# Ratio loss.
|
||||
# By multiplying T to the inner term, we try to maximize the margin throughout the overall denoising process.
|
||||
ratio = torch.nn.functional.logsigmoid(log_odds * noise_scheduler.config.num_train_timesteps)
|
||||
ratio_losses = args.mapo_weight * ratio
|
||||
|
||||
# Full MaPO loss
|
||||
loss = model_losses_w.mean(dim=list(range(1, len(model_losses_w.shape)))) - ratio_losses.mean(dim=list(range(1, len(ratio_losses.shape))))
|
||||
|
||||
accelerator.log({
|
||||
"total_loss": loss.detach().mean().item(),
|
||||
"ratio_loss": -ratio_losses.mean().detach().item(),
|
||||
"model_losses_w": model_losses_w.mean().detach().item(),
|
||||
"model_losses_l": model_losses_l.mean().detach().item(),
|
||||
"win_score": ((snr * model_losses_w) / (torch.exp(snr * model_losses_w) - 1))
|
||||
.mean()
|
||||
.detach()
|
||||
.item(),
|
||||
"lose_score": ((snr * model_losses_l) / (torch.exp(snr * model_losses_l) - 1))
|
||||
.mean()
|
||||
.detach()
|
||||
.item()
|
||||
}, step=global_step)
|
||||
else:
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
loss = loss * loss_weights
|
||||
timesteps = [timesteps[0]] * loss.shape[0]
|
||||
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||
@@ -983,7 +1066,8 @@ class NetworkTrainer:
|
||||
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss/epoch": loss_recorder.moving_average}
|
||||
accelerator.log(logs, step=epoch + 1)
|
||||
# accelerator.log(logs, step=epoch + 1)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user