From 5f1d07d62f4af2c77ef878860b3904537889fb85 Mon Sep 17 00:00:00 2001 From: hkinghuang <178854663@qq.com> Date: Fri, 12 May 2023 21:38:07 +0800 Subject: [PATCH] init --- library/custom_train_functions.py | 51 +++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 0c527c35..a2303a87 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -373,3 +373,54 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale): noise = noise + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) return noise + + + +########################################## +# Perlin Noise +def rand_perlin_2d(device, shape, res, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3): + delta = (res[0] / shape[0], res[1] / shape[1]) + d = (shape[0] // res[0], shape[1] // res[1]) + + grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0],device=device), torch.arange(0, res[1], delta[1],device=device)), dim=-1) % 1 + angles = 2 * torch.pi * torch.rand(res[0] + 1, res[1] + 1,device=device) + gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1) + + tile_grads = lambda slice1, slice2: gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]].repeat_interleave(d[0], + 0).repeat_interleave( + d[1], 1) + dot = lambda grad, shift: ( + torch.stack((grid[:shape[0], :shape[1], 0] + shift[0], grid[:shape[0], :shape[1], 1] + shift[1]), + dim=-1) * grad[:shape[0], :shape[1]]).sum(dim=-1) + + n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]) + n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]) + n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]) + n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]) + t = fade(grid[:shape[0], :shape[1]]) + return 1.414 * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]) + +def rand_perlin_2d_octaves(device, shape, res, octaves=1, persistence=0.5): + noise = torch.zeros(shape,device=device) + frequency = 1 + amplitude = 1 + for _ in range(octaves): + noise += amplitude * rand_perlin_2d(device, shape, (frequency*res[0], frequency*res[1])) + frequency *= 2 + amplitude *= persistence + return noise + +def perlin_noise(noise, device): + b, c, w, h = noise.shape() + perlin = lambda : rand_perlin_2d_octaves(device,(w,h),(4,4),1) + noise_perlin_r = torch.rand(noise.shape, device=device) + perlin() + noise_perlin_g = torch.rand(noise.shape, device=device) + perlin() + noise_perlin_b = torch.rand(noise.shape, device=device) + perlin() + noise_perlin = torch.cat( + (noise_perlin_r, + noise_perlin_g, + noise_perlin_b), + 2) + return noise_perlin + +