mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 00:49:40 +00:00
Fix DDO arguments
This commit is contained in:
@@ -150,6 +150,28 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted
|
||||
help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beta_dpo",
|
||||
type=int,
|
||||
help="DPO KL Divergence penalty. Recommended values for SD1.5 B=2000, SDXL B=5000 / DPO KL 発散ペナルティ。SD1.5 の推奨値 B=2000、SDXL B=5000",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mapo_weight",
|
||||
type=float,
|
||||
help="MaPO weight for relative ratio loss. Recommended values of 0.1 to 0.25 / 相対比損失の ORPO 重み。推奨値は 0.1 ~ 0.25 です",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ddo_alpha",
|
||||
type=float,
|
||||
help="Controls weight of the fake samples loss term (range: 0.5-50). Higher values increase penalty on reference model samples. Start with 4.0.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ddo_beta",
|
||||
type=float,
|
||||
help="Scaling factor for likelihood ratio (range: 0.01-0.1). Higher values create stronger separation between target and reference distributions. Start with 0.05.",
|
||||
)
|
||||
|
||||
|
||||
|
||||
re_attention = re.compile(
|
||||
r"""
|
||||
@@ -579,6 +601,28 @@ def mapo_loss(loss: torch.Tensor, mapo_weight: float, num_train_timesteps=1000)
|
||||
|
||||
|
||||
def ddo_loss(loss, ref_loss, ddo_alpha: float = 4.0, ddo_beta: float = 0.05):
|
||||
"""
|
||||
Implements Direct Discriminative Optimization (DDO) loss.
|
||||
|
||||
DDO bridges likelihood-based generative training with GAN objectives
|
||||
by parameterizing a discriminator using the likelihood ratio between
|
||||
a learnable target model and a fixed reference model.
|
||||
|
||||
Args:
|
||||
loss: Loss value from the target model being optimized
|
||||
ref_loss: Loss value from the reference model (should be detached)
|
||||
ddo_alpha: Weight coefficient for the fake samples loss term.
|
||||
Controls the balance between real/fake samples in training.
|
||||
Higher values increase penalty on reference model samples.
|
||||
ddo_beta: Scaling factor for the likelihood ratio to control gradient magnitude.
|
||||
Smaller values produce a smoother optimization landscape.
|
||||
Too large values can lead to numerical instability.
|
||||
|
||||
Returns:
|
||||
tuple: (total_loss, metrics_dict)
|
||||
- total_loss: Combined DDO loss for optimization
|
||||
- metrics_dict: Dictionary containing component losses for monitoring
|
||||
"""
|
||||
ref_loss = ref_loss.detach() # Ensure no gradients to reference
|
||||
log_ratio = ddo_beta * (ref_loss - loss)
|
||||
real_loss = -torch.log(torch.sigmoid(log_ratio) + 1e-6).mean()
|
||||
@@ -592,10 +636,6 @@ def ddo_loss(loss, ref_loss, ddo_alpha: float = 4.0, ddo_beta: float = 0.05):
|
||||
"loss/ddo_sigmoid_log_ratio": torch.sigmoid(log_ratio).mean().item(),
|
||||
}
|
||||
|
||||
# logger.debug(f"loss mean: {loss.mean().item()}, ref_loss mean: {ref_loss.mean().item()}")
|
||||
# logger.debug(f"difference: {(ref_loss - loss).mean().item()}")
|
||||
# logger.debug(f"log_ratio range: {log_ratio.min().item()} to {log_ratio.max().item()}")
|
||||
# logger.debug(f"sigmoid(log_ratio) mean: {torch.sigmoid(log_ratio).mean().item()}")
|
||||
return total_loss, metrics
|
||||
|
||||
|
||||
|
||||
@@ -4268,17 +4268,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
default=None,
|
||||
help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--beta_dpo",
|
||||
type=int,
|
||||
help="DPO KL Divergence penalty. Recommended values for SD1.5 B=2000, SDXL B=5000 / DPO KL 発散ペナルティ。SD1.5 の推奨値 B=2000、SDXL B=5000",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mapo_weight",
|
||||
type=float,
|
||||
help="MaPO weight for relative ratio loss. Recommended values of 0.1 to 0.25 / 相対比損失の ORPO 重み。推奨値は 0.1 ~ 0.25 です",
|
||||
)
|
||||
|
||||
if support_dreambooth:
|
||||
# DreamBooth training
|
||||
parser.add_argument(
|
||||
|
||||
Reference in New Issue
Block a user