Merge pull request #750 from kohya-ss/dev

block lr for U-Net with SDXL etc.
This commit is contained in:
Kohya S
2023-08-12 13:17:06 +09:00
committed by GitHub
7 changed files with 176 additions and 38 deletions

View File

@@ -52,6 +52,10 @@ def main(args):
# assert args.bucket_reso_steps % 8 == 0, f"bucket_reso_steps must be divisible by 8 / bucket_reso_stepは8で割り切れる必要があります" # assert args.bucket_reso_steps % 8 == 0, f"bucket_reso_steps must be divisible by 8 / bucket_reso_stepは8で割り切れる必要があります"
if args.bucket_reso_steps % 8 > 0: if args.bucket_reso_steps % 8 > 0:
print(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります") print(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります")
if args.bucket_reso_steps % 32 > 0:
print(
f"WARNING: bucket_reso_steps is not divisible by 32. It is not working with SDXL / bucket_reso_stepsが32で割り切れません。SDXLでは動作しません"
)
train_data_dir_path = Path(args.train_data_dir) train_data_dir_path = Path(args.train_data_dir)
image_paths: List[str] = [str(p) for p in train_util.glob_images_pathlib(train_data_dir_path, args.recursive)] image_paths: List[str] = [str(p) for p in train_util.glob_images_pathlib(train_data_dir_path, args.recursive)]

View File

@@ -13,7 +13,7 @@ from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeigh
TOKENIZER1_PATH = "openai/clip-vit-large-patch14" TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
DEFAULT_NOISE_OFFSET = 0.0357 # DEFAULT_NOISE_OFFSET = 0.0357
def load_target_model(args, accelerator, model_version: str, weight_dtype): def load_target_model(args, accelerator, model_version: str, weight_dtype):
@@ -312,18 +312,18 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin
if args.clip_skip is not None: if args.clip_skip is not None:
print("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません") print("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません")
if args.multires_noise_iterations: # if args.multires_noise_iterations:
print( # print(
f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET}, but noise_offset is disabled due to multires_noise_iterations / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されていますが、multires_noise_iterationsが有効になっているためnoise_offsetは無効になります" # f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET}, but noise_offset is disabled due to multires_noise_iterations / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されていますが、multires_noise_iterationsが有効になっているためnoise_offsetは無効になります"
) # )
else: # else:
if args.noise_offset is None: # if args.noise_offset is None:
args.noise_offset = DEFAULT_NOISE_OFFSET # args.noise_offset = DEFAULT_NOISE_OFFSET
elif args.noise_offset != DEFAULT_NOISE_OFFSET: # elif args.noise_offset != DEFAULT_NOISE_OFFSET:
print( # print(
f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET} / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されています" # f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET} / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されています"
) # )
print(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました") # print(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました")
assert ( assert (
not hasattr(args, "weighted_captions") or not args.weighted_captions not hasattr(args, "weighted_captions") or not args.weighted_captions

View File

@@ -800,6 +800,12 @@ class BaseDataset(torch.utils.data.Dataset):
random.shuffle(self.buckets_indices) random.shuffle(self.buckets_indices)
self.bucket_manager.shuffle() self.bucket_manager.shuffle()
def verify_bucket_reso_steps(self, min_steps: int):
assert self.bucket_reso_steps is None or self.bucket_reso_steps % min_steps == 0, (
f"bucket_reso_steps is {self.bucket_reso_steps}. it must be divisible by {min_steps}.\n"
+ f"bucket_reso_stepsが{self.bucket_reso_steps}です。{min_steps}で割り切れる必要があります"
)
def is_latent_cacheable(self): def is_latent_cacheable(self):
return all([not subset.color_aug and not subset.random_crop for subset in self.subsets]) return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])
@@ -1831,6 +1837,10 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
for dataset in self.datasets: for dataset in self.datasets:
dataset.set_caching_mode(caching_mode) dataset.set_caching_mode(caching_mode)
def verify_bucket_reso_steps(self, min_steps: int):
for dataset in self.datasets:
dataset.verify_bucket_reso_steps(min_steps)
def is_latent_cacheable(self) -> bool: def is_latent_cacheable(self) -> bool:
return all([dataset.is_latent_cacheable() for dataset in self.datasets]) return all([dataset.is_latent_cacheable() for dataset in self.datasets])
@@ -2020,6 +2030,9 @@ class MinimalDataset(BaseDataset):
self.is_reg = False self.is_reg = False
self.image_dir = "dummy" # for metadata self.image_dir = "dummy" # for metadata
def verify_bucket_reso_steps(self, min_steps: int):
pass
def is_latent_cacheable(self) -> bool: def is_latent_cacheable(self) -> bool:
return False return False
@@ -2981,11 +2994,11 @@ def verify_training_args(args: argparse.Namespace):
) )
# noise_offset, perlin_noise, multires_noise_iterations cannot be enabled at the same time # noise_offset, perlin_noise, multires_noise_iterations cannot be enabled at the same time
# Listを使って数えてもいいけど並べてしまえ # # Listを使って数えてもいいけど並べてしまえ
if args.noise_offset is not None and args.multires_noise_iterations is not None: # if args.noise_offset is not None and args.multires_noise_iterations is not None:
raise ValueError( # raise ValueError(
"noise_offset and multires_noise_iterations cannot be enabled at the same time / noise_offsetとmultires_noise_iterationsを同時に有効にできません" # "noise_offset and multires_noise_iterations cannot be enabled at the same time / noise_offsetとmultires_noise_iterationsを同時に有効にできません"
) # )
# if args.noise_offset is not None and args.perlin_noise is not None: # if args.noise_offset is not None and args.perlin_noise is not None:
# raise ValueError("noise_offset and perlin_noise cannot be enabled at the same time / noise_offsetとperlin_noiseは同時に有効にできません") # raise ValueError("noise_offset and perlin_noise cannot be enabled at the same time / noise_offsetとperlin_noiseは同時に有効にできません")
# if args.perlin_noise is not None and args.multires_noise_iterations is not None: # if args.perlin_noise is not None and args.multires_noise_iterations is not None:
@@ -4268,7 +4281,7 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
noise = torch.randn_like(latents, device=latents.device) noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset: if args.noise_offset:
noise = custom_train_functions.apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale) noise = custom_train_functions.apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale)
elif args.multires_noise_iterations: if args.multires_noise_iterations:
noise = custom_train_functions.pyramid_noise_like( noise = custom_train_functions.pyramid_noise_like(
noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount
) )

View File

@@ -1309,7 +1309,10 @@ def main(args):
# schedulerを用意する # schedulerを用意する
sched_init_args = {} sched_init_args = {}
has_steps_offset = True
has_clip_sample = True
scheduler_num_noises_per_step = 1 scheduler_num_noises_per_step = 1
if args.sampler == "ddim": if args.sampler == "ddim":
scheduler_cls = DDIMScheduler scheduler_cls = DDIMScheduler
scheduler_module = diffusers.schedulers.scheduling_ddim scheduler_module = diffusers.schedulers.scheduling_ddim
@@ -1319,32 +1322,48 @@ def main(args):
elif args.sampler == "pndm": elif args.sampler == "pndm":
scheduler_cls = PNDMScheduler scheduler_cls = PNDMScheduler
scheduler_module = diffusers.schedulers.scheduling_pndm scheduler_module = diffusers.schedulers.scheduling_pndm
has_clip_sample = False
elif args.sampler == "lms" or args.sampler == "k_lms": elif args.sampler == "lms" or args.sampler == "k_lms":
scheduler_cls = LMSDiscreteScheduler scheduler_cls = LMSDiscreteScheduler
scheduler_module = diffusers.schedulers.scheduling_lms_discrete scheduler_module = diffusers.schedulers.scheduling_lms_discrete
has_clip_sample = False
elif args.sampler == "euler" or args.sampler == "k_euler": elif args.sampler == "euler" or args.sampler == "k_euler":
scheduler_cls = EulerDiscreteScheduler scheduler_cls = EulerDiscreteScheduler
scheduler_module = diffusers.schedulers.scheduling_euler_discrete scheduler_module = diffusers.schedulers.scheduling_euler_discrete
has_clip_sample = False
elif args.sampler == "euler_a" or args.sampler == "k_euler_a": elif args.sampler == "euler_a" or args.sampler == "k_euler_a":
scheduler_cls = EulerAncestralDiscreteScheduler scheduler_cls = EulerAncestralDiscreteScheduler
scheduler_module = diffusers.schedulers.scheduling_euler_ancestral_discrete scheduler_module = diffusers.schedulers.scheduling_euler_ancestral_discrete
has_clip_sample = False
elif args.sampler == "dpmsolver" or args.sampler == "dpmsolver++": elif args.sampler == "dpmsolver" or args.sampler == "dpmsolver++":
scheduler_cls = DPMSolverMultistepScheduler scheduler_cls = DPMSolverMultistepScheduler
sched_init_args["algorithm_type"] = args.sampler sched_init_args["algorithm_type"] = args.sampler
scheduler_module = diffusers.schedulers.scheduling_dpmsolver_multistep scheduler_module = diffusers.schedulers.scheduling_dpmsolver_multistep
has_clip_sample = False
elif args.sampler == "dpmsingle": elif args.sampler == "dpmsingle":
scheduler_cls = DPMSolverSinglestepScheduler scheduler_cls = DPMSolverSinglestepScheduler
scheduler_module = diffusers.schedulers.scheduling_dpmsolver_singlestep scheduler_module = diffusers.schedulers.scheduling_dpmsolver_singlestep
has_clip_sample = False
has_steps_offset = False
elif args.sampler == "heun": elif args.sampler == "heun":
scheduler_cls = HeunDiscreteScheduler scheduler_cls = HeunDiscreteScheduler
scheduler_module = diffusers.schedulers.scheduling_heun_discrete scheduler_module = diffusers.schedulers.scheduling_heun_discrete
has_clip_sample = False
elif args.sampler == "dpm_2" or args.sampler == "k_dpm_2": elif args.sampler == "dpm_2" or args.sampler == "k_dpm_2":
scheduler_cls = KDPM2DiscreteScheduler scheduler_cls = KDPM2DiscreteScheduler
scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_discrete scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_discrete
has_clip_sample = False
elif args.sampler == "dpm_2_a" or args.sampler == "k_dpm_2_a": elif args.sampler == "dpm_2_a" or args.sampler == "k_dpm_2_a":
scheduler_cls = KDPM2AncestralDiscreteScheduler scheduler_cls = KDPM2AncestralDiscreteScheduler
scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete
scheduler_num_noises_per_step = 2 scheduler_num_noises_per_step = 2
has_clip_sample = False
# 警告を出さないようにする
if has_steps_offset:
sched_init_args["steps_offset"] = 1
if has_clip_sample:
sched_init_args["clip_sample"] = False
# samplerの乱数をあらかじめ指定するための処理 # samplerの乱数をあらかじめ指定するための処理
@@ -1397,10 +1416,11 @@ def main(args):
**sched_init_args, **sched_init_args,
) )
# clip_sample=Trueにする # ↓以下は結局PipeでFalseに設定されるので意味がなかった
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: # # clip_sample=Trueにする
print("set clip_sample to True") # if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
scheduler.config.clip_sample = True # print("set clip_sample to True")
# scheduler.config.clip_sample = True
# deviceを決定する # deviceを決定する
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # "mps"を考量してない device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # "mps"を考量してない

View File

@@ -5,6 +5,7 @@ import gc
import math import math
import os import os
from multiprocessing import Value from multiprocessing import Value
from typing import List
import toml import toml
from tqdm import tqdm from tqdm import tqdm
@@ -30,6 +31,67 @@ from library.custom_train_functions import (
from library.sdxl_original_unet import SdxlUNet2DConditionModel from library.sdxl_original_unet import SdxlUNet2DConditionModel
UNET_NUM_BLOCKS_FOR_BLOCK_LR = 23
def get_block_params_to_optimize(unet: SdxlUNet2DConditionModel, block_lrs: List[float]) -> List[dict]:
block_params = [[] for _ in range(len(block_lrs))]
for i, (name, param) in enumerate(unet.named_parameters()):
if name.startswith("time_embed.") or name.startswith("label_emb."):
block_index = 0 # 0
elif name.startswith("input_blocks."): # 1-9
block_index = 1 + int(name.split(".")[1])
elif name.startswith("middle_block."): # 10-12
block_index = 10 + int(name.split(".")[1])
elif name.startswith("output_blocks."): # 13-21
block_index = 13 + int(name.split(".")[1])
elif name.startswith("out."): # 22
block_index = 22
else:
raise ValueError(f"unexpected parameter name: {name}")
block_params[block_index].append(param)
params_to_optimize = []
for i, params in enumerate(block_params):
if block_lrs[i] == 0: # 0のときは学習しない do not optimize when lr is 0
continue
params_to_optimize.append({"params": params, "lr": block_lrs[i]})
return params_to_optimize
def append_block_lr_to_logs(block_lrs, logs, lr_scheduler, optimizer_type):
lrs = lr_scheduler.get_last_lr()
lr_index = 0
block_index = 0
while lr_index < len(lrs):
if block_index < UNET_NUM_BLOCKS_FOR_BLOCK_LR:
name = f"block{block_index}"
if block_lrs[block_index] == 0:
block_index += 1
continue
elif block_index == UNET_NUM_BLOCKS_FOR_BLOCK_LR:
name = "text_encoder1"
elif block_index == UNET_NUM_BLOCKS_FOR_BLOCK_LR + 1:
name = "text_encoder2"
else:
raise ValueError(f"unexpected block_index: {block_index}")
block_index += 1
logs["lr/" + name] = float(lrs[lr_index])
if optimizer_type.lower().startswith("DAdapt".lower()) or optimizer_type.lower() == "Prodigy".lower():
logs["lr/d*lr/" + name] = (
lr_scheduler.optimizers[-1].param_groups[lr_index]["d"] * lr_scheduler.optimizers[-1].param_groups[lr_index]["lr"]
)
lr_index += 1
def train(args): def train(args):
train_util.verify_training_args(args) train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True) train_util.prepare_dataset_args(args, True)
@@ -40,6 +102,14 @@ def train(args):
not args.train_text_encoder or not args.cache_text_encoder_outputs not args.train_text_encoder or not args.cache_text_encoder_outputs
), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません" ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません"
if args.block_lr:
block_lrs = [float(lr) for lr in args.block_lr.split(",")]
assert (
len(block_lrs) == UNET_NUM_BLOCKS_FOR_BLOCK_LR
), f"block_lr must have {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / block_lrは{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値を指定してください"
else:
block_lrs = None
cache_latents = args.cache_latents cache_latents = args.cache_latents
use_dreambooth_method = args.in_json is None use_dreambooth_method = args.in_json is None
@@ -98,6 +168,8 @@ def train(args):
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater) collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
train_dataset_group.verify_bucket_reso_steps(32)
if args.debug_dataset: if args.debug_dataset:
train_util.debug_dataset(train_dataset_group, True) train_util.debug_dataset(train_dataset_group, True)
return return
@@ -233,15 +305,28 @@ def train(args):
for m in training_models: for m in training_models:
m.requires_grad_(True) m.requires_grad_(True)
params = []
for m in training_models:
params.extend(m.parameters())
params_to_optimize = params
# calculate number of trainable parameters if block_lrs is None:
n_params = 0 params = []
for p in params: for m in training_models:
n_params += p.numel() params.extend(m.parameters())
params_to_optimize = params
# calculate number of trainable parameters
n_params = 0
for p in params:
n_params += p.numel()
else:
params_to_optimize = get_block_params_to_optimize(training_models[0], block_lrs) # U-Net
for m in training_models[1:]: # Text Encoders if exists
params_to_optimize.append({"params": m.parameters(), "lr": args.learning_rate})
# calculate number of trainable parameters
n_params = 0
for params in params_to_optimize:
for p in params["params"]:
n_params += p.numel()
accelerator.print(f"number of models: {len(training_models)}") accelerator.print(f"number of models: {len(training_models)}")
accelerator.print(f"number of trainable parameters: {n_params}") accelerator.print(f"number of trainable parameters: {n_params}")
@@ -526,13 +611,18 @@ def train(args):
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
if args.logging_dir is not None: if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} logs = {"loss": current_loss}
if ( if block_lrs is None:
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy" logs["lr"] = float(lr_scheduler.get_last_lr()[0])
): # tracking d*lr value if (
logs["lr/d*lr"] = ( args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower()
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] ): # tracking d*lr value
) logs["lr/d*lr"] = (
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
)
else:
append_block_lr_to_logs(block_lrs, logs, lr_scheduler, args.optimizer_type)
accelerator.log(logs, step=global_step) accelerator.log(logs, step=global_step)
# TODO moving averageにする # TODO moving averageにする
@@ -636,6 +726,13 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true", action="store_true",
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
) )
parser.add_argument(
"--block_lr",
type=str,
default=None,
help=f"learning rates for each block of U-Net, comma-separated, {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / "
+ f"U-Netの各ブロックの学習率、カンマ区切り、{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値",
)
return parser return parser

View File

@@ -23,6 +23,8 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
args.network_train_unet_only or not args.cache_text_encoder_outputs args.network_train_unet_only or not args.cache_text_encoder_outputs
), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません" ), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません"
train_dataset_group.verify_bucket_reso_steps(32)
def load_target_model(self, args, weight_dtype, accelerator): def load_target_model(self, args, weight_dtype, accelerator):
( (
load_stable_diffusion_format, load_stable_diffusion_format,

View File

@@ -19,6 +19,8 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
super().assert_extra_args(args, train_dataset_group) super().assert_extra_args(args, train_dataset_group)
sdxl_train_util.verify_sdxl_training_args(args, supportTextEncoderCaching=False) sdxl_train_util.verify_sdxl_training_args(args, supportTextEncoderCaching=False)
train_dataset_group.verify_bucket_reso_steps(32)
def load_target_model(self, args, weight_dtype, accelerator): def load_target_model(self, args, weight_dtype, accelerator):
( (
load_stable_diffusion_format, load_stable_diffusion_format,