Add import hint

This commit is contained in:
rockerBOO
2025-07-15 19:53:54 -04:00
parent 479ec9c8a6
commit 9cedf18a97
2 changed files with 8 additions and 1 deletions

View File

@@ -6029,6 +6029,9 @@ def get_noise_noisy_latents_and_timesteps(
else:
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
noise_scheduler.alphas_cumprod = noise_scheduler.alphas_cumprod.to("cpu")
return noise, noisy_latents, timesteps

View File

@@ -45,7 +45,11 @@ from library.custom_train_functions import (
apply_masked_loss,
)
from library.utils import setup_logging, add_logging_arguments
from wavelet_loss import WaveletLoss
try:
from wavelet_loss import WaveletLoss
except:
raise ImportError("WaveletLoss is not installed. Please install it with `pip install git+https://github.com/rockerBOO/wavelet-loss`")
setup_logging()
import logging