mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-17 17:24:21 +00:00
Fix wavelet loss on non-flow matching models (sd1.5, SDXL). Fix wavelet coorelation.
This commit is contained in:
@@ -57,6 +57,7 @@ class NetworkTrainer:
|
||||
def __init__(self):
|
||||
self.vae_scale_factor = 0.18215
|
||||
self.is_sdxl = False
|
||||
self.is_flow_matching = False
|
||||
|
||||
# TODO 他のスクリプトと共通化する
|
||||
def generate_step_logs(
|
||||
@@ -172,9 +173,9 @@ class NetworkTrainer:
|
||||
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
|
||||
val_dataset_group: Optional[train_util.DatasetGroup],
|
||||
):
|
||||
train_dataset_group.verify_bucket_reso_steps(64)
|
||||
train_dataset_group.verify_bucket_reso_steps(32)
|
||||
if val_dataset_group is not None:
|
||||
val_dataset_group.verify_bucket_reso_steps(64)
|
||||
val_dataset_group.verify_bucket_reso_steps(32)
|
||||
|
||||
def load_target_model(self, args, weight_dtype, accelerator):
|
||||
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
||||
@@ -323,6 +324,7 @@ class NetworkTrainer:
|
||||
target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype)
|
||||
|
||||
sigmas = timesteps / noise_scheduler.config.num_train_timesteps
|
||||
sigmas = sigmas.view(-1, 1, 1, 1)
|
||||
|
||||
return noise_pred, noisy_latents, target, sigmas, timesteps, None, noise
|
||||
|
||||
@@ -472,9 +474,22 @@ class NetworkTrainer:
|
||||
if args.wavelet_loss:
|
||||
def maybe_denoise_latents(denoise_latents: bool, noisy_latents, sigmas, noise_pred, noise):
|
||||
if denoise_latents:
|
||||
# denoise latents to use for wavelet loss
|
||||
wavelet_predicted = (noisy_latents - sigmas * noise_pred) / (1.0 - sigmas)
|
||||
wavelet_target = (noisy_latents - sigmas * noise) / (1.0 - sigmas)
|
||||
if self.is_flow_matching:
|
||||
# denoise latents to use for wavelet loss
|
||||
wavelet_predicted = (noisy_latents - sigmas * noise_pred) / (1.0 - sigmas)
|
||||
wavelet_target = (noisy_latents - sigmas * noise) / (1.0 - sigmas)
|
||||
|
||||
else:
|
||||
# Get alpha values from scheduler
|
||||
alphas_cumprod = noise_scheduler.alphas_cumprod.to(noisy_latents.device)
|
||||
alpha_t = alphas_cumprod[timesteps].reshape(-1, 1, 1, 1)
|
||||
sqrt_alpha_t = torch.sqrt(alpha_t)
|
||||
sqrt_one_minus_alpha_t = torch.sqrt(1.0 - alpha_t)
|
||||
|
||||
# Predict x0 (clean latents) from noise prediction
|
||||
wavelet_predicted = (noisy_latents - sqrt_one_minus_alpha_t * noise_pred) / sqrt_alpha_t
|
||||
wavelet_target = (noisy_latents - sqrt_one_minus_alpha_t * noise) / sqrt_alpha_t
|
||||
|
||||
return wavelet_predicted, wavelet_target
|
||||
else:
|
||||
return noise_pred, target
|
||||
|
||||
Reference in New Issue
Block a user