mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Merge pull request #1505 from liesened/patch-2
Add v-pred support for SDXL train
This commit is contained in:
@@ -702,7 +702,11 @@ def train(args):
|
|||||||
with accelerator.autocast():
|
with accelerator.autocast():
|
||||||
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
|
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
|
||||||
|
|
||||||
target = noise
|
if args.v_parameterization:
|
||||||
|
# v-parameterization training
|
||||||
|
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||||
|
else:
|
||||||
|
target = noise
|
||||||
|
|
||||||
if (
|
if (
|
||||||
args.min_snr_gamma
|
args.min_snr_gamma
|
||||||
@@ -720,7 +724,7 @@ def train(args):
|
|||||||
loss = loss.mean([1, 2, 3])
|
loss = loss.mean([1, 2, 3])
|
||||||
|
|
||||||
if args.min_snr_gamma:
|
if args.min_snr_gamma:
|
||||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||||
if args.scale_v_pred_loss_like_noise_pred:
|
if args.scale_v_pred_loss_like_noise_pred:
|
||||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||||
if args.v_pred_like_loss:
|
if args.v_pred_like_loss:
|
||||||
|
|||||||
Reference in New Issue
Block a user