revert perlin_noise

This commit is contained in:
Kohya S
2023-05-15 23:12:11 +09:00
parent 08d85d4013
commit 714846e1e1
3 changed files with 73 additions and 50 deletions

View File

@@ -19,6 +19,9 @@ def apply_snr_weight(loss, timesteps, noise_scheduler, gamma):
return loss
# TODO train_utilと分散しているのでどちらかに寄せる
def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted_captions: bool = True):
parser.add_argument(
"--min_snr_gamma",
@@ -375,23 +378,32 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
return noise
"""
##########################################
# Perlin Noise
def rand_perlin_2d(device, shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
delta = (res[0] / shape[0], res[1] / shape[1])
d = (shape[0] // res[0], shape[1] // res[1])
grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0],device=device), torch.arange(0, res[1], delta[1],device=device)), dim=-1) % 1
grid = (
torch.stack(
torch.meshgrid(torch.arange(0, res[0], delta[0], device=device), torch.arange(0, res[1], delta[1], device=device)),
dim=-1,
)
% 1
)
angles = 2 * torch.pi * torch.rand(res[0] + 1, res[1] + 1, device=device)
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
tile_grads = lambda slice1, slice2: gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]].repeat_interleave(d[0],
0).repeat_interleave(
d[1], 1)
tile_grads = (
lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
.repeat_interleave(d[0], 0)
.repeat_interleave(d[1], 1)
)
dot = lambda grad, shift: (
torch.stack((grid[:shape[0], :shape[1], 0] + shift[0], grid[:shape[0], :shape[1], 1] + shift[1]),
dim=-1) * grad[:shape[0], :shape[1]]).sum(dim=-1)
torch.stack((grid[: shape[0], : shape[1], 0] + shift[0], grid[: shape[0], : shape[1], 1] + shift[1]), dim=-1)
* grad[: shape[0], : shape[1]]
).sum(dim=-1)
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
@@ -400,6 +412,7 @@ def rand_perlin_2d(device, shape, res, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 +
t = fade(grid[: shape[0], : shape[1]])
return 1.414 * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])
def rand_perlin_2d_octaves(device, shape, res, octaves=1, persistence=0.5):
noise = torch.zeros(shape, device=device)
frequency = 1
@@ -410,17 +423,14 @@ def rand_perlin_2d_octaves(device, shape, res, octaves=1, persistence=0.5):
amplitude *= persistence
return noise
def perlin_noise(noise, device, octaves):
b, c, w, h = noise.shape()
_, c, w, h = noise.shape
perlin = lambda: rand_perlin_2d_octaves(device, (w, h), (4, 4), octaves)
noise_perlin_r = torch.rand(noise.shape, device=device) + perlin()
noise_perlin_g = torch.rand(noise.shape, device=device) + perlin()
noise_perlin_b = torch.rand(noise.shape, device=device) + perlin()
noise_perlin = torch.cat(
(noise_perlin_r,
noise_perlin_g,
noise_perlin_b),
1)
return noise_perlin
noise_perlin = []
for _ in range(c):
noise_perlin.append(perlin())
noise_perlin = torch.stack(noise_perlin).unsqueeze(0) # (1, c, w, h)
noise += noise_perlin # broadcast for each batch
return noise / noise.std() # Scaled back to roughly unit variance
"""

View File

@@ -2127,12 +2127,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
default=None,
help="enable multires noise with this number of iterations (if enabled, around 6-10 is recommended) / Multires noiseを有効にしてこのイテレーション数を設定する有効にする場合は6-10程度を推奨",
)
parser.add_argument(
"--perlin_noise",
type=int,
default=None,
help="enable perlin noise and set the octaves",
)
# parser.add_argument(
# "--perlin_noise",
# type=int,
# default=None,
# help="enable perlin noise and set the octaves / perlin noiseを有効にしてoctavesをこの値に設定する",
# )
parser.add_argument(
"--multires_noise_discount",
type=float,
@@ -2217,15 +2217,21 @@ def verify_training_args(args: argparse.Namespace):
"cache_latents_to_disk is enabled, so cache_latents is also enabled / cache_latents_to_diskが有効なため、cache_latentsを有効にします"
)
# noise_offset, perlin_noise, multires_noise_iterations cannot be enabled at the same time
# Listを使って数えてもいいけど並べてしまえ
if args.noise_offset is not None and args.multires_noise_iterations is not None:
raise ValueError(
"noise_offset and multires_noise_iterations cannot be enabled at the same time / noise_offsetとmultires_noise_iterationsを同時に有効にすることはできません"
"noise_offset and multires_noise_iterations cannot be enabled at the same time / noise_offsetとmultires_noise_iterationsを同時に有効にできません"
)
# if args.noise_offset is not None and args.perlin_noise is not None:
# raise ValueError("noise_offset and perlin_noise cannot be enabled at the same time / noise_offsetとperlin_noiseは同時に有効にできません")
# if args.perlin_noise is not None and args.multires_noise_iterations is not None:
# raise ValueError(
# "perlin_noise and multires_noise_iterations cannot be enabled at the same time / perlin_noiseとmultires_noise_iterationsを同時に有効にできません"
# )
if args.adaptive_noise_scale is not None and args.noise_offset is None:
raise ValueError(
"adaptive_noise_scale requires noise_offset / adaptive_noise_scaleを使用するにはnoise_offsetが必要です"
)
raise ValueError("adaptive_noise_scale requires noise_offset / adaptive_noise_scaleを使用するにはnoise_offsetが必要です")
def add_dataset_arguments(
@@ -3301,16 +3307,16 @@ def sample_images(
# with open(args.sample_prompts, "rt", encoding="utf-8") as f:
# prompts = f.readlines()
if args.sample_prompts.endswith('.txt'):
with open(args.sample_prompts, 'r') as f:
if args.sample_prompts.endswith(".txt"):
with open(args.sample_prompts, "r") as f:
lines = f.readlines()
prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
elif args.sample_prompts.endswith('.toml'):
with open(args.sample_prompts, 'r') as f:
elif args.sample_prompts.endswith(".toml"):
with open(args.sample_prompts, "r") as f:
data = toml.load(f)
prompts = [dict(**data['prompt'], **subset) for subset in data['prompt']['subset']]
elif args.sample_prompts.endswith('.json'):
with open(args.sample_prompts, 'r') as f:
prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]]
elif args.sample_prompts.endswith(".json"):
with open(args.sample_prompts, "r") as f:
prompts = json.load(f)
# schedulerを用意する

View File

@@ -23,7 +23,14 @@ from library.config_util import (
BlueprintGenerator,
)
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like, apply_noise_offset,perlin_noise
from library.custom_train_functions import (
apply_snr_weight,
get_weighted_text_embeddings,
pyramid_noise_like,
apply_noise_offset,
)
# perlin_noise,
def train(args):
@@ -274,8 +281,8 @@ def train(args):
noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale)
elif args.multires_noise_iterations:
noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount)
elif args.perlin_noise:
noise = perlin_noise(noise,latents.device,args.perlin_noise)
# elif args.perlin_noise:
# noise = perlin_noise(noise, latents.device, args.perlin_noise) # only shape of noise is used currently
# Get the text embedding for conditioning
with torch.set_grad_enabled(global_step < args.stop_text_encoder_training):