diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 2dc86bf4..888e1850 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -150,6 +150,28 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意", ) + parser.add_argument( + "--beta_dpo", + type=int, + help="DPO KL Divergence penalty. Recommended values for SD1.5 B=2000, SDXL B=5000 / DPO KL 発散ペナルティ。SD1.5 の推奨値 B=2000、SDXL B=5000", + ) + parser.add_argument( + "--mapo_weight", + type=float, + help="MaPO weight for relative ratio loss. Recommended values of 0.1 to 0.25 / 相対比損失の ORPO 重み。推奨値は 0.1 ~ 0.25 です", + ) + parser.add_argument( + "--ddo_alpha", + type=float, + help="Controls weight of the fake samples loss term (range: 0.5-50). Higher values increase penalty on reference model samples. Start with 4.0.", + ) + parser.add_argument( + "--ddo_beta", + type=float, + help="Scaling factor for likelihood ratio (range: 0.01-0.1). Higher values create stronger separation between target and reference distributions. Start with 0.05.", + ) + + re_attention = re.compile( r""" @@ -579,6 +601,28 @@ def mapo_loss(loss: torch.Tensor, mapo_weight: float, num_train_timesteps=1000) def ddo_loss(loss, ref_loss, ddo_alpha: float = 4.0, ddo_beta: float = 0.05): + """ + Implements Direct Discriminative Optimization (DDO) loss. + + DDO bridges likelihood-based generative training with GAN objectives + by parameterizing a discriminator using the likelihood ratio between + a learnable target model and a fixed reference model. + + Args: + loss: Loss value from the target model being optimized + ref_loss: Loss value from the reference model (should be detached) + ddo_alpha: Weight coefficient for the fake samples loss term. + Controls the balance between real/fake samples in training. + Higher values increase penalty on reference model samples. + ddo_beta: Scaling factor for the likelihood ratio to control gradient magnitude. + Smaller values produce a smoother optimization landscape. + Too large values can lead to numerical instability. + + Returns: + tuple: (total_loss, metrics_dict) + - total_loss: Combined DDO loss for optimization + - metrics_dict: Dictionary containing component losses for monitoring + """ ref_loss = ref_loss.detach() # Ensure no gradients to reference log_ratio = ddo_beta * (ref_loss - loss) real_loss = -torch.log(torch.sigmoid(log_ratio) + 1e-6).mean() @@ -592,10 +636,6 @@ def ddo_loss(loss, ref_loss, ddo_alpha: float = 4.0, ddo_beta: float = 0.05): "loss/ddo_sigmoid_log_ratio": torch.sigmoid(log_ratio).mean().item(), } - # logger.debug(f"loss mean: {loss.mean().item()}, ref_loss mean: {ref_loss.mean().item()}") - # logger.debug(f"difference: {(ref_loss - loss).mean().item()}") - # logger.debug(f"log_ratio range: {log_ratio.min().item()} to {log_ratio.max().item()}") - # logger.debug(f"sigmoid(log_ratio) mean: {torch.sigmoid(log_ratio).mean().item()}") return total_loss, metrics diff --git a/library/train_util.py b/library/train_util.py index d58361f7..d79f34a7 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4268,17 +4268,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: default=None, help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り", ) - parser.add_argument( - "--beta_dpo", - type=int, - help="DPO KL Divergence penalty. Recommended values for SD1.5 B=2000, SDXL B=5000 / DPO KL 発散ペナルティ。SD1.5 の推奨値 B=2000、SDXL B=5000", - ) - parser.add_argument( - "--mapo_weight", - type=float, - help="MaPO weight for relative ratio loss. Recommended values of 0.1 to 0.25 / 相対比損失の ORPO 重み。推奨値は 0.1 ~ 0.25 です", - ) - if support_dreambooth: # DreamBooth training parser.add_argument(