mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
revert perlin_noise
This commit is contained in:
@@ -19,6 +19,9 @@ def apply_snr_weight(loss, timesteps, noise_scheduler, gamma):
|
|||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
# TODO train_utilと分散しているのでどちらかに寄せる
|
||||||
|
|
||||||
|
|
||||||
def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted_captions: bool = True):
|
def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted_captions: bool = True):
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--min_snr_gamma",
|
"--min_snr_gamma",
|
||||||
@@ -347,7 +350,7 @@ def get_weighted_text_embeddings(
|
|||||||
|
|
||||||
# https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2
|
# https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2
|
||||||
def pyramid_noise_like(noise, device, iterations=6, discount=0.4):
|
def pyramid_noise_like(noise, device, iterations=6, discount=0.4):
|
||||||
b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant!
|
b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant!
|
||||||
u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device)
|
u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device)
|
||||||
for i in range(iterations):
|
for i in range(iterations):
|
||||||
r = random.random() * 2 + 2 # Rather than always going 2x,
|
r = random.random() * 2 + 2 # Rather than always going 2x,
|
||||||
@@ -369,58 +372,65 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
|
|||||||
|
|
||||||
# multiply adaptive noise scale to the mean value and add it to the noise offset
|
# multiply adaptive noise scale to the mean value and add it to the noise offset
|
||||||
noise_offset = noise_offset + adaptive_noise_scale * latent_mean
|
noise_offset = noise_offset + adaptive_noise_scale * latent_mean
|
||||||
noise_offset = torch.clamp(noise_offset, 0.0, None) # in case of adaptive noise scale is negative
|
noise_offset = torch.clamp(noise_offset, 0.0, None) # in case of adaptive noise scale is negative
|
||||||
|
|
||||||
noise = noise + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
|
noise = noise + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
|
||||||
return noise
|
return noise
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
##########################################
|
##########################################
|
||||||
# Perlin Noise
|
# Perlin Noise
|
||||||
def rand_perlin_2d(device, shape, res, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3):
|
def rand_perlin_2d(device, shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
|
||||||
delta = (res[0] / shape[0], res[1] / shape[1])
|
delta = (res[0] / shape[0], res[1] / shape[1])
|
||||||
d = (shape[0] // res[0], shape[1] // res[1])
|
d = (shape[0] // res[0], shape[1] // res[1])
|
||||||
|
|
||||||
grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0],device=device), torch.arange(0, res[1], delta[1],device=device)), dim=-1) % 1
|
grid = (
|
||||||
angles = 2 * torch.pi * torch.rand(res[0] + 1, res[1] + 1,device=device)
|
torch.stack(
|
||||||
|
torch.meshgrid(torch.arange(0, res[0], delta[0], device=device), torch.arange(0, res[1], delta[1], device=device)),
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
% 1
|
||||||
|
)
|
||||||
|
angles = 2 * torch.pi * torch.rand(res[0] + 1, res[1] + 1, device=device)
|
||||||
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
|
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
|
||||||
|
|
||||||
tile_grads = lambda slice1, slice2: gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]].repeat_interleave(d[0],
|
tile_grads = (
|
||||||
0).repeat_interleave(
|
lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
|
||||||
d[1], 1)
|
.repeat_interleave(d[0], 0)
|
||||||
|
.repeat_interleave(d[1], 1)
|
||||||
|
)
|
||||||
dot = lambda grad, shift: (
|
dot = lambda grad, shift: (
|
||||||
torch.stack((grid[:shape[0], :shape[1], 0] + shift[0], grid[:shape[0], :shape[1], 1] + shift[1]),
|
torch.stack((grid[: shape[0], : shape[1], 0] + shift[0], grid[: shape[0], : shape[1], 1] + shift[1]), dim=-1)
|
||||||
dim=-1) * grad[:shape[0], :shape[1]]).sum(dim=-1)
|
* grad[: shape[0], : shape[1]]
|
||||||
|
).sum(dim=-1)
|
||||||
|
|
||||||
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
|
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
|
||||||
n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
|
n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
|
||||||
n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
|
n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
|
||||||
n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
|
n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
|
||||||
t = fade(grid[:shape[0], :shape[1]])
|
t = fade(grid[: shape[0], : shape[1]])
|
||||||
return 1.414 * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])
|
return 1.414 * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])
|
||||||
|
|
||||||
|
|
||||||
def rand_perlin_2d_octaves(device, shape, res, octaves=1, persistence=0.5):
|
def rand_perlin_2d_octaves(device, shape, res, octaves=1, persistence=0.5):
|
||||||
noise = torch.zeros(shape,device=device)
|
noise = torch.zeros(shape, device=device)
|
||||||
frequency = 1
|
frequency = 1
|
||||||
amplitude = 1
|
amplitude = 1
|
||||||
for _ in range(octaves):
|
for _ in range(octaves):
|
||||||
noise += amplitude * rand_perlin_2d(device, shape, (frequency*res[0], frequency*res[1]))
|
noise += amplitude * rand_perlin_2d(device, shape, (frequency * res[0], frequency * res[1]))
|
||||||
frequency *= 2
|
frequency *= 2
|
||||||
amplitude *= persistence
|
amplitude *= persistence
|
||||||
return noise
|
return noise
|
||||||
|
|
||||||
def perlin_noise(noise, device,octaves):
|
|
||||||
b, c, w, h = noise.shape()
|
|
||||||
perlin = lambda : rand_perlin_2d_octaves(device,(w,h),(4,4),octaves)
|
|
||||||
noise_perlin_r = torch.rand(noise.shape, device=device) + perlin()
|
|
||||||
noise_perlin_g = torch.rand(noise.shape, device=device) + perlin()
|
|
||||||
noise_perlin_b = torch.rand(noise.shape, device=device) + perlin()
|
|
||||||
noise_perlin = torch.cat(
|
|
||||||
(noise_perlin_r,
|
|
||||||
noise_perlin_g,
|
|
||||||
noise_perlin_b),
|
|
||||||
1)
|
|
||||||
return noise_perlin
|
|
||||||
|
|
||||||
|
|
||||||
|
def perlin_noise(noise, device, octaves):
|
||||||
|
_, c, w, h = noise.shape
|
||||||
|
perlin = lambda: rand_perlin_2d_octaves(device, (w, h), (4, 4), octaves)
|
||||||
|
noise_perlin = []
|
||||||
|
for _ in range(c):
|
||||||
|
noise_perlin.append(perlin())
|
||||||
|
noise_perlin = torch.stack(noise_perlin).unsqueeze(0) # (1, c, w, h)
|
||||||
|
noise += noise_perlin # broadcast for each batch
|
||||||
|
return noise / noise.std() # Scaled back to roughly unit variance
|
||||||
|
"""
|
||||||
|
|||||||
@@ -2127,12 +2127,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|||||||
default=None,
|
default=None,
|
||||||
help="enable multires noise with this number of iterations (if enabled, around 6-10 is recommended) / Multires noiseを有効にしてこのイテレーション数を設定する(有効にする場合は6-10程度を推奨)",
|
help="enable multires noise with this number of iterations (if enabled, around 6-10 is recommended) / Multires noiseを有効にしてこのイテレーション数を設定する(有効にする場合は6-10程度を推奨)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
# parser.add_argument(
|
||||||
"--perlin_noise",
|
# "--perlin_noise",
|
||||||
type=int,
|
# type=int,
|
||||||
default=None,
|
# default=None,
|
||||||
help="enable perlin noise and set the octaves",
|
# help="enable perlin noise and set the octaves / perlin noiseを有効にしてoctavesをこの値に設定する",
|
||||||
)
|
# )
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--multires_noise_discount",
|
"--multires_noise_discount",
|
||||||
type=float,
|
type=float,
|
||||||
@@ -2217,15 +2217,21 @@ def verify_training_args(args: argparse.Namespace):
|
|||||||
"cache_latents_to_disk is enabled, so cache_latents is also enabled / cache_latents_to_diskが有効なため、cache_latentsを有効にします"
|
"cache_latents_to_disk is enabled, so cache_latents is also enabled / cache_latents_to_diskが有効なため、cache_latentsを有効にします"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# noise_offset, perlin_noise, multires_noise_iterations cannot be enabled at the same time
|
||||||
|
# Listを使って数えてもいいけど並べてしまえ
|
||||||
if args.noise_offset is not None and args.multires_noise_iterations is not None:
|
if args.noise_offset is not None and args.multires_noise_iterations is not None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"noise_offset and multires_noise_iterations cannot be enabled at the same time / noise_offsetとmultires_noise_iterationsを同時に有効にすることはできません"
|
"noise_offset and multires_noise_iterations cannot be enabled at the same time / noise_offsetとmultires_noise_iterationsを同時に有効にできません"
|
||||||
)
|
)
|
||||||
|
# if args.noise_offset is not None and args.perlin_noise is not None:
|
||||||
|
# raise ValueError("noise_offset and perlin_noise cannot be enabled at the same time / noise_offsetとperlin_noiseは同時に有効にできません")
|
||||||
|
# if args.perlin_noise is not None and args.multires_noise_iterations is not None:
|
||||||
|
# raise ValueError(
|
||||||
|
# "perlin_noise and multires_noise_iterations cannot be enabled at the same time / perlin_noiseとmultires_noise_iterationsを同時に有効にできません"
|
||||||
|
# )
|
||||||
|
|
||||||
if args.adaptive_noise_scale is not None and args.noise_offset is None:
|
if args.adaptive_noise_scale is not None and args.noise_offset is None:
|
||||||
raise ValueError(
|
raise ValueError("adaptive_noise_scale requires noise_offset / adaptive_noise_scaleを使用するにはnoise_offsetが必要です")
|
||||||
"adaptive_noise_scale requires noise_offset / adaptive_noise_scaleを使用するにはnoise_offsetが必要です"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def add_dataset_arguments(
|
def add_dataset_arguments(
|
||||||
@@ -3301,18 +3307,18 @@ def sample_images(
|
|||||||
# with open(args.sample_prompts, "rt", encoding="utf-8") as f:
|
# with open(args.sample_prompts, "rt", encoding="utf-8") as f:
|
||||||
# prompts = f.readlines()
|
# prompts = f.readlines()
|
||||||
|
|
||||||
if args.sample_prompts.endswith('.txt'):
|
if args.sample_prompts.endswith(".txt"):
|
||||||
with open(args.sample_prompts, 'r') as f:
|
with open(args.sample_prompts, "r") as f:
|
||||||
lines = f.readlines()
|
lines = f.readlines()
|
||||||
prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
|
prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
|
||||||
elif args.sample_prompts.endswith('.toml'):
|
elif args.sample_prompts.endswith(".toml"):
|
||||||
with open(args.sample_prompts, 'r') as f:
|
with open(args.sample_prompts, "r") as f:
|
||||||
data = toml.load(f)
|
data = toml.load(f)
|
||||||
prompts = [dict(**data['prompt'], **subset) for subset in data['prompt']['subset']]
|
prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]]
|
||||||
elif args.sample_prompts.endswith('.json'):
|
elif args.sample_prompts.endswith(".json"):
|
||||||
with open(args.sample_prompts, 'r') as f:
|
with open(args.sample_prompts, "r") as f:
|
||||||
prompts = json.load(f)
|
prompts = json.load(f)
|
||||||
|
|
||||||
# schedulerを用意する
|
# schedulerを用意する
|
||||||
sched_init_args = {}
|
sched_init_args = {}
|
||||||
if args.sample_sampler == "ddim":
|
if args.sample_sampler == "ddim":
|
||||||
@@ -3381,7 +3387,7 @@ def sample_images(
|
|||||||
for i, prompt in enumerate(prompts):
|
for i, prompt in enumerate(prompts):
|
||||||
if not accelerator.is_main_process:
|
if not accelerator.is_main_process:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if isinstance(prompt, dict):
|
if isinstance(prompt, dict):
|
||||||
negative_prompt = prompt.get("negative_prompt")
|
negative_prompt = prompt.get("negative_prompt")
|
||||||
sample_steps = prompt.get("sample_steps", 30)
|
sample_steps = prompt.get("sample_steps", 30)
|
||||||
@@ -3390,7 +3396,7 @@ def sample_images(
|
|||||||
scale = prompt.get("scale", 7.5)
|
scale = prompt.get("scale", 7.5)
|
||||||
seed = prompt.get("seed")
|
seed = prompt.get("seed")
|
||||||
prompt = prompt.get("prompt")
|
prompt = prompt.get("prompt")
|
||||||
else:
|
else:
|
||||||
# prompt = prompt.strip()
|
# prompt = prompt.strip()
|
||||||
# if len(prompt) == 0 or prompt[0] == "#":
|
# if len(prompt) == 0 or prompt[0] == "#":
|
||||||
# continue
|
# continue
|
||||||
|
|||||||
13
train_db.py
13
train_db.py
@@ -23,7 +23,14 @@ from library.config_util import (
|
|||||||
BlueprintGenerator,
|
BlueprintGenerator,
|
||||||
)
|
)
|
||||||
import library.custom_train_functions as custom_train_functions
|
import library.custom_train_functions as custom_train_functions
|
||||||
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like, apply_noise_offset,perlin_noise
|
from library.custom_train_functions import (
|
||||||
|
apply_snr_weight,
|
||||||
|
get_weighted_text_embeddings,
|
||||||
|
pyramid_noise_like,
|
||||||
|
apply_noise_offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
# perlin_noise,
|
||||||
|
|
||||||
|
|
||||||
def train(args):
|
def train(args):
|
||||||
@@ -274,8 +281,8 @@ def train(args):
|
|||||||
noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale)
|
noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale)
|
||||||
elif args.multires_noise_iterations:
|
elif args.multires_noise_iterations:
|
||||||
noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount)
|
noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount)
|
||||||
elif args.perlin_noise:
|
# elif args.perlin_noise:
|
||||||
noise = perlin_noise(noise,latents.device,args.perlin_noise)
|
# noise = perlin_noise(noise, latents.device, args.perlin_noise) # only shape of noise is used currently
|
||||||
|
|
||||||
# Get the text embedding for conditioning
|
# Get the text embedding for conditioning
|
||||||
with torch.set_grad_enabled(global_step < args.stop_text_encoder_training):
|
with torch.set_grad_enabled(global_step < args.stop_text_encoder_training):
|
||||||
|
|||||||
Reference in New Issue
Block a user