mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Merge branch 'dev' of https://github.com/kohya-ss/sd-scripts into dev
This commit is contained in:
@@ -373,3 +373,54 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
|
|||||||
|
|
||||||
noise = noise + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
|
noise = noise + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
|
||||||
return noise
|
return noise
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
##########################################
|
||||||
|
# Perlin Noise
|
||||||
|
def rand_perlin_2d(device, shape, res, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3):
|
||||||
|
delta = (res[0] / shape[0], res[1] / shape[1])
|
||||||
|
d = (shape[0] // res[0], shape[1] // res[1])
|
||||||
|
|
||||||
|
grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0],device=device), torch.arange(0, res[1], delta[1],device=device)), dim=-1) % 1
|
||||||
|
angles = 2 * torch.pi * torch.rand(res[0] + 1, res[1] + 1,device=device)
|
||||||
|
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
|
||||||
|
|
||||||
|
tile_grads = lambda slice1, slice2: gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]].repeat_interleave(d[0],
|
||||||
|
0).repeat_interleave(
|
||||||
|
d[1], 1)
|
||||||
|
dot = lambda grad, shift: (
|
||||||
|
torch.stack((grid[:shape[0], :shape[1], 0] + shift[0], grid[:shape[0], :shape[1], 1] + shift[1]),
|
||||||
|
dim=-1) * grad[:shape[0], :shape[1]]).sum(dim=-1)
|
||||||
|
|
||||||
|
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
|
||||||
|
n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
|
||||||
|
n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
|
||||||
|
n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
|
||||||
|
t = fade(grid[:shape[0], :shape[1]])
|
||||||
|
return 1.414 * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])
|
||||||
|
|
||||||
|
def rand_perlin_2d_octaves(device, shape, res, octaves=1, persistence=0.5):
|
||||||
|
noise = torch.zeros(shape,device=device)
|
||||||
|
frequency = 1
|
||||||
|
amplitude = 1
|
||||||
|
for _ in range(octaves):
|
||||||
|
noise += amplitude * rand_perlin_2d(device, shape, (frequency*res[0], frequency*res[1]))
|
||||||
|
frequency *= 2
|
||||||
|
amplitude *= persistence
|
||||||
|
return noise
|
||||||
|
|
||||||
|
def perlin_noise(noise, device,octaves):
|
||||||
|
b, c, w, h = noise.shape()
|
||||||
|
perlin = lambda : rand_perlin_2d_octaves(device,(w,h),(4,4),octaves)
|
||||||
|
noise_perlin_r = torch.rand(noise.shape, device=device) + perlin()
|
||||||
|
noise_perlin_g = torch.rand(noise.shape, device=device) + perlin()
|
||||||
|
noise_perlin_b = torch.rand(noise.shape, device=device) + perlin()
|
||||||
|
noise_perlin = torch.cat(
|
||||||
|
(noise_perlin_r,
|
||||||
|
noise_perlin_g,
|
||||||
|
noise_perlin_b),
|
||||||
|
1)
|
||||||
|
return noise_perlin
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2127,6 +2127,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|||||||
default=None,
|
default=None,
|
||||||
help="enable multires noise with this number of iterations (if enabled, around 6-10 is recommended) / Multires noiseを有効にしてこのイテレーション数を設定する(有効にする場合は6-10程度を推奨)",
|
help="enable multires noise with this number of iterations (if enabled, around 6-10 is recommended) / Multires noiseを有効にしてこのイテレーション数を設定する(有効にする場合は6-10程度を推奨)",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--perlin_noise",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="enable perlin noise and set the octaves",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--multires_noise_discount",
|
"--multires_noise_discount",
|
||||||
type=float,
|
type=float,
|
||||||
@@ -3291,8 +3297,21 @@ def sample_images(
|
|||||||
vae.to(device)
|
vae.to(device)
|
||||||
|
|
||||||
# read prompts
|
# read prompts
|
||||||
with open(args.sample_prompts, "rt", encoding="utf-8") as f:
|
|
||||||
prompts = f.readlines()
|
# with open(args.sample_prompts, "rt", encoding="utf-8") as f:
|
||||||
|
# prompts = f.readlines()
|
||||||
|
|
||||||
|
if args.sample_prompts.endswith('.txt'):
|
||||||
|
with open(args.sample_prompts, 'r') as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
|
||||||
|
elif args.sample_prompts.endswith('.toml'):
|
||||||
|
with open(args.sample_prompts, 'r') as f:
|
||||||
|
data = toml.load(f)
|
||||||
|
prompts = [dict(**data['prompt'], **subset) for subset in data['prompt']['subset']]
|
||||||
|
elif args.sample_prompts.endswith('.json'):
|
||||||
|
with open(args.sample_prompts, 'r') as f:
|
||||||
|
prompts = json.load(f)
|
||||||
|
|
||||||
# schedulerを用意する
|
# schedulerを用意する
|
||||||
sched_init_args = {}
|
sched_init_args = {}
|
||||||
@@ -3362,53 +3381,63 @@ def sample_images(
|
|||||||
for i, prompt in enumerate(prompts):
|
for i, prompt in enumerate(prompts):
|
||||||
if not accelerator.is_main_process:
|
if not accelerator.is_main_process:
|
||||||
continue
|
continue
|
||||||
prompt = prompt.strip()
|
|
||||||
if len(prompt) == 0 or prompt[0] == "#":
|
|
||||||
continue
|
|
||||||
|
|
||||||
# subset of gen_img_diffusers
|
if isinstance(prompt, dict):
|
||||||
prompt_args = prompt.split(" --")
|
negative_prompt = prompt.get("negative_prompt")
|
||||||
prompt = prompt_args[0]
|
sample_steps = prompt.get("sample_steps", 30)
|
||||||
negative_prompt = None
|
width = prompt.get("width", 512)
|
||||||
sample_steps = 30
|
height = prompt.get("height", 512)
|
||||||
width = height = 512
|
scale = prompt.get("scale", 7.5)
|
||||||
scale = 7.5
|
seed = prompt.get("seed")
|
||||||
seed = None
|
prompt = prompt.get("prompt")
|
||||||
for parg in prompt_args:
|
else:
|
||||||
try:
|
# prompt = prompt.strip()
|
||||||
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
|
# if len(prompt) == 0 or prompt[0] == "#":
|
||||||
if m:
|
# continue
|
||||||
width = int(m.group(1))
|
|
||||||
continue
|
|
||||||
|
|
||||||
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
|
# subset of gen_img_diffusers
|
||||||
if m:
|
prompt_args = prompt.split(" --")
|
||||||
height = int(m.group(1))
|
prompt = prompt_args[0]
|
||||||
continue
|
negative_prompt = None
|
||||||
|
sample_steps = 30
|
||||||
|
width = height = 512
|
||||||
|
scale = 7.5
|
||||||
|
seed = None
|
||||||
|
for parg in prompt_args:
|
||||||
|
try:
|
||||||
|
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
|
||||||
|
if m:
|
||||||
|
width = int(m.group(1))
|
||||||
|
continue
|
||||||
|
|
||||||
m = re.match(r"d (\d+)", parg, re.IGNORECASE)
|
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
|
||||||
if m:
|
if m:
|
||||||
seed = int(m.group(1))
|
height = int(m.group(1))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"s (\d+)", parg, re.IGNORECASE)
|
m = re.match(r"d (\d+)", parg, re.IGNORECASE)
|
||||||
if m: # steps
|
if m:
|
||||||
sample_steps = max(1, min(1000, int(m.group(1))))
|
seed = int(m.group(1))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
|
m = re.match(r"s (\d+)", parg, re.IGNORECASE)
|
||||||
if m: # scale
|
if m: # steps
|
||||||
scale = float(m.group(1))
|
sample_steps = max(1, min(1000, int(m.group(1))))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"n (.+)", parg, re.IGNORECASE)
|
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
if m: # negative prompt
|
if m: # scale
|
||||||
negative_prompt = m.group(1)
|
scale = float(m.group(1))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
except ValueError as ex:
|
m = re.match(r"n (.+)", parg, re.IGNORECASE)
|
||||||
print(f"Exception in parsing / 解析エラー: {parg}")
|
if m: # negative prompt
|
||||||
print(ex)
|
negative_prompt = m.group(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
except ValueError as ex:
|
||||||
|
print(f"Exception in parsing / 解析エラー: {parg}")
|
||||||
|
print(ex)
|
||||||
|
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ from library.config_util import (
|
|||||||
BlueprintGenerator,
|
BlueprintGenerator,
|
||||||
)
|
)
|
||||||
import library.custom_train_functions as custom_train_functions
|
import library.custom_train_functions as custom_train_functions
|
||||||
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like, apply_noise_offset
|
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like, apply_noise_offset,perlin_noise
|
||||||
|
|
||||||
|
|
||||||
def train(args):
|
def train(args):
|
||||||
@@ -274,6 +274,8 @@ def train(args):
|
|||||||
noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale)
|
noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale)
|
||||||
elif args.multires_noise_iterations:
|
elif args.multires_noise_iterations:
|
||||||
noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount)
|
noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount)
|
||||||
|
elif args.perlin_noise:
|
||||||
|
noise = perlin_noise(noise,latents.device,args.perlin_noise)
|
||||||
|
|
||||||
# Get the text embedding for conditioning
|
# Get the text embedding for conditioning
|
||||||
with torch.set_grad_enabled(global_step < args.stop_text_encoder_training):
|
with torch.set_grad_enabled(global_step < args.stop_text_encoder_training):
|
||||||
|
|||||||
Reference in New Issue
Block a user