mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-17 09:18:00 +00:00
Add BPO, CPO, DDO, SDPO, SimPO
Refactor Preference Optimization Refactor preference dataset Add iterator support for ImageInfo and ImageSetInfo - Supporting iterating through either ImageInfo or ImageSetInfo to clean up preference dataset implementation and support 2 or more images more cleanly without needing to duplicate code Add tests for all PO functions Add metrics for process_batch Add losses for gradient manipulation of loss parts Add normalizing gradient for stabilizing gradients Args added: mapo_beta = 0.05 cpo_beta = 0.1 bpo_beta = 0.1 bpo_lambda = 0.2 sdpo_beta = 0.02 simpo_gamma_beta_ratio = 0.25 simpo_beta = 2.0 simpo_smoothing = 0.0 simpo_loss_type = "sigmoid" ddo_alpha = 4.0 ddo_beta = 0.05
This commit is contained in:
@@ -3,6 +3,8 @@ import torch
|
||||
from torch import Tensor
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Callable, Protocol
|
||||
import math
|
||||
import argparse
|
||||
import random
|
||||
import re
|
||||
@@ -156,9 +158,57 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted
|
||||
help="DPO KL Divergence penalty. Recommended values for SD1.5 B=2000, SDXL B=5000 / DPO KL 発散ペナルティ。SD1.5 の推奨値 B=2000、SDXL B=5000",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mapo_weight",
|
||||
"--mapo_beta",
|
||||
type=float,
|
||||
help="MaPO weight for relative ratio loss. Recommended values of 0.1 to 0.25 / 相対比損失の ORPO 重み。推奨値は 0.1 ~ 0.25 です",
|
||||
help="MaPO beta regularization parameter. Recommended values of 0.01 to 0.1 / 相対比損失の MaPO ~ 0.25 です",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cpo_beta",
|
||||
type=float,
|
||||
help="CPO beta regularization parameter. Recommended value of 0.1",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bpo_beta",
|
||||
type=float,
|
||||
help="BPO beta regularization parameter. Recommended value of 0.1",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bpo_lambda",
|
||||
type=float,
|
||||
help="BPO beta regularization parameter. Recommended value of 0.0 to 0.2. -0.5 similar to DPO gradient.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sdpo_beta",
|
||||
type=float,
|
||||
help="SDPO beta regularization parameter. Recommended value of 0.02",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sdpo_epsilon",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="SDPO epsilon for clipping importance weighting. Recommended value of 0.1",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--simpo_gamma_beta_ratio",
|
||||
type=float,
|
||||
help="SimPO target reward margin term. Ensure the reward for the chosen exceeds the rejected. Recommended: 0.25-1.75",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--simpo_beta",
|
||||
type=float,
|
||||
help="SDPO beta controls the scaling of the reward difference. Recommended: 2.0-2.5",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--simpo_smoothing",
|
||||
type=float,
|
||||
help="SDPO smoothing of chosen/rejected. Recommended: 0.0",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--simpo_loss_type",
|
||||
type=str,
|
||||
default="sigmoid",
|
||||
choices=["sigmoid", "hinge"],
|
||||
help="SDPO loss type. Options: sigmoid, hinge. Default: sigmoid",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ddo_alpha",
|
||||
@@ -172,7 +222,6 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted
|
||||
)
|
||||
|
||||
|
||||
|
||||
re_attention = re.compile(
|
||||
r"""
|
||||
\\\(|
|
||||
@@ -532,7 +581,74 @@ def apply_masked_loss(loss, batch) -> torch.FloatTensor:
|
||||
return loss
|
||||
|
||||
|
||||
def diffusion_dpo_loss(loss: torch.Tensor, ref_loss: Tensor, beta_dpo: float):
|
||||
def assert_po_variables(args):
|
||||
if args.ddo_beta is not None or args.ddo_alpha is not None:
|
||||
assert args.ddo_beta is not None and args.ddo_alpha is not None, "Both ddo_beta and ddo_alpha must be set together"
|
||||
elif args.bpo_beta is not None or args.bpo_lambda is not None:
|
||||
assert args.bpo_beta is not None and args.bpo_lambda is not None, "Both bpo_beta and bpo_lambda must be set together"
|
||||
|
||||
|
||||
class PreferenceOptimization:
|
||||
def __init__(self, args):
|
||||
self.loss_fn = None
|
||||
self.loss_ref_fn = None
|
||||
|
||||
assert_po_variables(args)
|
||||
|
||||
if args.ddo_beta is not None or args.ddo_alpha is not None:
|
||||
self.algo = "DDO"
|
||||
self.loss_ref_fn = ddo_loss
|
||||
self.args = {"beta": args.ddo_beta, "alpha": args.ddo_alpha}
|
||||
elif args.bpo_beta is not None or args.bpo_lambda is not None:
|
||||
self.algo = "BPO"
|
||||
self.loss_ref_fn = bpo_loss
|
||||
self.args = {"beta": args.bpo_beta, "lambda_": args.bpo_lambda}
|
||||
elif args.beta_dpo is not None:
|
||||
self.algo = "Diffusion DPO"
|
||||
self.loss_ref_fn = diffusion_dpo_loss
|
||||
self.args = {"beta": args.beta_dpo}
|
||||
elif args.sdpo_beta is not None:
|
||||
self.algo = "SDPO"
|
||||
self.loss_ref_fn = sdpo_loss
|
||||
self.args = {"beta": args.sdpo_beta, "epsilon": args.sdpo_epsilon}
|
||||
|
||||
if args.mapo_beta is not None:
|
||||
self.algo = "MaPO"
|
||||
self.loss_fn = mapo_loss
|
||||
self.args = {"beta": args.mapo_beta}
|
||||
elif args.simpo_beta is not None:
|
||||
self.algo = "SimPO"
|
||||
self.loss_fn = simpo_loss
|
||||
self.args = {
|
||||
"beta": args.simpo_beta,
|
||||
"gamma_beta_ratio": args.simpo_gamma_beta_ratio,
|
||||
"smoothing": args.simpo_smoothing,
|
||||
"loss_type": args.simpo_loss_type,
|
||||
}
|
||||
elif args.cpo_beta is not None:
|
||||
self.algo = "CPO"
|
||||
self.loss_fn = cpo_loss
|
||||
self.args = {"beta": args.cpo_beta}
|
||||
|
||||
def is_po(self):
|
||||
return self.loss_fn is not None or self.loss_ref_fn is not None
|
||||
|
||||
def is_reference(self):
|
||||
return self.loss_ref_fn is not None
|
||||
|
||||
def __call__(self, loss: torch.Tensor, ref_loss: torch.Tensor | None = None):
|
||||
if self.is_reference():
|
||||
assert ref_loss is not None, "Reference required for this preference optimization"
|
||||
assert self.loss_ref_fn is not None, "No reference loss function"
|
||||
loss, metrics = self.loss_ref_fn(loss, ref_loss, **self.args)
|
||||
else:
|
||||
assert self.loss_fn is not None, "No loss function"
|
||||
loss, metrics = self.loss_fn(loss, **self.args)
|
||||
|
||||
return loss, metrics
|
||||
|
||||
|
||||
def diffusion_dpo_loss(loss: torch.Tensor, ref_loss: Tensor, beta: float):
|
||||
"""
|
||||
Diffusion DPO loss
|
||||
|
||||
@@ -542,103 +658,368 @@ def diffusion_dpo_loss(loss: torch.Tensor, ref_loss: Tensor, beta_dpo: float):
|
||||
beta_dpo: beta_dpo weight
|
||||
"""
|
||||
loss_w, loss_l = loss.chunk(2)
|
||||
raw_loss = 0.5 * (loss_w + loss_l)
|
||||
model_diff = loss_w - loss_l
|
||||
|
||||
ref_losses_w, ref_losses_l = ref_loss.chunk(2)
|
||||
ref_diff = ref_losses_w - ref_losses_l
|
||||
raw_ref_loss = ref_loss
|
||||
|
||||
scale_term = -0.5 * beta_dpo
|
||||
model_diff = loss_w - loss_l
|
||||
ref_diff = ref_losses_w - ref_losses_l
|
||||
|
||||
scale_term = -0.5 * beta
|
||||
inside_term = scale_term * (model_diff - ref_diff)
|
||||
loss = -1 * torch.nn.functional.logsigmoid(inside_term)
|
||||
loss = -1 * torch.nn.functional.logsigmoid(inside_term).mean(dim=(1, 2, 3))
|
||||
|
||||
implicit_acc = (inside_term > 0).sum().float() / inside_term.size(0)
|
||||
implicit_acc += 0.5 * (inside_term == 0).sum().float() / inside_term.size(0)
|
||||
|
||||
metrics = {
|
||||
"loss/diffusion_dpo_total_loss": loss.detach().mean().item(),
|
||||
"loss/diffusion_dpo_raw_loss": raw_loss.detach().mean().item(),
|
||||
"loss/diffusion_dpo_ref_loss": raw_ref_loss.detach().mean().item(),
|
||||
"loss/diffusion_dpo_ref_loss": ref_loss.detach().mean().item(),
|
||||
"loss/diffusion_dpo_implicit_acc": implicit_acc.detach().mean().item(),
|
||||
}
|
||||
|
||||
return loss, metrics
|
||||
|
||||
|
||||
def mapo_loss(loss: torch.Tensor, mapo_weight: float, num_train_timesteps=1000) -> tuple[torch.Tensor, dict[str, int | float]]:
|
||||
def mapo_loss(model_losses: torch.Tensor, beta: float, total_timesteps=1000) -> tuple[torch.Tensor, dict[str, int | float]]:
|
||||
"""
|
||||
MaPO loss
|
||||
|
||||
Paper: Margin-aware Preference Optimization for Aligning Diffusion Models without Reference
|
||||
https://mapo-t2i.github.io/
|
||||
|
||||
Args:
|
||||
loss: pairs of w, l losses B//2
|
||||
loss: pairs of w, l losses B//2, C, H, W. We want full distribution of the
|
||||
loss for numerical stability
|
||||
mapo_weight: mapo weight
|
||||
num_train_timesteps: number of timesteps
|
||||
total_timesteps: number of timesteps
|
||||
"""
|
||||
loss_w, loss_l = model_losses.chunk(2)
|
||||
|
||||
snr = 0.5
|
||||
loss_w, loss_l = loss.chunk(2)
|
||||
log_odds = (snr * loss_w) / (torch.exp(snr * loss_w) - 1) - (snr * loss_l) / (torch.exp(snr * loss_l) - 1)
|
||||
phi_coefficient = 0.5
|
||||
win_score = (phi_coefficient * loss_w) / (torch.exp(phi_coefficient * loss_w) - 1)
|
||||
lose_score = (phi_coefficient * loss_l) / (torch.exp(phi_coefficient * loss_l) - 1)
|
||||
|
||||
# Ratio loss.
|
||||
# By multiplying T to the inner term, we try to maximize the margin throughout the overall denoising process.
|
||||
ratio = torch.nn.functional.logsigmoid(log_odds * num_train_timesteps)
|
||||
ratio_losses = mapo_weight * ratio
|
||||
# Score difference loss
|
||||
score_difference = win_score - lose_score
|
||||
|
||||
# Margin loss.
|
||||
# By multiplying T in the inner term , we try to maximize the
|
||||
# margin throughout the overall denoising process.
|
||||
# T here is the number of training steps from the
|
||||
# underlying noise scheduler.
|
||||
margin = F.logsigmoid(score_difference * total_timesteps + 1e-10)
|
||||
margin_losses = beta * margin
|
||||
|
||||
# Full MaPO loss
|
||||
loss = loss_w - ratio_losses
|
||||
loss = loss_w.mean(dim=(1, 2, 3)) - margin_losses.mean(dim=(1, 2, 3))
|
||||
|
||||
metrics = {
|
||||
"loss/mapo_total": loss.detach().mean().item(),
|
||||
"loss/mapo_ratio": -ratio_losses.detach().mean().item(),
|
||||
"loss/mapo_ratio": -margin_losses.detach().mean().item(),
|
||||
"loss/mapo_w_loss": loss_w.detach().mean().item(),
|
||||
"loss/mapo_l_loss": loss_l.detach().mean().item(),
|
||||
"loss/mapo_win_score": ((snr * loss_w) / (torch.exp(snr * loss_w) - 1)).detach().mean().item(),
|
||||
"loss/mapo_lose_score": ((snr * loss_l) / (torch.exp(snr * loss_l) - 1)).detach().mean().item(),
|
||||
"loss/mapo_score_difference": score_difference.detach().mean().item(),
|
||||
"loss/mapo_win_score": win_score.detach().mean().item(),
|
||||
"loss/mapo_lose_score": lose_score.detach().mean().item(),
|
||||
}
|
||||
|
||||
return loss, metrics
|
||||
|
||||
|
||||
def ddo_loss(loss, ref_loss, ddo_alpha: float = 4.0, ddo_beta: float = 0.05):
|
||||
def ddo_loss(loss, ref_loss, w_t: float, ddo_alpha: float = 4.0, ddo_beta: float = 0.05):
|
||||
"""
|
||||
Implements Direct Discriminative Optimization (DDO) loss.
|
||||
|
||||
|
||||
DDO bridges likelihood-based generative training with GAN objectives
|
||||
by parameterizing a discriminator using the likelihood ratio between
|
||||
a learnable target model and a fixed reference model.
|
||||
|
||||
|
||||
Args:
|
||||
loss: Loss value from the target model being optimized
|
||||
ref_loss: Loss value from the reference model (should be detached)
|
||||
ddo_alpha: Weight coefficient for the fake samples loss term.
|
||||
loss: Target model loss
|
||||
ref_loss: Reference model loss (should be detached)
|
||||
w_t: weight at timestep
|
||||
ddo_alpha: Weight coefficient for the fake samples loss term.
|
||||
Controls the balance between real/fake samples in training.
|
||||
Higher values increase penalty on reference model samples.
|
||||
ddo_beta: Scaling factor for the likelihood ratio to control gradient magnitude.
|
||||
Smaller values produce a smoother optimization landscape.
|
||||
Too large values can lead to numerical instability.
|
||||
|
||||
|
||||
Returns:
|
||||
tuple: (total_loss, metrics_dict)
|
||||
- total_loss: Combined DDO loss for optimization
|
||||
- metrics_dict: Dictionary containing component losses for monitoring
|
||||
"""
|
||||
ref_loss = ref_loss.detach() # Ensure no gradients to reference
|
||||
log_ratio = ddo_beta * (ref_loss - loss)
|
||||
real_loss = -torch.log(torch.sigmoid(log_ratio) + 1e-6).mean()
|
||||
fake_loss = -ddo_alpha * torch.log(1 - torch.sigmoid(log_ratio) + 1e-6).mean()
|
||||
total_loss = real_loss + fake_loss
|
||||
|
||||
# Log likelihood from weighted loss
|
||||
target_logp = -torch.sum(w_t * loss, dim=(1, 2, 3))
|
||||
ref_logp = -torch.sum(w_t * ref_loss, dim=(1, 2, 3))
|
||||
|
||||
# ∆xt,t,ε = -w(t) * [||εθ(xt,t) - ε||²₂ - ||εθref(xt,t) - ε||²₂]
|
||||
delta = target_logp - ref_logp
|
||||
|
||||
# log_ratio = β * log pθ(x)/pθref(x)
|
||||
log_ratio = ddo_beta * delta
|
||||
|
||||
# E_pdata[log σ(-log_ratio)]
|
||||
data_loss = -F.logsigmoid(log_ratio)
|
||||
|
||||
# αE_pθref[log(1 - σ(log_ratio))]
|
||||
ref_loss_term = -ddo_alpha * F.logsigmoid(-log_ratio)
|
||||
|
||||
total_loss = data_loss + ref_loss_term
|
||||
|
||||
metrics = {
|
||||
"loss/ddo_real": real_loss.detach().item(),
|
||||
"loss/ddo_fake": fake_loss.detach().item(),
|
||||
"loss/ddo_total": total_loss.detach().item(),
|
||||
"loss/ddo_data": data_loss.detach().mean().item(),
|
||||
"loss/ddo_ref": ref_loss_term.detach().mean().item(),
|
||||
"loss/ddo_total": total_loss.detach().mean().item(),
|
||||
"loss/ddo_sigmoid_log_ratio": torch.sigmoid(log_ratio).mean().item(),
|
||||
}
|
||||
|
||||
return total_loss, metrics
|
||||
|
||||
|
||||
def cpo_loss(loss: torch.Tensor, beta: float = 0.1) -> tuple[torch.Tensor, dict[str, int | float]]:
|
||||
"""
|
||||
CPO Loss = L(π_θ; U) - E[log π_θ(y_w|x)]
|
||||
|
||||
Where L(π_θ; U) is the uniform reference DPO loss and the second term
|
||||
is a behavioral cloning regularizer on preferred data.
|
||||
|
||||
Args:
|
||||
loss: Losses of w and l B, C, H, W
|
||||
beta: Weight for log ratio (Similar to Diffusion DPO)
|
||||
"""
|
||||
# L(π_θ; U) - DPO loss with uniform reference (no reference model needed)
|
||||
loss_w, loss_l = loss.chunk(2)
|
||||
|
||||
# Prevent values from being too small, causing large gradients
|
||||
log_ratio = torch.max(loss_w - loss_l, torch.full_like(loss_w, 0.01))
|
||||
uniform_dpo_loss = -F.logsigmoid(beta * log_ratio).mean()
|
||||
|
||||
# Behavioral cloning regularizer: -E[log π_θ(y_w|x)]
|
||||
bc_regularizer = -loss_w.mean()
|
||||
|
||||
# Total CPO loss
|
||||
cpo_loss = uniform_dpo_loss + bc_regularizer
|
||||
|
||||
metrics = {}
|
||||
metrics["loss/cpo_reward_margin"] = uniform_dpo_loss.detach().mean().item()
|
||||
|
||||
return cpo_loss, metrics
|
||||
|
||||
|
||||
def bpo_loss(loss: Tensor, ref_loss: Tensor, beta: float, lambda_: float) -> tuple[Tensor, dict[str, int | float]]:
|
||||
"""
|
||||
Bregman Preference Optimization
|
||||
|
||||
Paper: Preference Optimization by Estimating the
|
||||
Ratio of the Data Distribution
|
||||
|
||||
Computes the BPO loss
|
||||
loss: Loss from the training model B
|
||||
ref_loss: Loss from the reference model B
|
||||
param beta : Regularization coefficient
|
||||
param lambda : hyperparameter for SBA
|
||||
"""
|
||||
# Compute the model ratio corresponding to Line 4 of Algorithm 1.
|
||||
loss_w, loss_l = loss.chunk(2)
|
||||
ref_loss_w, ref_loss_l = ref_loss.chunk(2)
|
||||
|
||||
logits = loss_w - loss_l - ref_loss_w + ref_loss_l
|
||||
reward_margin = beta * logits
|
||||
R = torch.exp(-reward_margin)
|
||||
|
||||
# Clip R values to be no smaller than 0.01 for training stability
|
||||
R = torch.max(R, torch.full_like(R, 0.01))
|
||||
|
||||
# Compute the loss according to the function h , following Line 5 of Algorithm 1.
|
||||
if lambda_ == 0.0:
|
||||
losses = R + torch.log(R)
|
||||
else:
|
||||
losses = R ** (lambda_ + 1) - ((lambda_ + 1) / lambda_) * (R ** (-lambda_))
|
||||
losses /= 4 * (1 + lambda_)
|
||||
|
||||
metrics = {}
|
||||
metrics["loss/bpo_reward_margin"] = reward_margin.detach().mean().item()
|
||||
metrics["loss/bpo_R"] = R.detach().mean().item()
|
||||
return losses.mean(dim=(1, 2, 3)), metrics
|
||||
|
||||
|
||||
def kto_loss(loss: Tensor, ref_loss: Tensor, kl_loss: Tensor, ref_kl_loss: Tensor, w_t=1.0, undesireable_w_t=1.0, beta=0.1):
|
||||
"""
|
||||
KTO: Model Alignment as Prospect Theoretic Optimization
|
||||
https://arxiv.org/abs/2402.01306
|
||||
|
||||
Compute the Kahneman-Tversky loss for a batch of policy and reference model losses.
|
||||
If generation y ~ p_desirable, we have the 'desirable' loss:
|
||||
L(x, y) := 1 - sigmoid(beta * ([log p_policy(y|x) - log p_reference(y|x)] - KL(p_policy || p_reference)))
|
||||
If generation y ~ p_undesirable, we have the 'undesirable' loss:
|
||||
L(x, y) := 1 - sigmoid(beta * (KL(p_policy || p_reference) - [log p_policy(y|x) - log p_reference(y|x)]))
|
||||
The desirable losses are weighed by w_t.
|
||||
The undesirable losses are weighed by undesirable_w_t.
|
||||
This should be used to address imbalances in the ratio of desirable:undesirable examples respectively.
|
||||
The KL term is estimated by matching x with unrelated outputs y', then calculating the average log ratio
|
||||
log p_policy(y'|x) - log p_reference(y'|x). Doing so avoids the requirement that there be equal numbers of
|
||||
desirable and undesirable examples in the microbatch. It can be estimated differently: the 'z1' estimate
|
||||
takes the mean reward clamped to be non-negative; the 'z2' estimate takes the mean over rewards when y|x
|
||||
is more probable under the policy than the reference.
|
||||
"""
|
||||
loss_w, loss_l = loss.chunk(2)
|
||||
ref_loss_w, ref_loss_l = ref_loss.chunk(2)
|
||||
|
||||
# Convert losses to rewards (negative loss = positive reward)
|
||||
chosen_rewards = -(loss_w - loss_l)
|
||||
rejected_rewards = -(ref_loss_w - ref_loss_l)
|
||||
KL_rewards = -(kl_loss - ref_kl_loss)
|
||||
|
||||
# Estimate KL divergence using unmatched samples
|
||||
KL_estimate = KL_rewards.mean().clamp(min=0)
|
||||
|
||||
losses = []
|
||||
|
||||
# Desirable (chosen) samples: we want reward > KL
|
||||
if chosen_rewards.shape[0] > 0:
|
||||
chosen_kto_losses = w_t * (1 - F.sigmoid(beta * (chosen_rewards - KL_estimate)))
|
||||
losses.append(chosen_kto_losses)
|
||||
|
||||
# Undesirable (rejected) samples: we want KL > reward
|
||||
if rejected_rewards.shape[0] > 0:
|
||||
rejected_kto_losses = undesireable_w_t * (1 - F.sigmoid(beta * (KL_estimate - rejected_rewards)))
|
||||
losses.append(rejected_kto_losses)
|
||||
|
||||
if losses:
|
||||
total_loss = torch.cat(losses, 0).mean()
|
||||
else:
|
||||
total_loss = torch.tensor(0.0)
|
||||
|
||||
return total_loss
|
||||
|
||||
|
||||
def ipo_loss(loss: Tensor, ref_loss: Tensor, tau=0.1):
|
||||
"""
|
||||
IPO: Iterative Preference Optimization for Text-to-Video Generation
|
||||
https://arxiv.org/abs/2502.02088
|
||||
"""
|
||||
loss_w, loss_l = loss.chunk(2)
|
||||
ref_loss_w, ref_loss_l = ref_loss.chunk(2)
|
||||
|
||||
chosen_rewards = loss_w - ref_loss_w
|
||||
rejected_rewards = loss_l - ref_loss_l
|
||||
|
||||
losses = (chosen_rewards - rejected_rewards - (1 / (2 * tau))).pow(2)
|
||||
|
||||
metrics: dict[str, int | float] = {}
|
||||
metrics["loss/ipo_chosen_rewards"] = chosen_rewards.detach().mean().item()
|
||||
metrics["loss/ipo_rejected_rewards"] = rejected_rewards.detach().mean().item()
|
||||
|
||||
return losses, metrics
|
||||
|
||||
|
||||
def compute_importance_weight(loss: Tensor, ref_loss: Tensor) -> Tensor:
|
||||
"""
|
||||
Compute importance weight w(t) = p_θ(x_{t-1}|x_t) / q(x_{t-1}|x_t, x_0)
|
||||
|
||||
Args:
|
||||
loss: Training model loss B, ...
|
||||
ref_loss: Reference model loss B, ...
|
||||
"""
|
||||
# Approximate importance weight (higher when model prediction is better)
|
||||
w_t = torch.exp(-loss + ref_loss) # [batch_size]
|
||||
return w_t
|
||||
|
||||
|
||||
def clip_importance_weight(w_t: Tensor, epsilon=0.1) -> Tensor:
|
||||
"""
|
||||
Clip importance weights: w̃(t) = clip(w(t), 1-ε, 1+ε)
|
||||
"""
|
||||
return torch.clamp(w_t, 1 - epsilon, 1 + epsilon)
|
||||
|
||||
|
||||
def sdpo_loss(loss: Tensor, ref_loss: Tensor, beta=0.02, epsilon=0.1) -> tuple[Tensor, dict[str, int | float]]:
|
||||
"""
|
||||
SDPO Loss (Formula 11):
|
||||
L_SDPO(θ) = -E[log σ(w̃_θ(t) · ψ(x^w_{t-1}|x^w_t) - w̃_θ(t) · ψ(x^l_{t-1}|x^l_t))]
|
||||
|
||||
where ψ(x_{t-1}|x_t) = β · log(p*_θ(x_{t-1}|x_t) / p_ref(x_{t-1}|x_t))
|
||||
"""
|
||||
|
||||
loss_w, loss_l = loss.chunk(2)
|
||||
ref_loss_w, ref_loss_l = ref_loss.chunk(2)
|
||||
|
||||
# Compute step-wise importance weights for inverse weighting
|
||||
w_theta_w = compute_importance_weight(loss_w, ref_loss_w)
|
||||
w_theta_l = compute_importance_weight(loss_l, ref_loss_l)
|
||||
|
||||
# Inverse weighting with clipping (Formula 12)
|
||||
w_theta_w_inv = clip_importance_weight(1.0 / (w_theta_w + 1e-8), epsilon=epsilon)
|
||||
w_theta_l_inv = clip_importance_weight(1.0 / (w_theta_l + 1e-8), epsilon=epsilon)
|
||||
w_theta_max = torch.max(w_theta_w_inv, w_theta_l_inv) # [batch_size]
|
||||
|
||||
# Compute ψ terms: ψ(x_{t-1}|x_t) = β · log(p*_θ(x_{t-1}|x_t) / p_ref(x_{t-1}|x_t))
|
||||
# Approximated using negative MSE differences
|
||||
|
||||
# For preferred samples
|
||||
log_ratio_w = -loss_w + ref_loss_w
|
||||
psi_w = beta * log_ratio_w # [batch_size]
|
||||
|
||||
# For dispreferred samples
|
||||
log_ratio_l = -loss_l + ref_loss_l
|
||||
psi_l = beta * log_ratio_l # [batch_size]
|
||||
|
||||
print((w_theta_max * psi_w - w_theta_max * psi_l).mean())
|
||||
|
||||
# Final SDPO loss computation
|
||||
logits = w_theta_max * psi_w - w_theta_max * psi_l # [batch_size]
|
||||
sigmoid_loss = -torch.log(torch.sigmoid(logits)) # [batch_size]
|
||||
|
||||
metrics: dict[str, int | float] = {}
|
||||
metrics["loss/sdpo_log_ratio_w"] = log_ratio_w.detach().mean().item()
|
||||
metrics["loss/sdpo_log_ratio_l"] = log_ratio_l.detach().mean().item()
|
||||
metrics["loss/sdpo_w_theta_max"] = w_theta_max.detach().mean().item()
|
||||
metrics["loss/sdpo_w_theta_w"] = w_theta_w.detach().mean().item()
|
||||
metrics["loss/sdpo_w_theta_l"] = w_theta_l.detach().mean().item()
|
||||
|
||||
return sigmoid_loss.mean(dim=(1, 2, 3)), metrics
|
||||
|
||||
|
||||
def simpo_loss(
|
||||
loss: torch.Tensor, loss_type: str = "sigmoid", gamma_beta_ratio: float = 0.25, beta: float = 2.0, smoothing: float = 0.0
|
||||
) -> tuple[torch.Tensor, dict[str, int | float]]:
|
||||
"""
|
||||
Compute the SimPO loss for a batch of policy and reference model
|
||||
|
||||
SimPO: Simple Preference Optimization with a Reference-Free Reward
|
||||
https://arxiv.org/abs/2405.14734
|
||||
"""
|
||||
loss_w, loss_l = loss.chunk(2)
|
||||
|
||||
pi_logratios = loss_w - loss_l
|
||||
pi_logratios = pi_logratios
|
||||
logits = pi_logratios - gamma_beta_ratio
|
||||
|
||||
if loss_type == "sigmoid":
|
||||
losses = -F.logsigmoid(beta * logits) * (1 - smoothing) - F.logsigmoid(-beta * logits) * smoothing
|
||||
elif loss_type == "hinge":
|
||||
losses = torch.relu(1 - beta * logits)
|
||||
else:
|
||||
raise ValueError(f"Unknown loss type: {loss_type}. Should be one of ['sigmoid', 'hinge']")
|
||||
|
||||
metrics = {}
|
||||
metrics["loss/simpo_chosen_rewards"] = (beta * loss_w.detach()).mean().item()
|
||||
metrics["loss/simpo_rejected_rewards"] = (beta * loss_l.detach()).mean().item()
|
||||
metrics["loss/simpo_logratio"] = (beta * logits.detach()).mean().item()
|
||||
|
||||
return losses, metrics
|
||||
|
||||
|
||||
def normalize_gradients(model):
|
||||
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach()) for p in model.parameters() if p.grad is not None]))
|
||||
if total_norm > 0:
|
||||
for p in model.parameters():
|
||||
if p.grad is not None:
|
||||
p.grad.div_(total_norm)
|
||||
|
||||
|
||||
"""
|
||||
##########################################
|
||||
# Perlin Noise
|
||||
|
||||
@@ -11,7 +11,7 @@ from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjecti
|
||||
|
||||
# TODO remove circular import by moving ImageInfo to a separate file
|
||||
# from library.train_util import ImageInfo
|
||||
|
||||
# from library.train_util import ImageSetInfo
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
@@ -514,6 +514,7 @@ class LatentsCachingStrategy:
|
||||
info.latents_flipped = flipped_latent
|
||||
info.alpha_mask = alpha_mask
|
||||
|
||||
|
||||
def load_latents_from_disk(
|
||||
self, npz_path: str, bucket_reso: Tuple[int, int]
|
||||
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
||||
|
||||
@@ -209,19 +209,71 @@ class ImageInfo:
|
||||
self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
|
||||
self.resize_interpolation: Optional[str] = None
|
||||
|
||||
self._current = 0
|
||||
|
||||
class ImageSetInfo(ImageInfo):
|
||||
def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None:
|
||||
super().__init__(image_key, num_repeats, caption, is_reg, absolute_path)
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
self.absolute_paths = [absolute_path]
|
||||
self.captions = [caption]
|
||||
self.image_sizes = []
|
||||
def __next__(self):
|
||||
if self._current < 1:
|
||||
self._current += 1
|
||||
return self
|
||||
else:
|
||||
self.current = 0
|
||||
raise StopIteration
|
||||
|
||||
def add(self, absolute_path, caption, size):
|
||||
self.absolute_paths.append(absolute_path)
|
||||
self.captions.append(caption)
|
||||
self.image_sizes.append(size)
|
||||
def __len__(self):
|
||||
return 1
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def _pin_tensor(tensor):
|
||||
return tensor.pin_memory() if tensor is not None else tensor
|
||||
|
||||
def pin_memory(self):
|
||||
self.latents = self._pin_tensor(self.latents)
|
||||
self.latents_flipped = self._pin_tensor(self.latents_flipped)
|
||||
self.text_encoder_outputs1 = self._pin_tensor(self.text_encoder_outputs1)
|
||||
self.text_encoder_outputs2 = self._pin_tensor(self.text_encoder_outputs2)
|
||||
self.text_encoder_pool2 = self._pin_tensor(self.text_encoder_pool2)
|
||||
self.alpha_mask = self._pin_tensor(self.alpha_mask)
|
||||
return self
|
||||
|
||||
|
||||
class ImageSetInfo:
|
||||
def __init__(self, images: list[ImageInfo] = []) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.images = images
|
||||
self.current = 0
|
||||
|
||||
@property
|
||||
def image_key(self):
|
||||
return self.images[0].image_key
|
||||
|
||||
@property
|
||||
def bucket_reso(self):
|
||||
return self.images[0].bucket_reso
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.current < len(self.images):
|
||||
result = self.images[self.current]
|
||||
self.current += 1
|
||||
return result
|
||||
else:
|
||||
self.current = 0
|
||||
raise StopIteration
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.images[item]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images)
|
||||
|
||||
|
||||
class BucketManager:
|
||||
@@ -727,7 +779,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
resolution: Optional[Tuple[int, int]],
|
||||
network_multiplier: float,
|
||||
debug_dataset: bool,
|
||||
resize_interpolation: Optional[str] = None
|
||||
resize_interpolation: Optional[str] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -763,10 +815,12 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
self.image_transforms = IMAGE_TRANSFORMS
|
||||
|
||||
if resize_interpolation is not None:
|
||||
assert validate_interpolation_fn(resize_interpolation), f"Resize interpolation \"{resize_interpolation}\" is not a valid interpolation"
|
||||
assert validate_interpolation_fn(
|
||||
resize_interpolation
|
||||
), f'Resize interpolation "{resize_interpolation}" is not a valid interpolation'
|
||||
self.resize_interpolation = resize_interpolation
|
||||
|
||||
self.image_data: Dict[str, ImageInfo] = {}
|
||||
self.image_data: Dict[str, ImageInfo | ImageSetInfo] = {}
|
||||
self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {}
|
||||
|
||||
self.replacements = {}
|
||||
@@ -1019,7 +1073,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
input_ids = torch.stack(iids_list) # 3,77
|
||||
return input_ids
|
||||
|
||||
def register_image(self, info: ImageInfo, subset: BaseSubset):
|
||||
def register_image(self, info: ImageInfo | ImageSetInfo, subset: BaseSubset):
|
||||
self.image_data[info.image_key] = info
|
||||
self.image_to_subset[info.image_key] = subset
|
||||
|
||||
@@ -1029,9 +1083,10 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
min_size and max_size are ignored when enable_bucket is False
|
||||
"""
|
||||
logger.info("loading image sizes.")
|
||||
for info in tqdm(self.image_data.values()):
|
||||
if info.image_size is None:
|
||||
info.image_size = self.get_image_size(info.absolute_path)
|
||||
for infos in tqdm(self.image_data.values()):
|
||||
for info in infos:
|
||||
if info.image_size is None:
|
||||
info.image_size = self.get_image_size(info.absolute_path)
|
||||
|
||||
# # run in parallel
|
||||
# max_workers = min(os.cpu_count(), len(self.image_data)) # TODO consider multi-gpu (processes)
|
||||
@@ -1073,26 +1128,37 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
)
|
||||
|
||||
img_ar_errors = []
|
||||
for image_info in self.image_data.values():
|
||||
image_width, image_height = image_info.image_size
|
||||
image_info.bucket_reso, image_info.resized_size, ar_error = self.bucket_manager.select_bucket(
|
||||
image_width, image_height
|
||||
)
|
||||
for image_infos in self.image_data.values():
|
||||
for image_info in image_infos:
|
||||
image_width, image_height = image_info.image_size
|
||||
image_info.bucket_reso, image_info.resized_size, ar_error = self.bucket_manager.select_bucket(
|
||||
image_width, image_height
|
||||
)
|
||||
|
||||
# logger.info(image_info.image_key, image_info.bucket_reso)
|
||||
img_ar_errors.append(abs(ar_error))
|
||||
# logger.info(image_info.image_key, image_info.bucket_reso)
|
||||
img_ar_errors.append(abs(ar_error))
|
||||
|
||||
self.bucket_manager.sort()
|
||||
else:
|
||||
self.bucket_manager = BucketManager(False, (self.width, self.height), None, None, None)
|
||||
self.bucket_manager.set_predefined_resos([(self.width, self.height)]) # ひとつの固定サイズbucketのみ
|
||||
for image_info in self.image_data.values():
|
||||
image_width, image_height = image_info.image_size
|
||||
image_info.bucket_reso, image_info.resized_size, _ = self.bucket_manager.select_bucket(image_width, image_height)
|
||||
for image_infos in self.image_data.values():
|
||||
for info in image_infos:
|
||||
image_width, image_height = info.image_size
|
||||
info.bucket_reso, info.resized_size, _ = self.bucket_manager.select_bucket(image_width, image_height)
|
||||
|
||||
for image_info in self.image_data.values():
|
||||
for _ in range(image_info.num_repeats):
|
||||
self.bucket_manager.add_image(image_info.bucket_reso, image_info.image_key)
|
||||
for infos in self.image_data.values():
|
||||
bucket_reso = None
|
||||
for info in infos:
|
||||
if bucket_reso is None:
|
||||
bucket_reso = info.bucket_reso
|
||||
else:
|
||||
assert (
|
||||
bucket_reso == info.bucket_reso
|
||||
), f"Image pair not found in same bucket. {info.image_key} {bucket_reso} {info.bucket_reso}"
|
||||
|
||||
for _ in range(infos[0].num_repeats):
|
||||
self.bucket_manager.add_image(infos.bucket_reso, infos.image_key)
|
||||
|
||||
# bucket情報を表示、格納する
|
||||
if self.enable_bucket:
|
||||
@@ -1176,7 +1242,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
and self.random_crop == other.random_crop
|
||||
)
|
||||
|
||||
batch: List[ImageInfo] = []
|
||||
batch: list[ImageInfo] = []
|
||||
current_condition = None
|
||||
|
||||
# support multiple-gpus
|
||||
@@ -1184,7 +1250,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
process_index = accelerator.process_index
|
||||
|
||||
# define a function to submit a batch to cache
|
||||
def submit_batch(batch, cond):
|
||||
def submit_batch(batch: list[ImageInfo], cond):
|
||||
for info in batch:
|
||||
if info.image is not None and isinstance(info.image, Future):
|
||||
info.image = info.image.result() # future to image
|
||||
@@ -1203,52 +1269,52 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
try:
|
||||
# iterate images
|
||||
logger.info("caching latents...")
|
||||
for i, info in enumerate(tqdm(image_infos)):
|
||||
subset = self.image_to_subset[info.image_key]
|
||||
for i, infos in enumerate(tqdm(image_infos)):
|
||||
subset = self.image_to_subset[infos[0].image_key]
|
||||
|
||||
if info.latents_npz is not None: # fine tuning dataset
|
||||
continue
|
||||
|
||||
# check disk cache exists and size of latents
|
||||
if caching_strategy.cache_to_disk:
|
||||
# info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix
|
||||
info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.image_size)
|
||||
|
||||
# if the modulo of num_processes is not equal to process_index, skip caching
|
||||
# this makes each process cache different latents
|
||||
if i % num_processes != process_index:
|
||||
for info in infos:
|
||||
if info.latents_npz is not None: # fine tuning dataset
|
||||
continue
|
||||
|
||||
# print(f"{process_index}/{num_processes} {i}/{len(image_infos)} {info.latents_npz}")
|
||||
# check disk cache exists and size of latents
|
||||
if caching_strategy.cache_to_disk:
|
||||
# info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix
|
||||
info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.image_size)
|
||||
|
||||
cache_available = caching_strategy.is_disk_cached_latents_expected(
|
||||
info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask
|
||||
)
|
||||
if cache_available: # do not add to batch
|
||||
continue
|
||||
# if the modulo of num_processes is not equal to process_index, skip caching
|
||||
# this makes each process cache different latents
|
||||
if i % num_processes != process_index:
|
||||
continue
|
||||
|
||||
# if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty
|
||||
condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop)
|
||||
if len(batch) > 0 and current_condition != condition:
|
||||
submit_batch(batch, current_condition)
|
||||
batch = []
|
||||
# print(f"{process_index}/{num_processes} {i}/{len(image_infos)} {info.latents_npz}")
|
||||
|
||||
if info.image is None:
|
||||
# load image in parallel
|
||||
info.image = executor.submit(load_image, info.absolute_path, condition.alpha_mask)
|
||||
cache_available = caching_strategy.is_disk_cached_latents_expected(
|
||||
info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask
|
||||
)
|
||||
if cache_available: # do not add to batch
|
||||
continue
|
||||
|
||||
batch.append(info)
|
||||
current_condition = condition
|
||||
# if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty
|
||||
condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop)
|
||||
if len(batch) > 0 and current_condition != condition:
|
||||
submit_batch(batch, current_condition)
|
||||
batch = []
|
||||
|
||||
# if number of data in batch is enough, flush the batch
|
||||
if len(batch) >= caching_strategy.batch_size:
|
||||
submit_batch(batch, current_condition)
|
||||
batch = []
|
||||
current_condition = None
|
||||
if info.image is None:
|
||||
# load image in parallel
|
||||
info.image = executor.submit(load_image, info.absolute_path, condition.alpha_mask)
|
||||
|
||||
batch.append(info)
|
||||
current_condition = condition
|
||||
|
||||
# if number of data in batch is enough, flush the batch
|
||||
if len(batch) >= caching_strategy.batch_size:
|
||||
submit_batch(batch, current_condition)
|
||||
batch = []
|
||||
current_condition = None
|
||||
|
||||
if len(batch) > 0:
|
||||
submit_batch(batch, current_condition)
|
||||
|
||||
finally:
|
||||
executor.shutdown()
|
||||
|
||||
@@ -1277,44 +1343,44 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
and self.random_crop == other.random_crop
|
||||
)
|
||||
|
||||
batches: List[Tuple[Condition, List[ImageInfo]]] = []
|
||||
batch: List[ImageInfo] = []
|
||||
batches: list[tuple[Condition, list[ImageInfo | ImageSetInfo]]] = []
|
||||
batch: list[ImageInfo | ImageSetInfo] = []
|
||||
current_condition = None
|
||||
|
||||
logger.info("checking cache validity...")
|
||||
for info in tqdm(image_infos):
|
||||
subset = self.image_to_subset[info.image_key]
|
||||
|
||||
if info.latents_npz is not None: # fine tuning dataset
|
||||
continue
|
||||
|
||||
# check disk cache exists and size of latents
|
||||
if cache_to_disk:
|
||||
info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix
|
||||
if not is_main_process: # store to info only
|
||||
for infos in tqdm(image_infos):
|
||||
subset = self.image_to_subset[infos[0].image_key]
|
||||
for info in infos:
|
||||
if info.latents_npz is not None: # fine tuning dataset
|
||||
continue
|
||||
|
||||
cache_available = is_disk_cached_latents_is_expected(
|
||||
info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask
|
||||
)
|
||||
# check disk cache exists and size of latents
|
||||
if cache_to_disk:
|
||||
info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix
|
||||
if not is_main_process: # store to info only
|
||||
continue
|
||||
|
||||
if cache_available: # do not add to batch
|
||||
continue
|
||||
cache_available = is_disk_cached_latents_is_expected(
|
||||
info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask
|
||||
)
|
||||
|
||||
# if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty
|
||||
condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop)
|
||||
if len(batch) > 0 and current_condition != condition:
|
||||
batches.append((current_condition, batch))
|
||||
batch = []
|
||||
if cache_available: # do not add to batch
|
||||
continue
|
||||
|
||||
batch.append(info)
|
||||
current_condition = condition
|
||||
# if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty
|
||||
condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop)
|
||||
if len(batch) > 0 and current_condition != condition:
|
||||
batches.append((current_condition, batch))
|
||||
batch = []
|
||||
|
||||
# if number of data in batch is enough, flush the batch
|
||||
if len(batch) >= vae_batch_size:
|
||||
batches.append((current_condition, batch))
|
||||
batch = []
|
||||
current_condition = None
|
||||
batch.append(info)
|
||||
current_condition = condition
|
||||
|
||||
# if number of data in batch is enough, flush the batch
|
||||
if len(batch) >= vae_batch_size:
|
||||
batches.append((current_condition, batch))
|
||||
batch = []
|
||||
current_condition = None
|
||||
|
||||
if len(batch) > 0:
|
||||
batches.append((current_condition, batch))
|
||||
@@ -1348,27 +1414,28 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
process_index = accelerator.process_index
|
||||
|
||||
logger.info("checking cache validity...")
|
||||
for i, info in enumerate(tqdm(image_infos)):
|
||||
# check disk cache exists and size of text encoder outputs
|
||||
if caching_strategy.cache_to_disk:
|
||||
te_out_npz = caching_strategy.get_outputs_npz_path(info.absolute_path)
|
||||
info.text_encoder_outputs_npz = te_out_npz # set npz filename regardless of cache availability
|
||||
for i, infos in enumerate(tqdm(image_infos)):
|
||||
for info in infos:
|
||||
# check disk cache exists and size of text encoder outputs
|
||||
if caching_strategy.cache_to_disk:
|
||||
te_out_npz = caching_strategy.get_outputs_npz_path(info.absolute_path)
|
||||
info.text_encoder_outputs_npz = te_out_npz # set npz filename regardless of cache availability
|
||||
|
||||
# if the modulo of num_processes is not equal to process_index, skip caching
|
||||
# this makes each process cache different text encoder outputs
|
||||
if i % num_processes != process_index:
|
||||
continue
|
||||
# if the modulo of num_processes is not equal to process_index, skip caching
|
||||
# this makes each process cache different text encoder outputs
|
||||
if i % num_processes != process_index:
|
||||
continue
|
||||
|
||||
cache_available = caching_strategy.is_disk_cached_outputs_expected(te_out_npz)
|
||||
if cache_available: # do not add to batch
|
||||
continue
|
||||
cache_available = caching_strategy.is_disk_cached_outputs_expected(te_out_npz)
|
||||
if cache_available: # do not add to batch
|
||||
continue
|
||||
|
||||
batch.append(info)
|
||||
batch.append(info)
|
||||
|
||||
# if number of data in batch is enough, flush the batch
|
||||
if len(batch) >= batch_size:
|
||||
batches.append(batch)
|
||||
batch = []
|
||||
# if number of data in batch is enough, flush the batch
|
||||
if len(batch) >= batch_size:
|
||||
batches.append(batch)
|
||||
batch = []
|
||||
|
||||
if len(batch) > 0:
|
||||
batches.append(batch)
|
||||
@@ -1526,9 +1593,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
def load_and_transform_image(self, subset, image_info, absolute_path, flipped):
|
||||
# 画像を読み込み、必要ならcropする
|
||||
|
||||
img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(
|
||||
subset, absolute_path, subset.alpha_mask
|
||||
)
|
||||
img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(subset, absolute_path, subset.alpha_mask)
|
||||
im_h, im_w = img.shape[0:2]
|
||||
|
||||
if self.enable_bucket:
|
||||
@@ -1550,9 +1615,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
img = img[:, p : p + self.width]
|
||||
|
||||
im_h, im_w = img.shape[0:2]
|
||||
assert (
|
||||
im_h == self.height and im_w == self.width
|
||||
), f"image size is small / 画像サイズが小さいようです: {absolute_path}"
|
||||
assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {absolute_path}"
|
||||
|
||||
original_size = [im_w, im_h]
|
||||
crop_ltrb = (0, 0, 0, 0)
|
||||
@@ -1679,87 +1742,69 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
custom_attributes = []
|
||||
|
||||
for image_key in bucket[image_index : image_index + bucket_batch_size]:
|
||||
image_info = self.image_data[image_key]
|
||||
image_infos = self.image_data[image_key]
|
||||
subset = self.image_to_subset[image_key]
|
||||
for image_info in image_infos:
|
||||
custom_attributes.append(subset.custom_attributes)
|
||||
|
||||
custom_attributes.append(subset.custom_attributes)
|
||||
# in case of fine tuning, is_reg is always False
|
||||
loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
|
||||
|
||||
# in case of fine tuning, is_reg is always False
|
||||
loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
|
||||
flipped = subset.flip_aug and random.random() < 0.5 # not flipped or flipped with 50% chance
|
||||
|
||||
flipped = subset.flip_aug and random.random() < 0.5 # not flipped or flipped with 50% chance
|
||||
# image/latentsを処理する
|
||||
if image_info.latents is not None: # cache_latents=Trueの場合
|
||||
original_size = image_info.latents_original_size
|
||||
crop_ltrb = image_info.latents_crop_ltrb # calc values later if flipped
|
||||
if not flipped:
|
||||
latents = image_info.latents
|
||||
alpha_mask = image_info.alpha_mask
|
||||
else:
|
||||
latents = image_info.latents_flipped
|
||||
alpha_mask = None if image_info.alpha_mask is None else torch.flip(image_info.alpha_mask, [1])
|
||||
|
||||
# image/latentsを処理する
|
||||
if image_info.latents is not None: # cache_latents=Trueの場合
|
||||
original_size = image_info.latents_original_size
|
||||
crop_ltrb = image_info.latents_crop_ltrb # calc values later if flipped
|
||||
if not flipped:
|
||||
latents = image_info.latents
|
||||
alpha_mask = image_info.alpha_mask
|
||||
else:
|
||||
latents = image_info.latents_flipped
|
||||
alpha_mask = None if image_info.alpha_mask is None else torch.flip(image_info.alpha_mask, [1])
|
||||
target_size = (latents.shape[2] * 8, latents.shape[1] * 8)
|
||||
image = None
|
||||
|
||||
target_size = (latents.shape[2] * 8, latents.shape[1] * 8)
|
||||
image = None
|
||||
|
||||
images.append(image)
|
||||
latents_list.append(latents)
|
||||
original_sizes_hw.append((int(original_size[1]), int(original_size[0])))
|
||||
crop_top_lefts.append((int(crop_ltrb[1]), int(crop_ltrb[0])))
|
||||
target_sizes_hw.append((int(target_size[1]), int(target_size[0])))
|
||||
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
|
||||
latents, original_size, crop_ltrb, flipped_latents, alpha_mask = (
|
||||
self.latents_caching_strategy.load_latents_from_disk(image_info.latents_npz, image_info.bucket_reso)
|
||||
)
|
||||
if flipped:
|
||||
latents = flipped_latents
|
||||
alpha_mask = None if alpha_mask is None else alpha_mask[:, ::-1].copy() # copy to avoid negative stride problem
|
||||
del flipped_latents
|
||||
latents = torch.FloatTensor(latents)
|
||||
if alpha_mask is not None:
|
||||
alpha_mask = torch.FloatTensor(alpha_mask)
|
||||
target_size = (latents.shape[2] * 8, latents.shape[1] * 8)
|
||||
|
||||
image = None
|
||||
|
||||
images.append(image)
|
||||
latents_list.append(latents)
|
||||
alpha_mask_list.append(alpha_mask)
|
||||
original_sizes_hw.append((int(original_size[1]), int(original_size[0])))
|
||||
crop_top_lefts.append((int(crop_ltrb[1]), int(crop_ltrb[0])))
|
||||
target_sizes_hw.append((int(target_size[1]), int(target_size[0])))
|
||||
else:
|
||||
if isinstance(image_info, ImageSetInfo):
|
||||
for absolute_path in image_info.absolute_paths:
|
||||
image, original_size, crop_ltrb, alpha_mask = self.load_and_transform_image(subset, image_info, absolute_path, flipped)
|
||||
images.append(image)
|
||||
latents_list.append(None)
|
||||
alpha_mask_list.append(alpha_mask)
|
||||
|
||||
target_size = (image.shape[2], image.shape[1]) if image is not None else (latents.shape[2] * 8, latents.shape[1] * 8)
|
||||
|
||||
if not flipped:
|
||||
crop_left_top = (crop_ltrb[0], crop_ltrb[1])
|
||||
else:
|
||||
# crop_ltrb[2] is right, so target_size[0] - crop_ltrb[2] is left in flipped image
|
||||
crop_left_top = (target_size[0] - crop_ltrb[2], crop_ltrb[1])
|
||||
|
||||
original_sizes_hw.append((int(original_size[1]), int(original_size[0])))
|
||||
crop_top_lefts.append((int(crop_left_top[1]), int(crop_left_top[0])))
|
||||
target_sizes_hw.append((int(target_size[1]), int(target_size[0])))
|
||||
flippeds.append(flipped)
|
||||
if self.enable_bucket:
|
||||
img, original_size, crop_ltrb = trim_and_resize_if_required(
|
||||
subset.random_crop, img, image_info.bucket_reso, image_info.resized_size, resize_interpolation=image_info.resize_interpolation
|
||||
images.append(image)
|
||||
latents_list.append(latents)
|
||||
original_sizes_hw.append((int(original_size[1]), int(original_size[0])))
|
||||
crop_top_lefts.append((int(crop_ltrb[1]), int(crop_ltrb[0])))
|
||||
target_sizes_hw.append((int(target_size[1]), int(target_size[0])))
|
||||
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
|
||||
latents, original_size, crop_ltrb, flipped_latents, alpha_mask = (
|
||||
self.latents_caching_strategy.load_latents_from_disk(image_info.latents_npz, image_info.bucket_reso)
|
||||
)
|
||||
if flipped:
|
||||
latents = flipped_latents
|
||||
alpha_mask = (
|
||||
None if alpha_mask is None else alpha_mask[:, ::-1].copy()
|
||||
) # copy to avoid negative stride problem
|
||||
del flipped_latents
|
||||
latents = torch.FloatTensor(latents)
|
||||
if alpha_mask is not None:
|
||||
alpha_mask = torch.FloatTensor(alpha_mask)
|
||||
target_size = (latents.shape[2] * 8, latents.shape[1] * 8)
|
||||
|
||||
image = None
|
||||
|
||||
images.append(image)
|
||||
latents_list.append(latents)
|
||||
alpha_mask_list.append(alpha_mask)
|
||||
original_sizes_hw.append((int(original_size[1]), int(original_size[0])))
|
||||
crop_top_lefts.append((int(crop_ltrb[1]), int(crop_ltrb[0])))
|
||||
target_sizes_hw.append((int(target_size[1]), int(target_size[0])))
|
||||
else:
|
||||
image, original_size, crop_ltrb, alpha_mask = self.load_and_transform_image(subset, image_info, image_info.absolute_path, flipped)
|
||||
image, original_size, crop_ltrb, alpha_mask = self.load_and_transform_image(
|
||||
subset, image_info, image_info.absolute_path, flipped
|
||||
)
|
||||
images.append(image)
|
||||
latents_list.append(None)
|
||||
alpha_mask_list.append(alpha_mask)
|
||||
|
||||
target_size = (image.shape[2], image.shape[1]) if image is not None else (latents.shape[2] * 8, latents.shape[1] * 8)
|
||||
target_size = (
|
||||
(image.shape[2], image.shape[1]) if image is not None else (latents.shape[2] * 8, latents.shape[1] * 8)
|
||||
)
|
||||
|
||||
if not flipped:
|
||||
crop_left_top = (crop_ltrb[0], crop_ltrb[1])
|
||||
@@ -1772,59 +1817,58 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
target_sizes_hw.append((int(target_size[1]), int(target_size[0])))
|
||||
flippeds.append(flipped)
|
||||
|
||||
# captionとtext encoder outputを処理する
|
||||
caption = image_info.caption # default
|
||||
|
||||
# captionとtext encoder outputを処理する
|
||||
caption = image_info.caption # default
|
||||
|
||||
tokenization_required = (
|
||||
self.text_encoder_output_caching_strategy is None or self.text_encoder_output_caching_strategy.is_partial
|
||||
)
|
||||
text_encoder_outputs = None
|
||||
input_ids = None
|
||||
|
||||
if image_info.text_encoder_outputs is not None:
|
||||
# cached
|
||||
text_encoder_outputs = image_info.text_encoder_outputs
|
||||
elif image_info.text_encoder_outputs_npz is not None:
|
||||
# on disk
|
||||
text_encoder_outputs = self.text_encoder_output_caching_strategy.load_outputs_npz(
|
||||
image_info.text_encoder_outputs_npz
|
||||
tokenization_required = (
|
||||
self.text_encoder_output_caching_strategy is None or self.text_encoder_output_caching_strategy.is_partial
|
||||
)
|
||||
else:
|
||||
tokenization_required = True
|
||||
text_encoder_outputs_list.append(text_encoder_outputs)
|
||||
text_encoder_outputs = None
|
||||
input_ids = None
|
||||
|
||||
if tokenization_required:
|
||||
caption = self.process_caption(subset, image_info.caption)
|
||||
input_ids = [ids[0] for ids in self.tokenize_strategy.tokenize(caption)] # remove batch dimension
|
||||
# if self.XTI_layers:
|
||||
# caption_layer = []
|
||||
# for layer in self.XTI_layers:
|
||||
# token_strings_from = " ".join(self.token_strings)
|
||||
# token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings])
|
||||
# caption_ = caption.replace(token_strings_from, token_strings_to)
|
||||
# caption_layer.append(caption_)
|
||||
# captions.append(caption_layer)
|
||||
# else:
|
||||
# captions.append(caption)
|
||||
if image_info.text_encoder_outputs is not None:
|
||||
# cached
|
||||
text_encoder_outputs = image_info.text_encoder_outputs
|
||||
elif image_info.text_encoder_outputs_npz is not None:
|
||||
# on disk
|
||||
text_encoder_outputs = self.text_encoder_output_caching_strategy.load_outputs_npz(
|
||||
image_info.text_encoder_outputs_npz
|
||||
)
|
||||
else:
|
||||
tokenization_required = True
|
||||
text_encoder_outputs_list.append(text_encoder_outputs)
|
||||
|
||||
# if not self.token_padding_disabled: # this option might be omitted in future
|
||||
# # TODO get_input_ids must support SD3
|
||||
# if self.XTI_layers:
|
||||
# token_caption = self.get_input_ids(caption_layer, self.tokenizers[0])
|
||||
# else:
|
||||
# token_caption = self.get_input_ids(caption, self.tokenizers[0])
|
||||
# input_ids_list.append(token_caption)
|
||||
if tokenization_required:
|
||||
caption = self.process_caption(subset, image_info.caption)
|
||||
input_ids = [ids[0] for ids in self.tokenize_strategy.tokenize(caption)] # remove batch dimension
|
||||
# if self.XTI_layers:
|
||||
# caption_layer = []
|
||||
# for layer in self.XTI_layers:
|
||||
# token_strings_from = " ".join(self.token_strings)
|
||||
# token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings])
|
||||
# caption_ = caption.replace(token_strings_from, token_strings_to)
|
||||
# caption_layer.append(caption_)
|
||||
# captions.append(caption_layer)
|
||||
# else:
|
||||
# captions.append(caption)
|
||||
|
||||
# if len(self.tokenizers) > 1:
|
||||
# if self.XTI_layers:
|
||||
# token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1])
|
||||
# else:
|
||||
# token_caption2 = self.get_input_ids(caption, self.tokenizers[1])
|
||||
# input_ids2_list.append(token_caption2)
|
||||
# if not self.token_padding_disabled: # this option might be omitted in future
|
||||
# # TODO get_input_ids must support SD3
|
||||
# if self.XTI_layers:
|
||||
# token_caption = self.get_input_ids(caption_layer, self.tokenizers[0])
|
||||
# else:
|
||||
# token_caption = self.get_input_ids(caption, self.tokenizers[0])
|
||||
# input_ids_list.append(token_caption)
|
||||
|
||||
input_ids_list.append(input_ids)
|
||||
captions.append(caption)
|
||||
# if len(self.tokenizers) > 1:
|
||||
# if self.XTI_layers:
|
||||
# token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1])
|
||||
# else:
|
||||
# token_caption2 = self.get_input_ids(caption, self.tokenizers[1])
|
||||
# input_ids2_list.append(token_caption2)
|
||||
|
||||
input_ids_list.append(input_ids)
|
||||
captions.append(caption)
|
||||
|
||||
def none_or_stack_elements(tensors_list, converter):
|
||||
# [[clip_l, clip_g, t5xxl], [clip_l, clip_g, t5xxl], ...] -> [torch.stack(clip_l), torch.stack(clip_g), torch.stack(t5xxl)]
|
||||
@@ -1864,6 +1908,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
example["images"] = images
|
||||
|
||||
example["latents"] = torch.stack(latents_list) if latents_list[0] is not None else None
|
||||
|
||||
example["captions"] = captions
|
||||
|
||||
example["original_sizes_hw"] = torch.stack([torch.LongTensor(x) for x in original_sizes_hw])
|
||||
@@ -1890,41 +1935,42 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
random_crop = None
|
||||
|
||||
for image_key in bucket[image_index : image_index + bucket_batch_size]:
|
||||
image_info = self.image_data[image_key]
|
||||
image_infos = self.image_data[image_key]
|
||||
subset = self.image_to_subset[image_key]
|
||||
|
||||
if flip_aug is None:
|
||||
flip_aug = subset.flip_aug
|
||||
alpha_mask = subset.alpha_mask
|
||||
random_crop = subset.random_crop
|
||||
bucket_reso = image_info.bucket_reso
|
||||
else:
|
||||
# TODO そもそも混在してても動くようにしたほうがいい
|
||||
assert flip_aug == subset.flip_aug, "flip_aug must be same in a batch"
|
||||
assert alpha_mask == subset.alpha_mask, "alpha_mask must be same in a batch"
|
||||
assert random_crop == subset.random_crop, "random_crop must be same in a batch"
|
||||
assert bucket_reso == image_info.bucket_reso, "bucket_reso must be same in a batch"
|
||||
for image_info in image_infos:
|
||||
if flip_aug is None:
|
||||
flip_aug = subset.flip_aug
|
||||
alpha_mask = subset.alpha_mask
|
||||
random_crop = subset.random_crop
|
||||
bucket_reso = image_info.bucket_reso
|
||||
else:
|
||||
# TODO そもそも混在してても動くようにしたほうがいい
|
||||
assert flip_aug == subset.flip_aug, "flip_aug must be same in a batch"
|
||||
assert alpha_mask == subset.alpha_mask, "alpha_mask must be same in a batch"
|
||||
assert random_crop == subset.random_crop, "random_crop must be same in a batch"
|
||||
assert bucket_reso == image_info.bucket_reso, "bucket_reso must be same in a batch"
|
||||
|
||||
caption = image_info.caption # TODO cache some patterns of dropping, shuffling, etc.
|
||||
caption = image_info.caption # TODO cache some patterns of dropping, shuffling, etc.
|
||||
|
||||
if self.caching_mode == "latents":
|
||||
image = load_image(image_info.absolute_path)
|
||||
else:
|
||||
image = None
|
||||
if self.caching_mode == "latents":
|
||||
image = load_image(image_info.absolute_path)
|
||||
else:
|
||||
image = None
|
||||
|
||||
if self.caching_mode == "text":
|
||||
input_ids1 = self.get_input_ids(caption, self.tokenizers[0])
|
||||
input_ids2 = self.get_input_ids(caption, self.tokenizers[1])
|
||||
else:
|
||||
input_ids1 = None
|
||||
input_ids2 = None
|
||||
if self.caching_mode == "text":
|
||||
input_ids1 = self.get_input_ids(caption, self.tokenizers[0])
|
||||
input_ids2 = self.get_input_ids(caption, self.tokenizers[1])
|
||||
else:
|
||||
input_ids1 = None
|
||||
input_ids2 = None
|
||||
|
||||
captions.append(caption)
|
||||
images.append(image)
|
||||
input_ids1_list.append(input_ids1)
|
||||
input_ids2_list.append(input_ids2)
|
||||
absolute_paths.append(image_info.absolute_path)
|
||||
resized_sizes.append(image_info.resized_size)
|
||||
captions.append(caption)
|
||||
images.append(image)
|
||||
input_ids1_list.append(input_ids1)
|
||||
input_ids2_list.append(input_ids2)
|
||||
absolute_paths.append(image_info.absolute_path)
|
||||
resized_sizes.append(image_info.resized_size)
|
||||
|
||||
example = {}
|
||||
|
||||
@@ -2198,12 +2244,27 @@ class DreamBoothDataset(BaseDataset):
|
||||
|
||||
for img_path, caption, size in zip(img_paths, captions, sizes):
|
||||
if subset.preference:
|
||||
|
||||
def get_non_preferred_pair_info(img_path, subset):
|
||||
head, file = os.path.split(img_path)
|
||||
head, tail = os.path.split(head)
|
||||
new_tail = tail.replace('w', 'l')
|
||||
new_tail = tail.replace("w", "l")
|
||||
loser_img_path = os.path.join(head, new_tail, file)
|
||||
|
||||
def check_extension(path: str):
|
||||
from pathlib import Path
|
||||
|
||||
test_path = Path(path)
|
||||
if not test_path.exists():
|
||||
for ext in [".webp", ".png", ".jpg", ".jpeg", ".png"]:
|
||||
test_path = test_path.with_suffix(ext)
|
||||
if test_path.exists():
|
||||
return str(test_path)
|
||||
|
||||
return str(test_path)
|
||||
|
||||
loser_img_path = check_extension(loser_img_path)
|
||||
|
||||
caption = read_caption(img_path, subset.caption_extension, subset.enable_wildcard)
|
||||
|
||||
if subset.non_preference_caption_prefix:
|
||||
@@ -2220,17 +2281,25 @@ class DreamBoothDataset(BaseDataset):
|
||||
if subset.preference_caption_suffix:
|
||||
caption = caption + " " + subset.preference_caption_suffix
|
||||
|
||||
info = ImageSetInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path)
|
||||
info.resize_interpolation = subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation
|
||||
if size is not None:
|
||||
info.image_size = size
|
||||
info.image_sizes = [size]
|
||||
else:
|
||||
info.image_sizes = [None]
|
||||
info.add(*get_non_preferred_pair_info(img_path, subset))
|
||||
resize_interpolation = (
|
||||
subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation
|
||||
)
|
||||
|
||||
chosen_image_info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path)
|
||||
chosen_image_info.resize_interpolation = resize_interpolation
|
||||
rejected_img_path, rejected_caption, rejected_image_size = get_non_preferred_pair_info(img_path, subset)
|
||||
rejected_image_info = ImageInfo(
|
||||
rejected_img_path, subset.num_repeats, caption, subset.is_reg, rejected_img_path
|
||||
)
|
||||
rejected_image_info.resize_interpolation = resize_interpolation
|
||||
|
||||
info = ImageSetInfo([chosen_image_info, rejected_image_info])
|
||||
print(chosen_image_info.image_size, rejected_image_info.image_size)
|
||||
else:
|
||||
info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path)
|
||||
info.resize_interpolation = subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation
|
||||
info.resize_interpolation = (
|
||||
subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation
|
||||
)
|
||||
if size is not None:
|
||||
info.image_size = size
|
||||
|
||||
@@ -2515,7 +2584,7 @@ class ControlNetDataset(BaseDataset):
|
||||
bucket_no_upscale: bool,
|
||||
debug_dataset: bool,
|
||||
validation_split: float,
|
||||
validation_seed: Optional[int],
|
||||
validation_seed: Optional[int],
|
||||
resize_interpolation: Optional[str] = None,
|
||||
) -> None:
|
||||
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)
|
||||
@@ -2583,7 +2652,7 @@ class ControlNetDataset(BaseDataset):
|
||||
self.num_train_images = self.dreambooth_dataset_delegate.num_train_images
|
||||
self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images
|
||||
self.validation_split = validation_split
|
||||
self.validation_seed = validation_seed
|
||||
self.validation_seed = validation_seed
|
||||
self.resize_interpolation = resize_interpolation
|
||||
|
||||
# assert all conditioning data exists
|
||||
@@ -2673,7 +2742,14 @@ class ControlNetDataset(BaseDataset):
|
||||
cond_img.shape[0] == original_size_hw[0] and cond_img.shape[1] == original_size_hw[1]
|
||||
), f"size of conditioning image is not match / 画像サイズが合いません: {image_info.absolute_path}"
|
||||
|
||||
cond_img = resize_image(cond_img, original_size_hw[1], original_size_hw[0], target_size_hw[1], target_size_hw[0], self.resize_interpolation)
|
||||
cond_img = resize_image(
|
||||
cond_img,
|
||||
original_size_hw[1],
|
||||
original_size_hw[0],
|
||||
target_size_hw[1],
|
||||
target_size_hw[0],
|
||||
self.resize_interpolation,
|
||||
)
|
||||
|
||||
# TODO support random crop
|
||||
# 現在サポートしているcropはrandomではなく中央のみ
|
||||
@@ -2687,7 +2763,14 @@ class ControlNetDataset(BaseDataset):
|
||||
# ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
|
||||
# resize to target
|
||||
if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]:
|
||||
cond_img = resize_image(cond_img, cond_img.shape[0], cond_img.shape[1], target_size_hw[1], target_size_hw[0], self.resize_interpolation)
|
||||
cond_img = resize_image(
|
||||
cond_img,
|
||||
cond_img.shape[0],
|
||||
cond_img.shape[1],
|
||||
target_size_hw[1],
|
||||
target_size_hw[0],
|
||||
self.resize_interpolation,
|
||||
)
|
||||
|
||||
if flipped:
|
||||
cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride
|
||||
@@ -3117,7 +3200,7 @@ def trim_and_resize_if_required(
|
||||
# for new_cache_latents
|
||||
def load_images_and_masks_for_caching(
|
||||
image_infos: List[ImageInfo], use_alpha_mask: bool, random_crop: bool
|
||||
) -> Tuple[torch.Tensor, List[np.ndarray], List[Tuple[int, int]], List[Tuple[int, int, int, int]]]:
|
||||
) -> Tuple[torch.Tensor, list[torch.Tensor | None], List[Tuple[int, int]], List[Tuple[int, int, int, int]]]:
|
||||
r"""
|
||||
requires image_infos to have: [absolute_path or image], bucket_reso, resized_size
|
||||
|
||||
@@ -3129,38 +3212,47 @@ def load_images_and_masks_for_caching(
|
||||
crop_ltrbs: List[Tuple[int, int, int, int]] = [(L, T, R, B), ...]
|
||||
"""
|
||||
images: List[torch.Tensor] = []
|
||||
alpha_masks: List[np.ndarray] = []
|
||||
alpha_masks: list[torch.Tensor | None] = []
|
||||
original_sizes: List[Tuple[int, int]] = []
|
||||
crop_ltrbs: List[Tuple[int, int, int, int]] = []
|
||||
for info in image_infos:
|
||||
image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8)
|
||||
# TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要
|
||||
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation)
|
||||
for infos in image_infos:
|
||||
for info in infos:
|
||||
image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8)
|
||||
# TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要
|
||||
image, original_size, crop_ltrb = trim_and_resize_if_required(
|
||||
random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation
|
||||
)
|
||||
|
||||
original_sizes.append(original_size)
|
||||
crop_ltrbs.append(crop_ltrb)
|
||||
original_sizes.append(original_size)
|
||||
crop_ltrbs.append(crop_ltrb)
|
||||
|
||||
if use_alpha_mask:
|
||||
if image.shape[2] == 4:
|
||||
alpha_mask = image[:, :, 3] # [H,W]
|
||||
alpha_mask = alpha_mask.astype(np.float32) / 255.0
|
||||
alpha_mask = torch.FloatTensor(alpha_mask) # [H,W]
|
||||
if use_alpha_mask:
|
||||
if image.shape[2] == 4:
|
||||
alpha_mask = image[:, :, 3] # [H,W]
|
||||
alpha_mask = alpha_mask.astype(np.float32) / 255.0
|
||||
alpha_mask = torch.FloatTensor(alpha_mask) # [H,W]
|
||||
else:
|
||||
alpha_mask = torch.ones_like(torch.from_numpy(image[:, :, 0]), dtype=torch.float32) # [H,W]
|
||||
else:
|
||||
alpha_mask = torch.ones_like(image[:, :, 0], dtype=torch.float32) # [H,W]
|
||||
else:
|
||||
alpha_mask = None
|
||||
alpha_masks.append(alpha_mask)
|
||||
alpha_mask = None
|
||||
alpha_masks.append(alpha_mask)
|
||||
|
||||
image = image[:, :, :3] # remove alpha channel if exists
|
||||
image = IMAGE_TRANSFORMS(image)
|
||||
images.append(image)
|
||||
image = image[:, :, :3] # remove alpha channel if exists
|
||||
image = IMAGE_TRANSFORMS(image)
|
||||
assert isinstance(image, torch.Tensor)
|
||||
images.append(image)
|
||||
|
||||
img_tensor = torch.stack(images, dim=0)
|
||||
return img_tensor, alpha_masks, original_sizes, crop_ltrbs
|
||||
|
||||
|
||||
def cache_batch_latents(
|
||||
vae: AutoencoderKL, cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, use_alpha_mask: bool, random_crop: bool
|
||||
vae: AutoencoderKL,
|
||||
cache_to_disk: bool,
|
||||
image_infos: list[ImageInfo | ImageSetInfo],
|
||||
flip_aug: bool,
|
||||
use_alpha_mask: bool,
|
||||
random_crop: bool,
|
||||
) -> None:
|
||||
r"""
|
||||
requires image_infos to have: absolute_path, bucket_reso, resized_size, latents_npz
|
||||
@@ -3172,29 +3264,32 @@ def cache_batch_latents(
|
||||
latents_original_size and latents_crop_ltrb are also set
|
||||
"""
|
||||
images = []
|
||||
alpha_masks: List[np.ndarray] = []
|
||||
for info in image_infos:
|
||||
image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8)
|
||||
# TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要
|
||||
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation)
|
||||
alpha_masks: List[torch.Tensor | None] = []
|
||||
for infos in image_infos:
|
||||
for info in infos:
|
||||
image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8)
|
||||
# TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要
|
||||
image, original_size, crop_ltrb = trim_and_resize_if_required(
|
||||
random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation
|
||||
)
|
||||
|
||||
info.latents_original_size = original_size
|
||||
info.latents_crop_ltrb = crop_ltrb
|
||||
info.latents_original_size = original_size
|
||||
info.latents_crop_ltrb = crop_ltrb
|
||||
|
||||
if use_alpha_mask:
|
||||
if image.shape[2] == 4:
|
||||
alpha_mask = image[:, :, 3] # [H,W]
|
||||
alpha_mask = alpha_mask.astype(np.float32) / 255.0
|
||||
alpha_mask = torch.FloatTensor(alpha_mask) # [H,W]
|
||||
if use_alpha_mask:
|
||||
if image.shape[2] == 4:
|
||||
alpha_mask = image[:, :, 3] # [H,W]
|
||||
alpha_mask = alpha_mask.astype(np.float32) / 255.0
|
||||
alpha_mask = torch.FloatTensor(alpha_mask) # [H,W]
|
||||
else:
|
||||
alpha_mask = torch.ones_like(torch.from_numpy(image[:, :, 0]), dtype=torch.float32) # [H,W]
|
||||
else:
|
||||
alpha_mask = torch.ones_like(image[:, :, 0], dtype=torch.float32) # [H,W]
|
||||
else:
|
||||
alpha_mask = None
|
||||
alpha_masks.append(alpha_mask)
|
||||
alpha_mask = None
|
||||
alpha_masks.append(alpha_mask)
|
||||
|
||||
image = image[:, :, :3] # remove alpha channel if exists
|
||||
image = IMAGE_TRANSFORMS(image)
|
||||
images.append(image)
|
||||
image = image[:, :, :3] # remove alpha channel if exists
|
||||
image = IMAGE_TRANSFORMS(image)
|
||||
images.append(image)
|
||||
|
||||
img_tensors = torch.stack(images, dim=0)
|
||||
img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype)
|
||||
@@ -6176,7 +6271,8 @@ def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler
|
||||
elif args.huber_schedule == "snr":
|
||||
if not hasattr(noise_scheduler, "alphas_cumprod"):
|
||||
raise NotImplementedError("Huber schedule 'snr' is not supported with the current model.")
|
||||
alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps.cpu())
|
||||
device = noise_scheduler.alphas_cumprod.device
|
||||
alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps.to(device))
|
||||
sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
||||
result = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c
|
||||
result = result.to(timesteps.device)
|
||||
@@ -6727,4 +6823,3 @@ class LossRecorder:
|
||||
if losses == 0:
|
||||
return 0
|
||||
return self.loss_total / losses
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ from PIL import Image
|
||||
import numpy as np
|
||||
from safetensors.torch import load_file
|
||||
|
||||
|
||||
def fire_in_thread(f, *args, **kwargs):
|
||||
threading.Thread(target=f, args=args, kwargs=kwargs).start()
|
||||
|
||||
@@ -88,6 +89,7 @@ def setup_logging(args=None, log_level=None, reset=False):
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(msg_init)
|
||||
|
||||
|
||||
setup_logging()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -398,7 +400,9 @@ def pil_resize(image, size, interpolation):
|
||||
return resized_cv2
|
||||
|
||||
|
||||
def resize_image(image: np.ndarray, width: int, height: int, resized_width: int, resized_height: int, resize_interpolation: Optional[str] = None):
|
||||
def resize_image(
|
||||
image: np.ndarray, width: int, height: int, resized_width: int, resized_height: int, resize_interpolation: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Resize image with resize interpolation. Default interpolation to AREA if image is smaller, else LANCZOS.
|
||||
|
||||
@@ -449,29 +453,30 @@ def get_cv2_interpolation(interpolation: Optional[str]) -> Optional[int]:
|
||||
https://docs.opencv.org/3.4/da/d54/group__imgproc__transform.html#ga5bb5a1fea74ea38e1a5445ca803ff121
|
||||
"""
|
||||
if interpolation is None:
|
||||
return None
|
||||
return None
|
||||
|
||||
if interpolation == "lanczos" or interpolation == "lanczos4":
|
||||
# Lanczos interpolation over 8x8 neighborhood
|
||||
# Lanczos interpolation over 8x8 neighborhood
|
||||
return cv2.INTER_LANCZOS4
|
||||
elif interpolation == "nearest":
|
||||
# Bit exact nearest neighbor interpolation. This will produce same results as the nearest neighbor method in PIL, scikit-image or Matlab.
|
||||
# Bit exact nearest neighbor interpolation. This will produce same results as the nearest neighbor method in PIL, scikit-image or Matlab.
|
||||
return cv2.INTER_NEAREST_EXACT
|
||||
elif interpolation == "bilinear" or interpolation == "linear":
|
||||
# bilinear interpolation
|
||||
return cv2.INTER_LINEAR
|
||||
elif interpolation == "bicubic" or interpolation == "cubic":
|
||||
# bicubic interpolation
|
||||
# bicubic interpolation
|
||||
return cv2.INTER_CUBIC
|
||||
elif interpolation == "area":
|
||||
# resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method.
|
||||
# resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method.
|
||||
return cv2.INTER_AREA
|
||||
elif interpolation == "box":
|
||||
# resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method.
|
||||
# resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method.
|
||||
return cv2.INTER_AREA
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_pil_interpolation(interpolation: Optional[str]) -> Optional[Image.Resampling]:
|
||||
"""
|
||||
Convert interpolation value to PIL interpolation
|
||||
@@ -479,7 +484,7 @@ def get_pil_interpolation(interpolation: Optional[str]) -> Optional[Image.Resamp
|
||||
https://pillow.readthedocs.io/en/stable/handbook/concepts.html#concept-filters
|
||||
"""
|
||||
if interpolation is None:
|
||||
return None
|
||||
return None
|
||||
|
||||
if interpolation == "lanczos":
|
||||
return Image.Resampling.LANCZOS
|
||||
@@ -493,7 +498,7 @@ def get_pil_interpolation(interpolation: Optional[str]) -> Optional[Image.Resamp
|
||||
# For resize calculate the output pixel value using cubic interpolation on all pixels that may contribute to the output value. For other transformations cubic interpolation over a 4x4 environment in the input image is used.
|
||||
return Image.Resampling.BICUBIC
|
||||
elif interpolation == "area":
|
||||
# Image.Resampling.BOX may be more appropriate if upscaling
|
||||
# Image.Resampling.BOX may be more appropriate if upscaling
|
||||
# Area interpolation is related to cv2.INTER_AREA
|
||||
# Produces a sharper image than Resampling.BILINEAR, doesn’t have dislocations on local level like with Resampling.BOX.
|
||||
return Image.Resampling.HAMMING
|
||||
@@ -503,12 +508,37 @@ def get_pil_interpolation(interpolation: Optional[str]) -> Optional[Image.Resamp
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def validate_interpolation_fn(interpolation_str: str) -> bool:
|
||||
"""
|
||||
Check if a interpolation function is supported
|
||||
"""
|
||||
return interpolation_str in ["lanczos", "nearest", "bilinear", "linear", "bicubic", "cubic", "area", "box"]
|
||||
|
||||
|
||||
# For debugging
|
||||
def save_latent_as_img(vae, latent_to, output_name):
|
||||
"""Save latent as image using VAE"""
|
||||
from PIL import Image
|
||||
|
||||
with torch.no_grad():
|
||||
image = vae.decode(latent_to.to(vae.dtype)).float()
|
||||
# VAE outputs are typically in the range [-1, 1], so rescale to [0, 255]
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
# Convert to numpy array with values in range [0, 255]
|
||||
image = (image * 255).cpu().numpy().astype(np.uint8)
|
||||
|
||||
# Rearrange dimensions from [batch_size, channels, height, width] to [batch_size, height, width, channels]
|
||||
image = image.transpose(0, 2, 3, 1)
|
||||
|
||||
# Take the first image if you have a batch
|
||||
pil_image = Image.fromarray(image[0])
|
||||
|
||||
# Save the image
|
||||
pil_image.save(output_name)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# TODO make inf_utils.py
|
||||
|
||||
Reference in New Issue
Block a user