Merge pull request #471 from pamparamm/multires-noise

Multi-Resolution Noise
This commit is contained in:
Kohya S
2023-05-03 11:17:21 +09:00
committed by GitHub
7 changed files with 42 additions and 5 deletions

View File

@@ -20,7 +20,7 @@ from library.config_util import (
BlueprintGenerator,
)
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight
from library.custom_train_functions import apply_snr_weight, pyramid_noise_like
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
imagenet_templates_small = [
@@ -428,6 +428,8 @@ def train(args):
if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
elif args.multires_noise_iterations:
noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount)
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)