mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Multi-Resolution Noise
This commit is contained in:
@@ -21,7 +21,7 @@ from library.config_util import (
|
|||||||
BlueprintGenerator,
|
BlueprintGenerator,
|
||||||
)
|
)
|
||||||
import library.custom_train_functions as custom_train_functions
|
import library.custom_train_functions as custom_train_functions
|
||||||
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings
|
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like
|
||||||
|
|
||||||
|
|
||||||
def train(args):
|
def train(args):
|
||||||
@@ -304,6 +304,8 @@ def train(args):
|
|||||||
if args.noise_offset:
|
if args.noise_offset:
|
||||||
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
||||||
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
|
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
|
||||||
|
elif args.multires_noise_iterations:
|
||||||
|
noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount)
|
||||||
|
|
||||||
# Sample a random timestep for each image
|
# Sample a random timestep for each image
|
||||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
|
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import argparse
|
import argparse
|
||||||
|
import random
|
||||||
import re
|
import re
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
@@ -342,3 +343,15 @@ def get_weighted_text_embeddings(
|
|||||||
text_embeddings = text_embeddings * (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
text_embeddings = text_embeddings * (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
||||||
|
|
||||||
return text_embeddings
|
return text_embeddings
|
||||||
|
|
||||||
|
|
||||||
|
# https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2
|
||||||
|
def pyramid_noise_like(noise, device, iterations=6, discount=0.3):
|
||||||
|
b, c, w, h = noise.shape
|
||||||
|
u = torch.nn.Upsample(size=(w, h), mode='bilinear').to(device)
|
||||||
|
for i in range(iterations):
|
||||||
|
r = random.random()*2+2 # Rather than always going 2x,
|
||||||
|
w, h = max(1, int(w/(r**i))), max(1, int(h/(r**i)))
|
||||||
|
noise += u(torch.randn(b, c, w, h).to(device)) * discount**i
|
||||||
|
if w==1 or h==1: break # Lowest resolution is 1x1
|
||||||
|
return noise/noise.std() # Scaled back to roughly unit variance
|
||||||
|
|||||||
@@ -2119,6 +2119,18 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|||||||
default=None,
|
default=None,
|
||||||
help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)",
|
help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--multires_noise_iterations",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="enable multires noise with this number of iterations (if enabled, around 6-10 is recommended)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--multires_noise_discount",
|
||||||
|
type=float,
|
||||||
|
default=0.3,
|
||||||
|
help="set discount value for multires noise (has no effect without --multires_noise_iterations)"
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lowram",
|
"--lowram",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ from library.config_util import (
|
|||||||
BlueprintGenerator,
|
BlueprintGenerator,
|
||||||
)
|
)
|
||||||
import library.custom_train_functions as custom_train_functions
|
import library.custom_train_functions as custom_train_functions
|
||||||
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings
|
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like
|
||||||
|
|
||||||
|
|
||||||
def train(args):
|
def train(args):
|
||||||
@@ -270,6 +270,8 @@ def train(args):
|
|||||||
if args.noise_offset:
|
if args.noise_offset:
|
||||||
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
||||||
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
|
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
|
||||||
|
elif args.multires_noise_iterations:
|
||||||
|
noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount)
|
||||||
|
|
||||||
# Get the text embedding for conditioning
|
# Get the text embedding for conditioning
|
||||||
with torch.set_grad_enabled(global_step < args.stop_text_encoder_training):
|
with torch.set_grad_enabled(global_step < args.stop_text_encoder_training):
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ from library.config_util import (
|
|||||||
)
|
)
|
||||||
import library.huggingface_util as huggingface_util
|
import library.huggingface_util as huggingface_util
|
||||||
import library.custom_train_functions as custom_train_functions
|
import library.custom_train_functions as custom_train_functions
|
||||||
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings
|
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like
|
||||||
|
|
||||||
|
|
||||||
# TODO 他のスクリプトと共通化する
|
# TODO 他のスクリプトと共通化する
|
||||||
@@ -366,6 +366,8 @@ def train(args):
|
|||||||
"ss_seed": args.seed,
|
"ss_seed": args.seed,
|
||||||
"ss_lowram": args.lowram,
|
"ss_lowram": args.lowram,
|
||||||
"ss_noise_offset": args.noise_offset,
|
"ss_noise_offset": args.noise_offset,
|
||||||
|
"ss_multires_noise_iterations": args.multires_noise_iterations,
|
||||||
|
"ss_multires_noise_discount": args.multires_noise_discount,
|
||||||
"ss_training_comment": args.training_comment, # will not be updated after training
|
"ss_training_comment": args.training_comment, # will not be updated after training
|
||||||
"ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(),
|
"ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(),
|
||||||
"ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""),
|
"ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""),
|
||||||
@@ -612,6 +614,8 @@ def train(args):
|
|||||||
if args.noise_offset:
|
if args.noise_offset:
|
||||||
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
||||||
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
|
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
|
||||||
|
elif args.multires_noise_iterations:
|
||||||
|
noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount)
|
||||||
|
|
||||||
# Sample a random timestep for each image
|
# Sample a random timestep for each image
|
||||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
|
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from library.config_util import (
|
|||||||
BlueprintGenerator,
|
BlueprintGenerator,
|
||||||
)
|
)
|
||||||
import library.custom_train_functions as custom_train_functions
|
import library.custom_train_functions as custom_train_functions
|
||||||
from library.custom_train_functions import apply_snr_weight
|
from library.custom_train_functions import apply_snr_weight, pyramid_noise_like
|
||||||
|
|
||||||
imagenet_templates_small = [
|
imagenet_templates_small = [
|
||||||
"a photo of a {}",
|
"a photo of a {}",
|
||||||
@@ -386,6 +386,8 @@ def train(args):
|
|||||||
if args.noise_offset:
|
if args.noise_offset:
|
||||||
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
||||||
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
|
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
|
||||||
|
elif args.multires_noise_iterations:
|
||||||
|
noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount)
|
||||||
|
|
||||||
# Sample a random timestep for each image
|
# Sample a random timestep for each image
|
||||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
|
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from library.config_util import (
|
|||||||
BlueprintGenerator,
|
BlueprintGenerator,
|
||||||
)
|
)
|
||||||
import library.custom_train_functions as custom_train_functions
|
import library.custom_train_functions as custom_train_functions
|
||||||
from library.custom_train_functions import apply_snr_weight
|
from library.custom_train_functions import apply_snr_weight, pyramid_noise_like
|
||||||
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
|
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
|
||||||
|
|
||||||
imagenet_templates_small = [
|
imagenet_templates_small = [
|
||||||
@@ -425,6 +425,8 @@ def train(args):
|
|||||||
if args.noise_offset:
|
if args.noise_offset:
|
||||||
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
||||||
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
|
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
|
||||||
|
elif args.multires_noise_iterations:
|
||||||
|
noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount)
|
||||||
|
|
||||||
# Sample a random timestep for each image
|
# Sample a random timestep for each image
|
||||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
|
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
|
||||||
|
|||||||
Reference in New Issue
Block a user