mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 06:28:48 +00:00
Merge pull request #507 from HkingAuditore/main
Added support for Perlin noise in Noise Offset
This commit is contained in:
@@ -373,3 +373,54 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
|
||||
|
||||
noise = noise + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
|
||||
return noise
|
||||
|
||||
|
||||
|
||||
##########################################
|
||||
# Perlin Noise
|
||||
def rand_perlin_2d(device, shape, res, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3):
|
||||
delta = (res[0] / shape[0], res[1] / shape[1])
|
||||
d = (shape[0] // res[0], shape[1] // res[1])
|
||||
|
||||
grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0],device=device), torch.arange(0, res[1], delta[1],device=device)), dim=-1) % 1
|
||||
angles = 2 * torch.pi * torch.rand(res[0] + 1, res[1] + 1,device=device)
|
||||
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
|
||||
|
||||
tile_grads = lambda slice1, slice2: gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]].repeat_interleave(d[0],
|
||||
0).repeat_interleave(
|
||||
d[1], 1)
|
||||
dot = lambda grad, shift: (
|
||||
torch.stack((grid[:shape[0], :shape[1], 0] + shift[0], grid[:shape[0], :shape[1], 1] + shift[1]),
|
||||
dim=-1) * grad[:shape[0], :shape[1]]).sum(dim=-1)
|
||||
|
||||
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
|
||||
n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
|
||||
n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
|
||||
n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
|
||||
t = fade(grid[:shape[0], :shape[1]])
|
||||
return 1.414 * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])
|
||||
|
||||
def rand_perlin_2d_octaves(device, shape, res, octaves=1, persistence=0.5):
|
||||
noise = torch.zeros(shape,device=device)
|
||||
frequency = 1
|
||||
amplitude = 1
|
||||
for _ in range(octaves):
|
||||
noise += amplitude * rand_perlin_2d(device, shape, (frequency*res[0], frequency*res[1]))
|
||||
frequency *= 2
|
||||
amplitude *= persistence
|
||||
return noise
|
||||
|
||||
def perlin_noise(noise, device,octaves):
|
||||
b, c, w, h = noise.shape()
|
||||
perlin = lambda : rand_perlin_2d_octaves(device,(w,h),(4,4),octaves)
|
||||
noise_perlin_r = torch.rand(noise.shape, device=device) + perlin()
|
||||
noise_perlin_g = torch.rand(noise.shape, device=device) + perlin()
|
||||
noise_perlin_b = torch.rand(noise.shape, device=device) + perlin()
|
||||
noise_perlin = torch.cat(
|
||||
(noise_perlin_r,
|
||||
noise_perlin_g,
|
||||
noise_perlin_b),
|
||||
1)
|
||||
return noise_perlin
|
||||
|
||||
|
||||
|
||||
@@ -2127,6 +2127,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
default=None,
|
||||
help="enable multires noise with this number of iterations (if enabled, around 6-10 is recommended) / Multires noiseを有効にしてこのイテレーション数を設定する(有効にする場合は6-10程度を推奨)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--perlin_noise",
|
||||
type=int,
|
||||
default=None,
|
||||
help="enable perlin noise and set the octaves",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--multires_noise_discount",
|
||||
type=float,
|
||||
|
||||
@@ -23,7 +23,7 @@ from library.config_util import (
|
||||
BlueprintGenerator,
|
||||
)
|
||||
import library.custom_train_functions as custom_train_functions
|
||||
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like, apply_noise_offset
|
||||
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like, apply_noise_offset,perlin_noise
|
||||
|
||||
|
||||
def train(args):
|
||||
@@ -274,6 +274,8 @@ def train(args):
|
||||
noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale)
|
||||
elif args.multires_noise_iterations:
|
||||
noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount)
|
||||
elif args.perlin_noise:
|
||||
noise = perlin_noise(noise,latents.device,args.perlin_noise)
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
with torch.set_grad_enabled(global_step < args.stop_text_encoder_training):
|
||||
|
||||
Reference in New Issue
Block a user