revert perlin_noise

This commit is contained in:
Kohya S
2023-05-15 23:12:11 +09:00
parent 08d85d4013
commit 714846e1e1
3 changed files with 73 additions and 50 deletions

View File

@@ -19,6 +19,9 @@ def apply_snr_weight(loss, timesteps, noise_scheduler, gamma):
return loss return loss
# TODO train_utilと分散しているのでどちらかに寄せる
def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted_captions: bool = True): def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted_captions: bool = True):
parser.add_argument( parser.add_argument(
"--min_snr_gamma", "--min_snr_gamma",
@@ -347,7 +350,7 @@ def get_weighted_text_embeddings(
# https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2 # https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2
def pyramid_noise_like(noise, device, iterations=6, discount=0.4): def pyramid_noise_like(noise, device, iterations=6, discount=0.4):
b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant! b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant!
u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device) u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device)
for i in range(iterations): for i in range(iterations):
r = random.random() * 2 + 2 # Rather than always going 2x, r = random.random() * 2 + 2 # Rather than always going 2x,
@@ -369,58 +372,65 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
# multiply adaptive noise scale to the mean value and add it to the noise offset # multiply adaptive noise scale to the mean value and add it to the noise offset
noise_offset = noise_offset + adaptive_noise_scale * latent_mean noise_offset = noise_offset + adaptive_noise_scale * latent_mean
noise_offset = torch.clamp(noise_offset, 0.0, None) # in case of adaptive noise scale is negative noise_offset = torch.clamp(noise_offset, 0.0, None) # in case of adaptive noise scale is negative
noise = noise + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) noise = noise + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
return noise return noise
"""
########################################## ##########################################
# Perlin Noise # Perlin Noise
def rand_perlin_2d(device, shape, res, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3): def rand_perlin_2d(device, shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
delta = (res[0] / shape[0], res[1] / shape[1]) delta = (res[0] / shape[0], res[1] / shape[1])
d = (shape[0] // res[0], shape[1] // res[1]) d = (shape[0] // res[0], shape[1] // res[1])
grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0],device=device), torch.arange(0, res[1], delta[1],device=device)), dim=-1) % 1 grid = (
angles = 2 * torch.pi * torch.rand(res[0] + 1, res[1] + 1,device=device) torch.stack(
torch.meshgrid(torch.arange(0, res[0], delta[0], device=device), torch.arange(0, res[1], delta[1], device=device)),
dim=-1,
)
% 1
)
angles = 2 * torch.pi * torch.rand(res[0] + 1, res[1] + 1, device=device)
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1) gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
tile_grads = lambda slice1, slice2: gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]].repeat_interleave(d[0], tile_grads = (
0).repeat_interleave( lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
d[1], 1) .repeat_interleave(d[0], 0)
.repeat_interleave(d[1], 1)
)
dot = lambda grad, shift: ( dot = lambda grad, shift: (
torch.stack((grid[:shape[0], :shape[1], 0] + shift[0], grid[:shape[0], :shape[1], 1] + shift[1]), torch.stack((grid[: shape[0], : shape[1], 0] + shift[0], grid[: shape[0], : shape[1], 1] + shift[1]), dim=-1)
dim=-1) * grad[:shape[0], :shape[1]]).sum(dim=-1) * grad[: shape[0], : shape[1]]
).sum(dim=-1)
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]) n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]) n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]) n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]) n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
t = fade(grid[:shape[0], :shape[1]]) t = fade(grid[: shape[0], : shape[1]])
return 1.414 * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]) return 1.414 * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])
def rand_perlin_2d_octaves(device, shape, res, octaves=1, persistence=0.5): def rand_perlin_2d_octaves(device, shape, res, octaves=1, persistence=0.5):
noise = torch.zeros(shape,device=device) noise = torch.zeros(shape, device=device)
frequency = 1 frequency = 1
amplitude = 1 amplitude = 1
for _ in range(octaves): for _ in range(octaves):
noise += amplitude * rand_perlin_2d(device, shape, (frequency*res[0], frequency*res[1])) noise += amplitude * rand_perlin_2d(device, shape, (frequency * res[0], frequency * res[1]))
frequency *= 2 frequency *= 2
amplitude *= persistence amplitude *= persistence
return noise return noise
def perlin_noise(noise, device,octaves):
b, c, w, h = noise.shape()
perlin = lambda : rand_perlin_2d_octaves(device,(w,h),(4,4),octaves)
noise_perlin_r = torch.rand(noise.shape, device=device) + perlin()
noise_perlin_g = torch.rand(noise.shape, device=device) + perlin()
noise_perlin_b = torch.rand(noise.shape, device=device) + perlin()
noise_perlin = torch.cat(
(noise_perlin_r,
noise_perlin_g,
noise_perlin_b),
1)
return noise_perlin
def perlin_noise(noise, device, octaves):
_, c, w, h = noise.shape
perlin = lambda: rand_perlin_2d_octaves(device, (w, h), (4, 4), octaves)
noise_perlin = []
for _ in range(c):
noise_perlin.append(perlin())
noise_perlin = torch.stack(noise_perlin).unsqueeze(0) # (1, c, w, h)
noise += noise_perlin # broadcast for each batch
return noise / noise.std() # Scaled back to roughly unit variance
"""

View File

@@ -2127,12 +2127,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
default=None, default=None,
help="enable multires noise with this number of iterations (if enabled, around 6-10 is recommended) / Multires noiseを有効にしてこのイテレーション数を設定する有効にする場合は6-10程度を推奨", help="enable multires noise with this number of iterations (if enabled, around 6-10 is recommended) / Multires noiseを有効にしてこのイテレーション数を設定する有効にする場合は6-10程度を推奨",
) )
parser.add_argument( # parser.add_argument(
"--perlin_noise", # "--perlin_noise",
type=int, # type=int,
default=None, # default=None,
help="enable perlin noise and set the octaves", # help="enable perlin noise and set the octaves / perlin noiseを有効にしてoctavesをこの値に設定する",
) # )
parser.add_argument( parser.add_argument(
"--multires_noise_discount", "--multires_noise_discount",
type=float, type=float,
@@ -2217,15 +2217,21 @@ def verify_training_args(args: argparse.Namespace):
"cache_latents_to_disk is enabled, so cache_latents is also enabled / cache_latents_to_diskが有効なため、cache_latentsを有効にします" "cache_latents_to_disk is enabled, so cache_latents is also enabled / cache_latents_to_diskが有効なため、cache_latentsを有効にします"
) )
# noise_offset, perlin_noise, multires_noise_iterations cannot be enabled at the same time
# Listを使って数えてもいいけど並べてしまえ
if args.noise_offset is not None and args.multires_noise_iterations is not None: if args.noise_offset is not None and args.multires_noise_iterations is not None:
raise ValueError( raise ValueError(
"noise_offset and multires_noise_iterations cannot be enabled at the same time / noise_offsetとmultires_noise_iterationsを同時に有効にすることはできません" "noise_offset and multires_noise_iterations cannot be enabled at the same time / noise_offsetとmultires_noise_iterationsを同時に有効にできません"
) )
# if args.noise_offset is not None and args.perlin_noise is not None:
# raise ValueError("noise_offset and perlin_noise cannot be enabled at the same time / noise_offsetとperlin_noiseは同時に有効にできません")
# if args.perlin_noise is not None and args.multires_noise_iterations is not None:
# raise ValueError(
# "perlin_noise and multires_noise_iterations cannot be enabled at the same time / perlin_noiseとmultires_noise_iterationsを同時に有効にできません"
# )
if args.adaptive_noise_scale is not None and args.noise_offset is None: if args.adaptive_noise_scale is not None and args.noise_offset is None:
raise ValueError( raise ValueError("adaptive_noise_scale requires noise_offset / adaptive_noise_scaleを使用するにはnoise_offsetが必要です")
"adaptive_noise_scale requires noise_offset / adaptive_noise_scaleを使用するにはnoise_offsetが必要です"
)
def add_dataset_arguments( def add_dataset_arguments(
@@ -3301,18 +3307,18 @@ def sample_images(
# with open(args.sample_prompts, "rt", encoding="utf-8") as f: # with open(args.sample_prompts, "rt", encoding="utf-8") as f:
# prompts = f.readlines() # prompts = f.readlines()
if args.sample_prompts.endswith('.txt'): if args.sample_prompts.endswith(".txt"):
with open(args.sample_prompts, 'r') as f: with open(args.sample_prompts, "r") as f:
lines = f.readlines() lines = f.readlines()
prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"] prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
elif args.sample_prompts.endswith('.toml'): elif args.sample_prompts.endswith(".toml"):
with open(args.sample_prompts, 'r') as f: with open(args.sample_prompts, "r") as f:
data = toml.load(f) data = toml.load(f)
prompts = [dict(**data['prompt'], **subset) for subset in data['prompt']['subset']] prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]]
elif args.sample_prompts.endswith('.json'): elif args.sample_prompts.endswith(".json"):
with open(args.sample_prompts, 'r') as f: with open(args.sample_prompts, "r") as f:
prompts = json.load(f) prompts = json.load(f)
# schedulerを用意する # schedulerを用意する
sched_init_args = {} sched_init_args = {}
if args.sample_sampler == "ddim": if args.sample_sampler == "ddim":
@@ -3381,7 +3387,7 @@ def sample_images(
for i, prompt in enumerate(prompts): for i, prompt in enumerate(prompts):
if not accelerator.is_main_process: if not accelerator.is_main_process:
continue continue
if isinstance(prompt, dict): if isinstance(prompt, dict):
negative_prompt = prompt.get("negative_prompt") negative_prompt = prompt.get("negative_prompt")
sample_steps = prompt.get("sample_steps", 30) sample_steps = prompt.get("sample_steps", 30)
@@ -3390,7 +3396,7 @@ def sample_images(
scale = prompt.get("scale", 7.5) scale = prompt.get("scale", 7.5)
seed = prompt.get("seed") seed = prompt.get("seed")
prompt = prompt.get("prompt") prompt = prompt.get("prompt")
else: else:
# prompt = prompt.strip() # prompt = prompt.strip()
# if len(prompt) == 0 or prompt[0] == "#": # if len(prompt) == 0 or prompt[0] == "#":
# continue # continue

View File

@@ -23,7 +23,14 @@ from library.config_util import (
BlueprintGenerator, BlueprintGenerator,
) )
import library.custom_train_functions as custom_train_functions import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like, apply_noise_offset,perlin_noise from library.custom_train_functions import (
apply_snr_weight,
get_weighted_text_embeddings,
pyramid_noise_like,
apply_noise_offset,
)
# perlin_noise,
def train(args): def train(args):
@@ -274,8 +281,8 @@ def train(args):
noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale) noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale)
elif args.multires_noise_iterations: elif args.multires_noise_iterations:
noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount) noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount)
elif args.perlin_noise: # elif args.perlin_noise:
noise = perlin_noise(noise,latents.device,args.perlin_noise) # noise = perlin_noise(noise, latents.device, args.perlin_noise) # only shape of noise is used currently
# Get the text embedding for conditioning # Get the text embedding for conditioning
with torch.set_grad_enabled(global_step < args.stop_text_encoder_training): with torch.set_grad_enabled(global_step < args.stop_text_encoder_training):