Merge branch 'dev' into sd3

This commit is contained in:
Kohya S
2024-08-24 21:24:44 +09:00
2 changed files with 7 additions and 2 deletions

View File

@@ -461,6 +461,7 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
### Working in progress ### Working in progress
- `--v_parameterization` is available in `sdxl_train.py`. The results are unpredictable, so use with caution. PR [#1505](https://github.com/kohya-ss/sd-scripts/pull/1505) Thanks to liesened!
- Fused optimizer is available for SDXL training. PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) Thanks to 2kpr! - Fused optimizer is available for SDXL training. PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) Thanks to 2kpr!
- The memory usage during training is significantly reduced by integrating the optimizer's backward pass with step. The training results are the same as before, but if you have plenty of memory, the speed will be slower. - The memory usage during training is significantly reduced by integrating the optimizer's backward pass with step. The training results are the same as before, but if you have plenty of memory, the speed will be slower.
- Specify the `--fused_backward_pass` option in `sdxl_train.py`. At this time, only Adafactor is supported. Gradient accumulation is not available. - Specify the `--fused_backward_pass` option in `sdxl_train.py`. At this time, only Adafactor is supported. Gradient accumulation is not available.

View File

@@ -700,7 +700,11 @@ def train(args):
with accelerator.autocast(): with accelerator.autocast():
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
target = noise if args.v_parameterization:
# v-parameterization training
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
target = noise
if ( if (
args.min_snr_gamma args.min_snr_gamma
@@ -718,7 +722,7 @@ def train(args):
loss = loss.mean([1, 2, 3]) loss = loss.mean([1, 2, 3])
if args.min_snr_gamma: if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred: if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.v_pred_like_loss: if args.v_pred_like_loss: