add min/max_timestep

This commit is contained in:
Kohya S
2023-07-03 20:44:42 +09:00
parent 5863676ccb
commit ea182461d3
7 changed files with 78 additions and 93 deletions

View File

@@ -51,6 +51,7 @@ from diffusers import (
KDPM2DiscreteScheduler,
KDPM2AncestralDiscreteScheduler,
)
from library import custom_train_functions
from library.original_unet import UNet2DConditionModel
from huggingface_hub import hf_hub_download
import albumentations as albu
@@ -2460,6 +2461,19 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
default=None,
help="add `latent mean absolute value * this value` to noise_offset (disabled if None, default) / latentの平均値の絶対値 * この値をnoise_offsetに加算するNoneの場合は無効、デフォルト",
)
parser.add_argument(
"--min_timestep",
type=int,
default=None,
help="set minimum time step for U-Net training (0~999, default is 0) / U-Net学習時のtime stepの最小値を設定する0~999で指定、省略時はデフォルト値(0) ",
)
parser.add_argument(
"--max_timestep",
type=int,
default=None,
help="set maximum time step for U-Net training (1~1000, default is 1000) / U-Net学習時のtime stepの最大値を設定する1~1000で指定、省略時はデフォルト値(1000)",
)
parser.add_argument(
"--lowram",
action="store_true",
@@ -3688,6 +3702,32 @@ def save_sd_model_on_train_end_common(
huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True)
def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:
noise = custom_train_functions.apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale)
elif args.multires_noise_iterations:
noise = custom_train_functions.pyramid_noise_like(
noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount
)
# Sample a random timestep for each image
b_size = latents.shape[0]
min_timestep = 0 if args.min_timestep is None else args.min_timestep
max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep
print(b_size, min_timestep, max_timestep)
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
return noise, noisy_latents, timesteps
# scheduler:
SCHEDULER_LINEAR_START = 0.00085
SCHEDULER_LINEAR_END = 0.0120
@@ -3807,7 +3847,7 @@ def sample_images_common(
clip_skip=args.clip_skip,
)
pipeline.to(device)
save_dir = args.output_dir + "/sample"
os.makedirs(save_dir, exist_ok=True)