revert perlin_noise

This commit is contained in:
Kohya S
2023-05-15 23:12:11 +09:00
parent 08d85d4013
commit 714846e1e1
3 changed files with 73 additions and 50 deletions

View File

@@ -19,6 +19,9 @@ def apply_snr_weight(loss, timesteps, noise_scheduler, gamma):
return loss
# TODO train_utilと分散しているのでどちらかに寄せる
def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted_captions: bool = True):
parser.add_argument(
"--min_snr_gamma",
@@ -347,7 +350,7 @@ def get_weighted_text_embeddings(
# https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2
def pyramid_noise_like(noise, device, iterations=6, discount=0.4):
b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant!
b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant!
u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device)
for i in range(iterations):
r = random.random() * 2 + 2 # Rather than always going 2x,
@@ -369,58 +372,65 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
# multiply adaptive noise scale to the mean value and add it to the noise offset
noise_offset = noise_offset + adaptive_noise_scale * latent_mean
noise_offset = torch.clamp(noise_offset, 0.0, None) # in case of adaptive noise scale is negative
noise_offset = torch.clamp(noise_offset, 0.0, None) # in case of adaptive noise scale is negative
noise = noise + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
return noise
"""
##########################################
# Perlin Noise
def rand_perlin_2d(device, shape, res, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3):
def rand_perlin_2d(device, shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
delta = (res[0] / shape[0], res[1] / shape[1])
d = (shape[0] // res[0], shape[1] // res[1])
grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0],device=device), torch.arange(0, res[1], delta[1],device=device)), dim=-1) % 1
angles = 2 * torch.pi * torch.rand(res[0] + 1, res[1] + 1,device=device)
grid = (
torch.stack(
torch.meshgrid(torch.arange(0, res[0], delta[0], device=device), torch.arange(0, res[1], delta[1], device=device)),
dim=-1,
)
% 1
)
angles = 2 * torch.pi * torch.rand(res[0] + 1, res[1] + 1, device=device)
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
tile_grads = lambda slice1, slice2: gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]].repeat_interleave(d[0],
0).repeat_interleave(
d[1], 1)
tile_grads = (
lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
.repeat_interleave(d[0], 0)
.repeat_interleave(d[1], 1)
)
dot = lambda grad, shift: (
torch.stack((grid[:shape[0], :shape[1], 0] + shift[0], grid[:shape[0], :shape[1], 1] + shift[1]),
dim=-1) * grad[:shape[0], :shape[1]]).sum(dim=-1)
torch.stack((grid[: shape[0], : shape[1], 0] + shift[0], grid[: shape[0], : shape[1], 1] + shift[1]), dim=-1)
* grad[: shape[0], : shape[1]]
).sum(dim=-1)
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
t = fade(grid[:shape[0], :shape[1]])
t = fade(grid[: shape[0], : shape[1]])
return 1.414 * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])
def rand_perlin_2d_octaves(device, shape, res, octaves=1, persistence=0.5):
noise = torch.zeros(shape,device=device)
noise = torch.zeros(shape, device=device)
frequency = 1
amplitude = 1
for _ in range(octaves):
noise += amplitude * rand_perlin_2d(device, shape, (frequency*res[0], frequency*res[1]))
noise += amplitude * rand_perlin_2d(device, shape, (frequency * res[0], frequency * res[1]))
frequency *= 2
amplitude *= persistence
return noise
def perlin_noise(noise, device,octaves):
b, c, w, h = noise.shape()
perlin = lambda : rand_perlin_2d_octaves(device,(w,h),(4,4),octaves)
noise_perlin_r = torch.rand(noise.shape, device=device) + perlin()
noise_perlin_g = torch.rand(noise.shape, device=device) + perlin()
noise_perlin_b = torch.rand(noise.shape, device=device) + perlin()
noise_perlin = torch.cat(
(noise_perlin_r,
noise_perlin_g,
noise_perlin_b),
1)
return noise_perlin
def perlin_noise(noise, device, octaves):
_, c, w, h = noise.shape
perlin = lambda: rand_perlin_2d_octaves(device, (w, h), (4, 4), octaves)
noise_perlin = []
for _ in range(c):
noise_perlin.append(perlin())
noise_perlin = torch.stack(noise_perlin).unsqueeze(0) # (1, c, w, h)
noise += noise_perlin # broadcast for each batch
return noise / noise.std() # Scaled back to roughly unit variance
"""