From 5f1d07d62f4af2c77ef878860b3904537889fb85 Mon Sep 17 00:00:00 2001 From: hkinghuang <178854663@qq.com> Date: Fri, 12 May 2023 21:38:07 +0800 Subject: [PATCH 1/4] init --- library/custom_train_functions.py | 51 +++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 0c527c35..a2303a87 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -373,3 +373,54 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale): noise = noise + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) return noise + + + +########################################## +# Perlin Noise +def rand_perlin_2d(device, shape, res, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3): + delta = (res[0] / shape[0], res[1] / shape[1]) + d = (shape[0] // res[0], shape[1] // res[1]) + + grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0],device=device), torch.arange(0, res[1], delta[1],device=device)), dim=-1) % 1 + angles = 2 * torch.pi * torch.rand(res[0] + 1, res[1] + 1,device=device) + gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1) + + tile_grads = lambda slice1, slice2: gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]].repeat_interleave(d[0], + 0).repeat_interleave( + d[1], 1) + dot = lambda grad, shift: ( + torch.stack((grid[:shape[0], :shape[1], 0] + shift[0], grid[:shape[0], :shape[1], 1] + shift[1]), + dim=-1) * grad[:shape[0], :shape[1]]).sum(dim=-1) + + n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]) + n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]) + n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]) + n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]) + t = fade(grid[:shape[0], :shape[1]]) + return 1.414 * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]) + +def rand_perlin_2d_octaves(device, shape, res, octaves=1, persistence=0.5): + noise = torch.zeros(shape,device=device) + frequency = 1 + amplitude = 1 + for _ in range(octaves): + noise += amplitude * rand_perlin_2d(device, shape, (frequency*res[0], frequency*res[1])) + frequency *= 2 + amplitude *= persistence + return noise + +def perlin_noise(noise, device): + b, c, w, h = noise.shape() + perlin = lambda : rand_perlin_2d_octaves(device,(w,h),(4,4),1) + noise_perlin_r = torch.rand(noise.shape, device=device) + perlin() + noise_perlin_g = torch.rand(noise.shape, device=device) + perlin() + noise_perlin_b = torch.rand(noise.shape, device=device) + perlin() + noise_perlin = torch.cat( + (noise_perlin_r, + noise_perlin_g, + noise_perlin_b), + 2) + return noise_perlin + + From 774c4059fb3cac0ba58d9ec273abe8fcf1465857 Mon Sep 17 00:00:00 2001 From: Linaqruf Date: Sun, 14 May 2023 19:38:44 +0700 Subject: [PATCH 2/4] feat: added toml support for sample prompt --- library/train_util.py | 104 +++++++++++++++++++++++++----------------- 1 file changed, 62 insertions(+), 42 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 2a55a446..438c397b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3291,9 +3291,19 @@ def sample_images( vae.to(device) # read prompts - with open(args.sample_prompts, "rt", encoding="utf-8") as f: - prompts = f.readlines() + # with open(args.sample_prompts, "rt", encoding="utf-8") as f: + # prompts = f.readlines() + + if args.sample_prompts.endswith('.txt'): + with open(args.sample_prompts, 'r') as f: + lines = f.readlines() + prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"] + elif args.sample_prompts.endswith('.toml'): + with open(args.sample_prompts, 'r') as f: + data = toml.load(f) + prompts = [dict(**data['prompt'], **subset) for subset in data['prompt']['subset']] + # schedulerを用意する sched_init_args = {} if args.sample_sampler == "ddim": @@ -3362,53 +3372,63 @@ def sample_images( for i, prompt in enumerate(prompts): if not accelerator.is_main_process: continue - prompt = prompt.strip() - if len(prompt) == 0 or prompt[0] == "#": - continue + + if isinstance(prompt, dict): + negative_prompt = prompt.get("negative_prompt") + sample_steps = prompt.get("sample_steps", 30) + width = prompt.get("width", 512) + height = prompt.get("height", 512) + scale = prompt.get("scale", 7.5) + seed = prompt.get("seed") + prompt = prompt.get("prompt") + else: + # prompt = prompt.strip() + # if len(prompt) == 0 or prompt[0] == "#": + # continue - # subset of gen_img_diffusers - prompt_args = prompt.split(" --") - prompt = prompt_args[0] - negative_prompt = None - sample_steps = 30 - width = height = 512 - scale = 7.5 - seed = None - for parg in prompt_args: - try: - m = re.match(r"w (\d+)", parg, re.IGNORECASE) - if m: - width = int(m.group(1)) - continue + # subset of gen_img_diffusers + prompt_args = prompt.split(" --") + prompt = prompt_args[0] + negative_prompt = None + sample_steps = 30 + width = height = 512 + scale = 7.5 + seed = None + for parg in prompt_args: + try: + m = re.match(r"w (\d+)", parg, re.IGNORECASE) + if m: + width = int(m.group(1)) + continue - m = re.match(r"h (\d+)", parg, re.IGNORECASE) - if m: - height = int(m.group(1)) - continue + m = re.match(r"h (\d+)", parg, re.IGNORECASE) + if m: + height = int(m.group(1)) + continue - m = re.match(r"d (\d+)", parg, re.IGNORECASE) - if m: - seed = int(m.group(1)) - continue + m = re.match(r"d (\d+)", parg, re.IGNORECASE) + if m: + seed = int(m.group(1)) + continue - m = re.match(r"s (\d+)", parg, re.IGNORECASE) - if m: # steps - sample_steps = max(1, min(1000, int(m.group(1)))) - continue + m = re.match(r"s (\d+)", parg, re.IGNORECASE) + if m: # steps + sample_steps = max(1, min(1000, int(m.group(1)))) + continue - m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) - if m: # scale - scale = float(m.group(1)) - continue + m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) + if m: # scale + scale = float(m.group(1)) + continue - m = re.match(r"n (.+)", parg, re.IGNORECASE) - if m: # negative prompt - negative_prompt = m.group(1) - continue + m = re.match(r"n (.+)", parg, re.IGNORECASE) + if m: # negative prompt + negative_prompt = m.group(1) + continue - except ValueError as ex: - print(f"Exception in parsing / 解析エラー: {parg}") - print(ex) + except ValueError as ex: + print(f"Exception in parsing / 解析エラー: {parg}") + print(ex) if seed is not None: torch.manual_seed(seed) From 8ab5c8cb28c8efbe97c76b8abdb5a97779e5ac84 Mon Sep 17 00:00:00 2001 From: Linaqruf Date: Sun, 14 May 2023 19:49:54 +0700 Subject: [PATCH 3/4] feat: added json support as well --- library/train_util.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/library/train_util.py b/library/train_util.py index 438c397b..039f58b5 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3303,6 +3303,9 @@ def sample_images( with open(args.sample_prompts, 'r') as f: data = toml.load(f) prompts = [dict(**data['prompt'], **subset) for subset in data['prompt']['subset']] + elif args.sample_prompts.endswith('.json'): + with open(args.sample_prompts, 'r') as f: + prompts = json.load(f) # schedulerを用意する sched_init_args = {} From bca6a44974414f6dc2f7e32423938dbe09bf50af Mon Sep 17 00:00:00 2001 From: hkinghuang <178854663@qq.com> Date: Mon, 15 May 2023 11:16:08 +0800 Subject: [PATCH 4/4] Perlin noise --- library/custom_train_functions.py | 6 +++--- library/train_util.py | 6 ++++++ train_db.py | 4 +++- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index a2303a87..d9d85d45 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -410,9 +410,9 @@ def rand_perlin_2d_octaves(device, shape, res, octaves=1, persistence=0.5): amplitude *= persistence return noise -def perlin_noise(noise, device): +def perlin_noise(noise, device,octaves): b, c, w, h = noise.shape() - perlin = lambda : rand_perlin_2d_octaves(device,(w,h),(4,4),1) + perlin = lambda : rand_perlin_2d_octaves(device,(w,h),(4,4),octaves) noise_perlin_r = torch.rand(noise.shape, device=device) + perlin() noise_perlin_g = torch.rand(noise.shape, device=device) + perlin() noise_perlin_b = torch.rand(noise.shape, device=device) + perlin() @@ -420,7 +420,7 @@ def perlin_noise(noise, device): (noise_perlin_r, noise_perlin_g, noise_perlin_b), - 2) + 1) return noise_perlin diff --git a/library/train_util.py b/library/train_util.py index 2a55a446..3539a5bd 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2127,6 +2127,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: default=None, help="enable multires noise with this number of iterations (if enabled, around 6-10 is recommended) / Multires noiseを有効にしてこのイテレーション数を設定する(有効にする場合は6-10程度を推奨)", ) + parser.add_argument( + "--perlin_noise", + type=int, + default=None, + help="enable perlin noise and set the octaves", + ) parser.add_argument( "--multires_noise_discount", type=float, diff --git a/train_db.py b/train_db.py index 11af9f6b..5425a488 100644 --- a/train_db.py +++ b/train_db.py @@ -23,7 +23,7 @@ from library.config_util import ( BlueprintGenerator, ) import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like, apply_noise_offset +from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like, apply_noise_offset,perlin_noise def train(args): @@ -274,6 +274,8 @@ def train(args): noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale) elif args.multires_noise_iterations: noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount) + elif args.perlin_noise: + noise = perlin_noise(noise,latents.device,args.perlin_noise) # Get the text embedding for conditioning with torch.set_grad_enabled(global_step < args.stop_text_encoder_training):