Fix DDO arguments

This commit is contained in:
rockerBOO
2025-05-04 22:19:39 -04:00
parent fe497291b5
commit 971387ea8c
2 changed files with 44 additions and 15 deletions

View File

@@ -150,6 +150,28 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted
help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意",
)
parser.add_argument(
"--beta_dpo",
type=int,
help="DPO KL Divergence penalty. Recommended values for SD1.5 B=2000, SDXL B=5000 / DPO KL 発散ペナルティ。SD1.5 の推奨値 B=2000、SDXL B=5000",
)
parser.add_argument(
"--mapo_weight",
type=float,
help="MaPO weight for relative ratio loss. Recommended values of 0.1 to 0.25 / 相対比損失の ORPO 重み。推奨値は 0.1 0.25 です",
)
parser.add_argument(
"--ddo_alpha",
type=float,
help="Controls weight of the fake samples loss term (range: 0.5-50). Higher values increase penalty on reference model samples. Start with 4.0.",
)
parser.add_argument(
"--ddo_beta",
type=float,
help="Scaling factor for likelihood ratio (range: 0.01-0.1). Higher values create stronger separation between target and reference distributions. Start with 0.05.",
)
re_attention = re.compile(
r"""
@@ -579,6 +601,28 @@ def mapo_loss(loss: torch.Tensor, mapo_weight: float, num_train_timesteps=1000)
def ddo_loss(loss, ref_loss, ddo_alpha: float = 4.0, ddo_beta: float = 0.05):
"""
Implements Direct Discriminative Optimization (DDO) loss.
DDO bridges likelihood-based generative training with GAN objectives
by parameterizing a discriminator using the likelihood ratio between
a learnable target model and a fixed reference model.
Args:
loss: Loss value from the target model being optimized
ref_loss: Loss value from the reference model (should be detached)
ddo_alpha: Weight coefficient for the fake samples loss term.
Controls the balance between real/fake samples in training.
Higher values increase penalty on reference model samples.
ddo_beta: Scaling factor for the likelihood ratio to control gradient magnitude.
Smaller values produce a smoother optimization landscape.
Too large values can lead to numerical instability.
Returns:
tuple: (total_loss, metrics_dict)
- total_loss: Combined DDO loss for optimization
- metrics_dict: Dictionary containing component losses for monitoring
"""
ref_loss = ref_loss.detach() # Ensure no gradients to reference
log_ratio = ddo_beta * (ref_loss - loss)
real_loss = -torch.log(torch.sigmoid(log_ratio) + 1e-6).mean()
@@ -592,10 +636,6 @@ def ddo_loss(loss, ref_loss, ddo_alpha: float = 4.0, ddo_beta: float = 0.05):
"loss/ddo_sigmoid_log_ratio": torch.sigmoid(log_ratio).mean().item(),
}
# logger.debug(f"loss mean: {loss.mean().item()}, ref_loss mean: {ref_loss.mean().item()}")
# logger.debug(f"difference: {(ref_loss - loss).mean().item()}")
# logger.debug(f"log_ratio range: {log_ratio.min().item()} to {log_ratio.max().item()}")
# logger.debug(f"sigmoid(log_ratio) mean: {torch.sigmoid(log_ratio).mean().item()}")
return total_loss, metrics

View File

@@ -4268,17 +4268,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
default=None,
help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り",
)
parser.add_argument(
"--beta_dpo",
type=int,
help="DPO KL Divergence penalty. Recommended values for SD1.5 B=2000, SDXL B=5000 / DPO KL 発散ペナルティ。SD1.5 の推奨値 B=2000、SDXL B=5000",
)
parser.add_argument(
"--mapo_weight",
type=float,
help="MaPO weight for relative ratio loss. Recommended values of 0.1 to 0.25 / 相対比損失の ORPO 重み。推奨値は 0.1 0.25 です",
)
if support_dreambooth:
# DreamBooth training
parser.add_argument(