From 747af145ed32eb85205dca144a4e49f25032d130 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 26 Jun 2023 08:07:24 +0900 Subject: [PATCH] add sdxl fine-tuning and LoRA --- library/model_util.py | 10 + library/sdxl_model_util.py | 8 +- library/sdxl_original_unet.py | 12 +- library/sdxl_train_util.py | 384 +++++++++ library/train_util.py | 103 ++- networks/lora.py | 90 +- networks/sdxl_merge_lora.py | 258 ++++++ sdxl_minimal_inference.py | 29 +- sdxl_train.py | 605 +++++++++++++ sdxl_train_network.py | 172 ++++ train_network.py | 1525 ++++++++++++++++++--------------- 11 files changed, 2442 insertions(+), 754 deletions(-) create mode 100644 library/sdxl_train_util.py create mode 100644 networks/sdxl_merge_lora.py create mode 100644 sdxl_train.py create mode 100644 sdxl_train_network.py diff --git a/library/model_util.py b/library/model_util.py index fce08be8..938b7692 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -1061,6 +1061,16 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt return text_model, vae, unet +def get_model_version_str_for_sd1_sd2(v2, v_parameterization): + # only for reference + version_str = "sd" + if v2: + version_str += "_v2" + else: + version_str += "_v1" + if v_parameterization: + version_str += "_v" + return version_str def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False): def convert_key(key): diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py index fc64d21e..c554782b 100644 --- a/library/sdxl_model_util.py +++ b/library/sdxl_model_util.py @@ -6,6 +6,10 @@ from library import model_util from library import sdxl_original_unet +VAE_SCALE_FACTOR = 0.13025 +MODEL_VERSION_SDXL_BASE_V0_9 = "sdxl_base_v0-9" + + def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length): SDXL_KEY_PREFIX = "conditioner.embedders.1.model." @@ -76,8 +80,8 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length): return new_sd, text_projection, logit_scale -def load_models_from_sdxl_checkpoint(model_type, ckpt_path, map_location): - # model_type is reserved to future use +def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location): + # model_version is reserved for future use # Load the state dict if model_util.is_safetensors(ckpt_path): diff --git a/library/sdxl_original_unet.py b/library/sdxl_original_unet.py index fd37432f..8ba1c988 100644 --- a/library/sdxl_original_unet.py +++ b/library/sdxl_original_unet.py @@ -1069,8 +1069,8 @@ class SdxlUNet2DConditionModel(nn.Module): t_emb = t_emb.to(x.dtype) emb = self.time_embed(t_emb) - assert y.shape[0] == x.shape[0] - assert x.dtype == y.dtype + assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}" + assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}" # assert x.dtype == self.dtype emb = emb + self.label_emb(y) @@ -1105,6 +1105,8 @@ class SdxlUNet2DConditionModel(nn.Module): if __name__ == "__main__": + import time + print("create unet") unet = SdxlUNet2DConditionModel() @@ -1132,8 +1134,11 @@ if __name__ == "__main__": print("start training") steps = 10 batch_size = 1 + for step in range(steps): print(f"step {step}") + if step == 1: + time_start = time.perf_counter() x = torch.randn(batch_size, 4, 128, 128).cuda() # 1024x1024 t = torch.randint(low=0, high=10, size=(batch_size,), device="cuda") @@ -1149,3 +1154,6 @@ if __name__ == "__main__": scaler.step(optimizer) scaler.update() optimizer.zero_grad(set_to_none=True) + + time_end = time.perf_counter() + print(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps") diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py new file mode 100644 index 00000000..e10488c0 --- /dev/null +++ b/library/sdxl_train_util.py @@ -0,0 +1,384 @@ +import argparse +import gc +import math +import os +from types import SimpleNamespace +from typing import Any +import torch +from tqdm import tqdm +from transformers import CLIPTokenizer +import open_clip +from library import model_util, sdxl_model_util, train_util + +TOKENIZER_PATH = "openai/clip-vit-large-patch14" + +DEFAULT_NOISE_OFFSET = 0.0357 + + +# TODO: separate checkpoint for each U-Net/Text Encoder/VAE +def load_target_model(args, accelerator, model_version: str, weight_dtype): + # load models for each process + for pi in range(accelerator.state.num_processes): + if pi == accelerator.state.local_process_index: + print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") + + ( + load_stable_diffusion_format, + text_encoder1, + text_encoder2, + vae, + unet, + text_projection, + logit_scale, + ckpt_info, + ) = _load_target_model(args, model_version, weight_dtype, accelerator.device if args.lowram else "cpu") + + # work on low-ram device + if args.lowram: + text_encoder1.to(accelerator.device) + text_encoder2.to(accelerator.device) + unet.to(accelerator.device) + vae.to(accelerator.device) + + gc.collect() + torch.cuda.empty_cache() + accelerator.wait_for_everyone() + + text_encoder1, text_encoder2, unet = train_util.transform_models_if_DDP([text_encoder1, text_encoder2, unet]) + + return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, text_projection, logit_scale, ckpt_info + + +def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtype, device="cpu"): + # only supports StableDiffusion + name_or_path = args.pretrained_model_name_or_path + name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path + load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers + assert ( + load_stable_diffusion_format + ), f"only supports StableDiffusion format for SDXL / SDXLではStableDiffusion形式のみサポートしています: {name_or_path}" + + print(f"load StableDiffusion checkpoint: {name_or_path}") + ( + text_encoder1, + text_encoder2, + vae, + unet, + text_projection, + logit_scale, + ckpt_info, + ) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device) + + # VAEを読み込む + if args.vae is not None: + vae = model_util.load_vae(args.vae, weight_dtype) + print("additional VAE loaded") + + return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, text_projection, logit_scale, ckpt_info + + +class WrapperTokenizer: + # open clipのtokenizerをHuggingFaceのtokenizerと同じ形で使えるようにする + def __init__(self): + open_clip_tokenizer = open_clip.tokenizer._tokenizer + self.model_max_length = 77 + self.bos_token_id = open_clip_tokenizer.all_special_ids[0] + self.eos_token_id = open_clip_tokenizer.all_special_ids[1] + self.pad_token_id = 0 # 結果から推定している + + def __call__(self, *args: Any, **kwds: Any) -> Any: + return self.tokenize(*args, **kwds) + + def tokenize(self, text, padding, truncation, max_length, return_tensors): + assert padding == "max_length" + assert truncation == True + assert return_tensors == "pt" + input_ids = open_clip.tokenize(text, context_length=max_length) + return SimpleNamespace(**{"input_ids": input_ids}) + + +def load_tokenizers(args: argparse.Namespace): + print("prepare tokenizers") + original_path = TOKENIZER_PATH + + tokenizer1: CLIPTokenizer = None + if args.tokenizer_cache_dir: + local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_")) + if os.path.exists(local_tokenizer_path): + print(f"load tokenizer from cache: {local_tokenizer_path}") + tokenizer1 = CLIPTokenizer.from_pretrained(local_tokenizer_path) + + if tokenizer1 is None: + tokenizer1 = CLIPTokenizer.from_pretrained(original_path) + + if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path): + print(f"save Tokenizer to cache: {local_tokenizer_path}") + tokenizer1.save_pretrained(local_tokenizer_path) + + if hasattr(args, "max_token_length") and args.max_token_length is not None: + print(f"update token length: {args.max_token_length}") + + # tokenizer2 is from open_clip + # TODO caching + tokenizer2 = WrapperTokenizer() + + return [tokenizer1, tokenizer2] + + +def get_hidden_states( + args: argparse.Namespace, input_ids1, input_ids2, tokenizer1, tokenizer2, text_encoder1, text_encoder2, weight_dtype=None +): + input_ids1 = input_ids1.reshape((-1, tokenizer1.model_max_length)) # batch_size*n, 77 + input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) # batch_size*n, 77 + + # text_encoder1 + enc_out = text_encoder1(input_ids1, output_hidden_states=True, return_dict=True) + hidden_states1 = enc_out["hidden_states"][11] + + # text_encoder2 + enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True) + hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer + pool2 = enc_out["pooler_output"] + + if args.max_token_length is not None: + # bs*3, 77, 768 or 1024 + # encoder1: ... の三連を ... へ戻す + states_list = [hidden_states1[:, 0].unsqueeze(1)] # + for i in range(1, args.max_token_length, tokenizer1.model_max_length): + states_list.append(hidden_states1[:, i : i + tokenizer1.model_max_length - 2]) # の後から の前まで + states_list.append(hidden_states1[:, -1].unsqueeze(1)) # + hidden_states1 = torch.cat(states_list, dim=1) + + # v2: ... ... の三連を ... ... へ戻す 正直この実装でいいのかわからん + states_list = [hidden_states2[:, 0].unsqueeze(1)] # + for i in range(1, args.max_token_length, tokenizer2.model_max_length): + chunk = hidden_states2[:, i : i + tokenizer2.model_max_length - 2] # の後から 最後の前まで + if i > 0: + for j in range(len(chunk)): + if input_ids2[j, 1] == tokenizer2.eos_token: # 空、つまり ...のパターン + chunk[j, 0] = chunk[j, 1] # 次の の値をコピーする + states_list.append(chunk) # の後から の前まで + states_list.append(hidden_states2[:, -1].unsqueeze(1)) # のどちらか + hidden_states2 = torch.cat(states_list, dim=1) + + if weight_dtype is not None: + # this is required for additional network training + hidden_states1 = hidden_states1.to(weight_dtype) + hidden_states2 = hidden_states2.to(weight_dtype) + + return hidden_states1, hidden_states2, pool2 + + +def timestep_embedding(timesteps, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( + device=timesteps.device + ) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +def get_timestep_embedding(x, outdim): + assert len(x.shape) == 2 + b, dims = x.shape[0], x.shape[1] + x = torch.flatten(x) + emb = timestep_embedding(x, outdim) + emb = torch.reshape(emb, (b, dims * outdim)) + return emb + + +def get_size_embeddings(orig_size, crop_size, target_size, device): + emb1 = get_timestep_embedding(orig_size, 256) + emb2 = get_timestep_embedding(crop_size, 256) + emb3 = get_timestep_embedding(target_size, 256) + vector = torch.cat([emb1, emb2, emb3], dim=1).to(device) + return vector + + +def save_sd_model_on_train_end( + args: argparse.Namespace, + src_path: str, + save_stable_diffusion_format: bool, + use_safetensors: bool, + save_dtype: torch.dtype, + epoch: int, + global_step: int, + text_encoder1, + text_encoder2, + unet, + vae, + text_projection, + logit_scale, + ckpt_info, +): + def sd_saver(ckpt_file, epoch_no, global_step): + sdxl_model_util.save_stable_diffusion_checkpoint( + ckpt_file, + text_encoder1, + text_encoder2, + unet, + epoch_no, + global_step, + ckpt_info, + vae, + text_projection, + logit_scale, + save_dtype, + ) + + def diffusers_saver(out_dir): + raise NotImplementedError("diffusers_saver is not implemented") + + train_util.save_sd_model_on_train_end_common( + args, save_stable_diffusion_format, use_safetensors, epoch, global_step, sd_saver, diffusers_saver + ) + + +# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している +# on_epoch_end: Trueならepoch終了時、Falseならstep経過時 +def save_sd_model_on_epoch_end_or_stepwise( + args: argparse.Namespace, + on_epoch_end: bool, + accelerator, + src_path, + save_stable_diffusion_format: bool, + use_safetensors: bool, + save_dtype: torch.dtype, + epoch: int, + num_train_epochs: int, + global_step: int, + text_encoder1, + text_encoder2, + unet, + vae, + text_projection, + logit_scale, + ckpt_info, +): + def sd_saver(ckpt_file, epoch_no, global_step): + sdxl_model_util.save_stable_diffusion_checkpoint( + ckpt_file, + text_encoder1, + text_encoder2, + unet, + epoch_no, + global_step, + ckpt_info, + vae, + text_projection, + logit_scale, + save_dtype, + ) + + def diffusers_saver(out_dir): + raise NotImplementedError("diffusers_saver is not implemented") + + train_util.save_sd_model_on_epoch_end_or_stepwise_common( + args, + on_epoch_end, + accelerator, + save_stable_diffusion_format, + use_safetensors, + epoch, + num_train_epochs, + global_step, + sd_saver, + diffusers_saver, + ) + + +# TextEncoderの出力をキャッシュする +# weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる +def cache_text_encoder_outputs(args, accelerator, tokenizers, text_encoders, data_loader, weight_dtype): + print("caching text encoder outputs") + + tokenizer1, tokenizer2 = tokenizers + text_encoder1, text_encoder2 = text_encoders + text_encoder1.to(accelerator.device) + text_encoder2.to(accelerator.device) + if weight_dtype is not None: + text_encoder1.to(dtype=weight_dtype) + text_encoder2.to(dtype=weight_dtype) + + text_encoder1_cache = {} + text_encoder2_cache = {} + for batch in tqdm(data_loader): + input_ids1_batch = batch["input_ids"] + input_ids2_batch = batch["input_ids2"] + + # split batch to avoid OOM + # TODO specify batch size by args + for input_ids1, input_ids2 in zip(input_ids1_batch.split(1), input_ids2_batch.split(1)): + # remove input_ids already in cache + input_ids1 = input_ids1.squeeze(0) + input_ids2 = input_ids2.squeeze(0) + input_ids1 = [i for i in input_ids1 if i not in text_encoder1_cache] + input_ids2 = [i for i in input_ids2 if i not in text_encoder2_cache] + assert len(input_ids1) == len(input_ids2) + if len(input_ids1) == 0: + continue + input_ids1 = torch.stack(input_ids1).to(accelerator.device) + input_ids2 = torch.stack(input_ids2).to(accelerator.device) + + with torch.no_grad(): + encoder_hidden_states1, encoder_hidden_states2, pool2 = get_hidden_states( + args, + input_ids1, + input_ids2, + tokenizer1, + tokenizer2, + text_encoder1, + text_encoder2, + None if not args.full_fp16 else weight_dtype, + ) + encoder_hidden_states1 = encoder_hidden_states1.detach().to("cpu") + encoder_hidden_states2 = encoder_hidden_states2.detach().to("cpu") + pool2 = pool2.to("cpu") + for input_id1, input_id2, hidden_states1, hidden_states2, p2 in zip( + input_ids1, input_ids2, encoder_hidden_states1, encoder_hidden_states2, pool2 + ): + text_encoder1_cache[tuple(input_id1.tolist())] = hidden_states1 + text_encoder2_cache[tuple(input_id2.tolist())] = (hidden_states2, p2) + return text_encoder1_cache, text_encoder2_cache + + +def add_sdxl_training_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする" + ) + + +def verify_sdxl_training_args(args: argparse.Namespace): + assert ( + not args.v2 and not args.v_parameterization + ), "v2 or v_parameterization cannot be enabled in SDXL training / SDXL学習ではv2とv_parameterizationを有効にすることはできません" + if args.clip_skip is not None: + print("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません") + + if args.multires_noise_iterations: + print( + f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET}, but noise_offset is disabled due to multires_noise_iterations / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されていますが、multires_noise_iterationsが有効になっているためnoise_offsetは無効になります" + ) + else: + if args.noise_offset is None: + args.noise_offset = DEFAULT_NOISE_OFFSET + elif args.noise_offset != DEFAULT_NOISE_OFFSET: + print( + f"Waring: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET} / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されています" + ) + print(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました") + + assert ( + not args.weighted_captions + ), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません" diff --git a/library/train_util.py b/library/train_util.py index 533bf0a9..e609705e 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -798,6 +798,19 @@ class BaseDataset(torch.utils.data.Dataset): def is_latent_cacheable(self): return all([not subset.color_aug and not subset.random_crop for subset in self.subsets]) + def is_text_encoder_output_cacheable(self): + return all( + [ + not ( + subset.caption_dropout_rate > 0 + or subset.shuffle_caption + or subset.token_warmup_step > 0 + or subset.caption_tag_dropout_rate > 0 + ) + for subset in self.subsets + ] + ) + def is_disk_cached_latents_is_expected(self, reso, npz_path, flipped_npz_path): expected_latents_size = (reso[1] // 8, reso[0] // 8) # bucket_resoはWxHなので注意 @@ -850,7 +863,7 @@ class BaseDataset(torch.utils.data.Dataset): continue cache_available = self.is_disk_cached_latents_is_expected( - info.bucket_reso, info.latents_npz, info.latents_npz_flipped if self.flip_aug else None + info.bucket_reso, info.latents_npz, info.latents_npz_flipped if subset.flip_aug else None ) if cache_available: # do not add to batch @@ -1719,6 +1732,9 @@ class DatasetGroup(torch.utils.data.ConcatDataset): def is_latent_cacheable(self) -> bool: return all([dataset.is_latent_cacheable() for dataset in self.datasets]) + def is_text_encoder_output_cacheable(self) -> bool: + return all([dataset.is_text_encoder_output_cacheable() for dataset in self.datasets]) + def set_current_epoch(self, epoch): for dataset in self.datasets: dataset.set_current_epoch(epoch) @@ -3284,11 +3300,17 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une return text_encoder, vae, unet, load_stable_diffusion_format +# TODO remove this function in the future def transform_if_model_is_DDP(text_encoder, unet, network=None): # Transform text_encoder, unet and network from DistributedDataParallel return (model.module if type(model) == DDP else model for model in [text_encoder, unet, network] if model is not None) +def transform_models_if_DDP(models): + # Transform text_encoder, unet and network from DistributedDataParallel + return [model.module if type(model) == DDP else model for model in models if model is not None] + + def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=False): # load models for each process for pi in range(accelerator.state.num_processes): @@ -3430,6 +3452,42 @@ def save_sd_model_on_epoch_end_or_stepwise( text_encoder, unet, vae, +): + def sd_saver(ckpt_file, epoch_no, global_step): + model_util.save_stable_diffusion_checkpoint( + args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae + ) + + def diffusers_saver(out_dir): + model_util.save_diffusers_checkpoint( + args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors + ) + + save_sd_model_on_epoch_end_or_stepwise_common( + args, + on_epoch_end, + accelerator, + save_stable_diffusion_format, + use_safetensors, + epoch, + num_train_epochs, + global_step, + sd_saver, + diffusers_saver, + ) + + +def save_sd_model_on_epoch_end_or_stepwise_common( + args: argparse.Namespace, + on_epoch_end: bool, + accelerator, + save_stable_diffusion_format: bool, + use_safetensors: bool, + epoch: int, + num_train_epochs: int, + global_step: int, + sd_saver, + diffusers_saver, ): if on_epoch_end: epoch_no = epoch + 1 @@ -3457,9 +3515,7 @@ def save_sd_model_on_epoch_end_or_stepwise( ckpt_file = os.path.join(args.output_dir, ckpt_name) print(f"\nsaving checkpoint: {ckpt_file}") - model_util.save_stable_diffusion_checkpoint( - args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae - ) + sd_saver(ckpt_file, epoch_no, global_step) if args.huggingface_repo_id is not None: huggingface_util.upload(args, ckpt_file, "/" + ckpt_name) @@ -3483,9 +3539,8 @@ def save_sd_model_on_epoch_end_or_stepwise( out_dir = os.path.join(args.output_dir, STEP_DIFFUSERS_DIR_NAME.format(model_name, global_step)) print(f"\nsaving model: {out_dir}") - model_util.save_diffusers_checkpoint( - args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors - ) + diffusers_saver(out_dir) + if args.huggingface_repo_id is not None: huggingface_util.upload(args, out_dir, "/" + model_name) @@ -3578,6 +3633,30 @@ def save_sd_model_on_train_end( text_encoder, unet, vae, +): + def sd_saver(ckpt_file, epoch_no, global_step): + model_util.save_stable_diffusion_checkpoint( + args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae + ) + + def diffusers_saver(out_dir): + model_util.save_diffusers_checkpoint( + args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors + ) + + save_sd_model_on_train_end_common( + args, save_stable_diffusion_format, use_safetensors, epoch, global_step, sd_saver, diffusers_saver + ) + + +def save_sd_model_on_train_end_common( + args: argparse.Namespace, + save_stable_diffusion_format: bool, + use_safetensors: bool, + epoch: int, + global_step: int, + sd_saver, + diffusers_saver, ): model_name = default_if_none(args.output_name, DEFAULT_LAST_OUTPUT_NAME) @@ -3588,9 +3667,8 @@ def save_sd_model_on_train_end( ckpt_file = os.path.join(args.output_dir, ckpt_name) print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}") - model_util.save_stable_diffusion_checkpoint( - args.v2, ckpt_file, text_encoder, unet, src_path, epoch, global_step, save_dtype, vae - ) + sd_saver(ckpt_file, epoch, global_step) + if args.huggingface_repo_id is not None: huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True) else: @@ -3598,9 +3676,8 @@ def save_sd_model_on_train_end( os.makedirs(out_dir, exist_ok=True) print(f"save trained model as Diffusers to {out_dir}") - model_util.save_diffusers_checkpoint( - args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors - ) + diffusers_saver(out_dir) + if args.huggingface_repo_id is not None: huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True) diff --git a/networks/lora.py b/networks/lora.py index 10c5a07f..b6788b99 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -5,7 +5,9 @@ import math import os -from typing import List, Tuple, Union +from typing import Dict, List, Optional, Tuple, Type, Union +from diffusers import AutoencoderKL +from transformers import CLIPTextModel import numpy as np import torch import re @@ -400,7 +402,16 @@ def parse_block_lr_kwargs(nw_kwargs): return down_lr_weight, mid_lr_weight, up_lr_weight -def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, neuron_dropout=None, **kwargs): +def create_network( + multiplier: float, + network_dim: Optional[int], + network_alpha: Optional[float], + vae: AutoencoderKL, + text_encoder: Union[CLIPTextModel, List[CLIPTextModel]], + unet, + neuron_dropout: Optional[float] = None, + **kwargs, +): if network_dim is None: network_dim = 4 # default if network_alpha is None: @@ -719,33 +730,36 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh class LoRANetwork(torch.nn.Module): NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数 - # is it possible to apply conv_in and conv_out? -> yes, newer LoCon supports it (^^;) UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"] UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] LORA_PREFIX_UNET = "lora_unet" LORA_PREFIX_TEXT_ENCODER = "lora_te" + + # SDXL: must starts with LORA_PREFIX_TEXT_ENCODER + LORA_PREFIX_TEXT_ENCODER1 = "lora_te1" + LORA_PREFIX_TEXT_ENCODER2 = "lora_te2" def __init__( self, - text_encoder, + text_encoder: Union[List[CLIPTextModel], CLIPTextModel], unet, - multiplier=1.0, - lora_dim=4, - alpha=1, - dropout=None, - rank_dropout=None, - module_dropout=None, - conv_lora_dim=None, - conv_alpha=None, - block_dims=None, - block_alphas=None, - conv_block_dims=None, - conv_block_alphas=None, - modules_dim=None, - modules_alpha=None, - module_class=LoRAModule, - varbose=False, + multiplier: float = 1.0, + lora_dim: int = 4, + alpha: float = 1, + dropout: Optional[float] = None, + rank_dropout: Optional[float] = None, + module_dropout: Optional[float] = None, + conv_lora_dim: Optional[int] = None, + conv_alpha: Optional[float] = None, + block_dims: Optional[List[int]] = None, + block_alphas: Optional[List[float]] = None, + conv_block_dims: Optional[List[int]] = None, + conv_block_alphas: Optional[List[float]] = None, + modules_dim: Optional[Dict[str, int]] = None, + modules_alpha: Optional[Dict[str, int]] = None, + module_class: Type[object] = LoRAModule, + varbose: Optional[bool] = False, ) -> None: """ LoRA network: すごく引数が多いが、パターンは以下の通り @@ -783,8 +797,21 @@ class LoRANetwork(torch.nn.Module): print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") # create module instances - def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]: - prefix = LoRANetwork.LORA_PREFIX_UNET if is_unet else LoRANetwork.LORA_PREFIX_TEXT_ENCODER + def create_modules( + is_unet: bool, + text_encoder_idx: Optional[int], # None, 1, 2 + root_module: torch.nn.Module, + target_replace_modules: List[torch.nn.Module], + ) -> List[LoRAModule]: + prefix = ( + self.LORA_PREFIX_UNET + if is_unet + else ( + self.LORA_PREFIX_TEXT_ENCODER + if text_encoder_idx is None + else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2) + ) + ) loras = [] skipped = [] for name, module in root_module.named_modules(): @@ -800,11 +827,14 @@ class LoRANetwork(torch.nn.Module): dim = None alpha = None + if modules_dim is not None: + # モジュール指定あり if lora_name in modules_dim: dim = modules_dim[lora_name] alpha = modules_alpha[lora_name] elif is_unet and block_dims is not None: + # U-Netでblock_dims指定あり block_idx = get_block_index(lora_name) if is_linear or is_conv2d_1x1: dim = block_dims[block_idx] @@ -813,6 +843,7 @@ class LoRANetwork(torch.nn.Module): dim = conv_block_dims[block_idx] alpha = conv_block_alphas[block_idx] else: + # 通常、すべて対象とする if is_linear or is_conv2d_1x1: dim = self.lora_dim alpha = self.alpha @@ -821,6 +852,7 @@ class LoRANetwork(torch.nn.Module): alpha = self.conv_alpha if dim is None or dim == 0: + # skipした情報を出力 if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None or conv_block_dims is not None): skipped.append(lora_name) continue @@ -838,7 +870,16 @@ class LoRANetwork(torch.nn.Module): loras.append(lora) return loras, skipped - self.text_encoder_loras, skipped_te = create_modules(False, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + text_encoders = text_encoder if type(text_encoder) == list else [text_encoder] + + # create LoRA for text encoder + # 毎回すべてのモジュールを作るのは無駄なので要検討 + self.text_encoder_loras = [] + skipped_te = [] + for i, text_encoder in enumerate(text_encoders): + text_encoder_loras, skipped = create_modules(False, i + 1, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + self.text_encoder_loras.extend(text_encoder_loras) + skipped_te += skipped print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights @@ -846,7 +887,7 @@ class LoRANetwork(torch.nn.Module): if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None: target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 - self.unet_loras, skipped_un = create_modules(True, unet, target_modules) + self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules) print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") skipped = skipped_te + skipped_un @@ -961,6 +1002,7 @@ class LoRANetwork(torch.nn.Module): return lr_weight + # 二つのText Encoderに別々の学習率を設定できるようにするといいかも def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): self.requires_grad_(True) all_params = [] diff --git a/networks/sdxl_merge_lora.py b/networks/sdxl_merge_lora.py new file mode 100644 index 00000000..0fc3f9c5 --- /dev/null +++ b/networks/sdxl_merge_lora.py @@ -0,0 +1,258 @@ +import math +import argparse +import os +import torch +from safetensors.torch import load_file, save_file +from tqdm import tqdm +from library import sdxl_model_util +import library.model_util as model_util +import lora + + +def load_state_dict(file_name, dtype): + if os.path.splitext(file_name)[1] == ".safetensors": + sd = load_file(file_name) + else: + sd = torch.load(file_name, map_location="cpu") + for key in list(sd.keys()): + if type(sd[key]) == torch.Tensor: + sd[key] = sd[key].to(dtype) + return sd + + +def save_to_file(file_name, model, state_dict, dtype): + if dtype is not None: + for key in list(state_dict.keys()): + if type(state_dict[key]) == torch.Tensor: + state_dict[key] = state_dict[key].to(dtype) + + if os.path.splitext(file_name)[1] == ".safetensors": + save_file(model, file_name) + else: + torch.save(model, file_name) + + +def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_dtype): + text_encoder1.to(merge_dtype) + text_encoder1.to(merge_dtype) + unet.to(merge_dtype) + + # create module map + name_to_module = {} + for i, root_module in enumerate([text_encoder1, text_encoder2, unet]): + if i <= 1: + if i == 0: + prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1 + else: + prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER2 + target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE + else: + prefix = lora.LoRANetwork.LORA_PREFIX_UNET + target_replace_modules = ( + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 + ) + + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d": + lora_name = prefix + "." + name + "." + child_name + lora_name = lora_name.replace(".", "_") + name_to_module[lora_name] = child_module + + for model, ratio in zip(models, ratios): + print(f"loading: {model}") + lora_sd = load_state_dict(model, merge_dtype) + + print(f"merging...") + for key in tqdm(lora_sd.keys()): + if "lora_down" in key: + up_key = key.replace("lora_down", "lora_up") + alpha_key = key[: key.index("lora_down")] + "alpha" + + # find original module for this lora + module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight" + if module_name not in name_to_module: + print(f"no module found for LoRA weight: {key}") + continue + module = name_to_module[module_name] + # print(f"apply {key} to {module}") + + down_weight = lora_sd[key] + up_weight = lora_sd[up_key] + + dim = down_weight.size()[0] + alpha = lora_sd.get(alpha_key, dim) + scale = alpha / dim + + # W <- W + U * D + weight = module.weight + # print(module_name, down_weight.size(), up_weight.size()) + if len(weight.size()) == 2: + # linear + weight = weight + ratio * (up_weight @ down_weight) * scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + weight + + ratio + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # print(conved.size(), weight.size(), module.stride, module.padding) + weight = weight + ratio * conved * scale + + module.weight = torch.nn.Parameter(weight) + + +def merge_lora_models(models, ratios, merge_dtype): + base_alphas = {} # alpha for merged model + base_dims = {} + + merged_sd = {} + for model, ratio in zip(models, ratios): + print(f"loading: {model}") + lora_sd = load_state_dict(model, merge_dtype) + + # get alpha and dim + alphas = {} # alpha for current model + dims = {} # dims for current model + for key in lora_sd.keys(): + if "alpha" in key: + lora_module_name = key[: key.rfind(".alpha")] + alpha = float(lora_sd[key].detach().numpy()) + alphas[lora_module_name] = alpha + if lora_module_name not in base_alphas: + base_alphas[lora_module_name] = alpha + elif "lora_down" in key: + lora_module_name = key[: key.rfind(".lora_down")] + dim = lora_sd[key].size()[0] + dims[lora_module_name] = dim + if lora_module_name not in base_dims: + base_dims[lora_module_name] = dim + + for lora_module_name in dims.keys(): + if lora_module_name not in alphas: + alpha = dims[lora_module_name] + alphas[lora_module_name] = alpha + if lora_module_name not in base_alphas: + base_alphas[lora_module_name] = alpha + + print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") + + # merge + print(f"merging...") + for key in tqdm(lora_sd.keys()): + if "alpha" in key: + continue + + lora_module_name = key[: key.rfind(".lora_")] + + base_alpha = base_alphas[lora_module_name] + alpha = alphas[lora_module_name] + + scale = math.sqrt(alpha / base_alpha) * ratio + + if key in merged_sd: + assert ( + merged_sd[key].size() == lora_sd[key].size() + ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません" + merged_sd[key] = merged_sd[key] + lora_sd[key] * scale + else: + merged_sd[key] = lora_sd[key] * scale + + # set alpha to sd + for lora_module_name, alpha in base_alphas.items(): + key = lora_module_name + ".alpha" + merged_sd[key] = torch.tensor(alpha) + + print("merged model") + print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") + + return merged_sd + + +def merge(args): + assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" + + def str_to_dtype(p): + if p == "float": + return torch.float + if p == "fp16": + return torch.float16 + if p == "bf16": + return torch.bfloat16 + return None + + merge_dtype = str_to_dtype(args.precision) + save_dtype = str_to_dtype(args.save_precision) + if save_dtype is None: + save_dtype = merge_dtype + + if args.sd_model is not None: + print(f"loading SD model: {args.sd_model}") + + ( + text_model1, + text_model2, + vae, + unet, + text_projection, + logit_scale, + ckpt_info, + ) = sdxl_model_util.load_models_from_sdxl_checkpoint(sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, args.sd_model, "cpu") + + merge_to_sd_model(text_model2, text_model2, unet, args.models, args.ratios, merge_dtype) + + print(f"saving SD model to: {args.save_to}") + sdxl_model_util.save_stable_diffusion_checkpoint( + args.save_to, text_model1, text_model2, unet, 0, 0, ckpt_info, vae, text_projection, logit_scale, save_dtype + ) + else: + state_dict = merge_lora_models(args.models, args.ratios, merge_dtype) + + print(f"saving model to: {args.save_to}") + save_to_file(args.save_to, state_dict, state_dict, save_dtype) + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument( + "--save_precision", + type=str, + default=None, + choices=[None, "float", "fp16", "bf16"], + help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ", + ) + parser.add_argument( + "--precision", + type=str, + default="float", + choices=["float", "fp16", "bf16"], + help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)", + ) + parser.add_argument( + "--sd_model", + type=str, + default=None, + help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする", + ) + parser.add_argument( + "--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors" + ) + parser.add_argument( + "--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors" + ) + parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率") + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + merge(args) diff --git a/sdxl_minimal_inference.py b/sdxl_minimal_inference.py index f8e7d687..2f3670df 100644 --- a/sdxl_minimal_inference.py +++ b/sdxl_minimal_inference.py @@ -11,10 +11,13 @@ import numpy as np import torch from tqdm import tqdm from transformers import CLIPTokenizer -from library import sdxl_model_util from diffusers import EulerDiscreteScheduler from PIL import Image import open_clip +from safetensors.torch import load_file + +from library import model_util, sdxl_model_util +import networks.lora as lora # scheduler: このあたりの設定はSD1/2と同じでいいらしい # scheduler: The settings around here seem to be the same as SD1/2 @@ -85,6 +88,13 @@ if __name__ == "__main__": parser.add_argument("--prompt", type=str, default="A photo of a cat") parser.add_argument("--negative_prompt", type=str, default="") parser.add_argument("--output_dir", type=str, default=".") + parser.add_argument( + "--lora_weights", + type=str, + nargs="*", + default=[], + help="LoRA weights, only supports networks.lora, each arguement is a `path;multiplier` (semi-colon separated)", + ) args = parser.parse_args() # HuggingFaceのmodel id @@ -97,7 +107,7 @@ if __name__ == "__main__": # 本体RAMが少ない場合はGPUにロードするといいかも # If the main RAM is small, it may be better to load it on the GPU text_model1, text_model2, vae, unet, text_projection, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( - "sdxl_base_v0-9", args.ckpt_path, "cpu" + sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, args.ckpt_path, "cpu" ) # Text Encoder 1はSDXL本体でもHuggingFaceのものを使っている @@ -134,6 +144,19 @@ if __name__ == "__main__": unet.set_use_memory_efficient_attention(True, False) + # LoRA + for weights_file in args.lora_weights: + if ";" in weights_file: + weights_file, multiplier = weights_file.split(";") + multiplier = float(multiplier) + else: + multiplier = 1.0 + + lora_model, weights_sd = lora.create_network_from_weights( + multiplier, weights_file, vae, [text_model1, text_model2], unet, None, True + ) + lora_model.merge_to([text_model1, text_model2], unet, weights_sd, DTYPE, DEVICE) + # prepare embedding with torch.no_grad(): # vector @@ -248,7 +271,7 @@ if __name__ == "__main__": latents = scheduler.step(noise_pred, t, latents).prev_sample # latents = 1 / 0.18215 * latents - latents = 1 / 0.13025 * latents + latents = 1 / sdxl_model_util.VAE_SCALE_FACTOR * latents latents = latents.to(torch.float32) image = vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) diff --git a/sdxl_train.py b/sdxl_train.py new file mode 100644 index 00000000..2683038b --- /dev/null +++ b/sdxl_train.py @@ -0,0 +1,605 @@ +# training with captions + +import argparse +import gc +import math +import os +from multiprocessing import Value + +from tqdm import tqdm +import torch +from accelerate.utils import set_seed +from diffusers import DDPMScheduler +from library import sdxl_model_util + +import library.train_util as train_util +import library.config_util as config_util +import library.sdxl_train_util as sdxl_train_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import ( + apply_snr_weight, + prepare_scheduler_for_custom_training, + pyramid_noise_like, + apply_noise_offset, + scale_v_prediction_loss_like_noise_prediction, +) +from library.sdxl_original_unet import SdxlUNet2DConditionModel + + +def train(args): + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + sdxl_train_util.verify_sdxl_training_args(args) + + assert not args.weighted_captions, "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" + assert ( + not args.train_text_encoder or not args.cache_text_encoder_outputs + ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません" + + cache_latents = args.cache_latents + + if args.seed is not None: + set_seed(args.seed) # 乱数系列を初期化する + + tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) + + # データセットを準備する + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) + if args.dataset_config is not None: + print(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + print( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2]) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + else: + train_dataset_group = train_util.load_arbitrary_dataset(args, [tokenizer1, tokenizer2]) + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collater = train_util.collater_class(current_epoch, current_step, ds_for_collater) + + if args.debug_dataset: + train_util.debug_dataset(train_dataset_group, True) + return + if len(train_dataset_group) == 0: + print( + "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。" + ) + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + # acceleratorを準備する + print("prepare accelerator") + accelerator = train_util.prepare_accelerator(args) + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + + # モデルを読み込む + ( + load_stable_diffusion_format, + text_encoder1, + text_encoder2, + vae, + unet, + text_projection, + logit_scale, + ckpt_info, + ) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype) + text_projection = text_projection.to(accelerator.device, dtype=weight_dtype) + logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype) + + # verify load/save model formats + if load_stable_diffusion_format: + src_stable_diffusion_ckpt = args.pretrained_model_name_or_path + src_diffusers_model_path = None + else: + src_stable_diffusion_ckpt = None + src_diffusers_model_path = args.pretrained_model_name_or_path + + if args.save_model_as is None: + save_stable_diffusion_format = load_stable_diffusion_format + use_safetensors = args.use_safetensors + else: + save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors" + use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower()) + assert save_stable_diffusion_format, "save_model_as must be ckpt or safetensors / save_model_asはckptかsafetensorsである必要があります" + + # Diffusers版のxformers使用フラグを設定する関数 + def set_diffusers_xformers_flag(model, valid): + def fn_recursive_set_mem_eff(module: torch.nn.Module): + if hasattr(module, "set_use_memory_efficient_attention_xformers"): + module.set_use_memory_efficient_attention_xformers(valid) + + for child in module.children(): + fn_recursive_set_mem_eff(child) + + fn_recursive_set_mem_eff(model) + + # モデルに xformers とか memory efficient attention を組み込む + if args.diffusers_xformers: + # もうU-Netを独自にしたので動かないけどVAEのxformersは動くはず + accelerator.print("Use xformers by Diffusers") + # set_diffusers_xformers_flag(unet, True) + set_diffusers_xformers_flag(vae, True) + else: + # Windows版のxformersはfloatで学習できなかったりxformersを使わない設定も可能にしておく必要がある + accelerator.print("Disable Diffusers' xformers") + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) + vae.set_use_memory_efficient_attention_xformers(args.xformers) + + # 学習を準備する + if cache_latents: + # vae.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=torch.float32) # VAE in float to avoid NaN + vae.requires_grad_(False) + vae.eval() + with torch.no_grad(): + train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) + vae.to("cpu") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + accelerator.wait_for_everyone() + + # 学習を準備する:モデルを適切な状態にする + training_models = [] + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + training_models.append(unet) + + if args.train_text_encoder: + # TODO each option for two text encoders? + accelerator.print("enable text encoder training") + if args.gradient_checkpointing: + text_encoder1.gradient_checkpointing_enable() + text_encoder2.gradient_checkpointing_enable() + training_models.append(text_encoder1) + training_models.append(text_encoder2) + else: + text_encoder1.requires_grad_(False) + text_encoder2.requires_grad_(False) + text_encoder1.eval() + text_encoder2.eval() + + if not cache_latents: + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=weight_dtype) + + for m in training_models: + m.requires_grad_(True) + params = [] + for m in training_models: + params.extend(m.parameters()) + params_to_optimize = params + + # calculate number of trainable parameters + n_params = 0 + for p in params: + n_params += p.numel() + accelerator.print(f"number of models: {len(training_models)}") + accelerator.print(f"number of trainable parameters: {n_params}") + + # 学習に必要なクラスを準備する + accelerator.print("prepare optimizer, data loader etc.") + _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) + + # dataloaderを準備する + # DataLoaderのプロセス数:0はメインプロセスになる + n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collater, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # lr schedulerを用意する + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + accelerator.print("enable full fp16 training.") + unet.to(weight_dtype) + text_encoder1.to(weight_dtype) + text_encoder2.to(weight_dtype) + + # acceleratorがなんかよろしくやってくれるらしい + if args.train_text_encoder: + unet, text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler + ) + + # transform DDP after prepare + text_encoder1, text_encoder2, unet = train_util.transform_models_if_DDP([text_encoder1, text_encoder2, unet]) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) + (unet,) = train_util.transform_models_if_DDP([unet]) + text_encoder1.to(weight_dtype) + text_encoder2.to(weight_dtype) + text_encoder1.eval() + text_encoder2.eval() + + # TextEncoderの出力をキャッシュする + if args.cache_text_encoder_outputs: + text_encoder1_cache, text_encoder2_cache = sdxl_train_util.cache_text_encoder_outputs( + args, accelerator, (tokenizer1, tokenizer2), (text_encoder1, text_encoder2), train_dataloader, None + ) + accelerator.wait_for_everyone() + text_encoder1.to("cpu") + text_encoder2.to("cpu") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + else: + text_encoder1_cache = None + text_encoder2_cache = None + text_encoder1.to(accelerator.device) + text_encoder2.to(accelerator.device) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + accelerator.print("running training / 学習開始") + accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print(f" batch size per device / バッチサイズ: {args.train_batch_size}") + accelerator.print( + f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" + ) + accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + noise_scheduler = DDPMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False + ) + prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) + + if accelerator.is_main_process: + accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name) + + for epoch in range(num_train_epochs): + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 + + for m in training_models: + m.train() + + loss_total = 0 + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + # with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく + if True: + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) + else: + with torch.no_grad(): + # latentに変換 + # latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() + latents = vae.encode(batch["images"].to(torch.float32)).latent_dist.sample().to(weight_dtype) + latents = latents * sdxl_model_util.VAE_SCALE_FACTOR + b_size = latents.shape[0] + + input_ids1 = batch["input_ids"] + input_ids2 = batch["input_ids2"] + if not args.cache_text_encoder_outputs: + with torch.set_grad_enabled(args.train_text_encoder): + # Get the text embedding for conditioning + # TODO support weighted captions + # if args.weighted_captions: + # encoder_hidden_states = get_weighted_text_embeddings( + # tokenizer, + # text_encoder, + # batch["captions"], + # accelerator.device, + # args.max_token_length // 75 if args.max_token_length else 1, + # clip_skip=args.clip_skip, + # ) + # else: + input_ids1 = input_ids1.to(accelerator.device) + input_ids2 = input_ids2.to(accelerator.device) + encoder_hidden_states1, encoder_hidden_states2, pool2 = sdxl_train_util.get_hidden_states( + args, + input_ids1, + input_ids2, + tokenizer1, + tokenizer2, + text_encoder1, + text_encoder2, + None if not args.full_fp16 else weight_dtype, + ) + pool2 = pool2 @ text_projection.to(pool2.dtype) + else: + encoder_hidden_states1 = [] + encoder_hidden_states2 = [] + pool2 = [] + for input_id1, input_id2 in zip(input_ids1, input_ids2): + input_id1 = input_id1.squeeze(0) + input_id2 = input_id2.squeeze(0) + encoder_hidden_states1.append(text_encoder1_cache[tuple(input_id1.tolist())]) + hidden_states2, p2 = text_encoder2_cache[tuple(input_id2.tolist())] + encoder_hidden_states2.append(hidden_states2) + pool2.append(p2) + encoder_hidden_states1 = torch.stack(encoder_hidden_states1).to(accelerator.device).to(weight_dtype) + encoder_hidden_states2 = torch.stack(encoder_hidden_states2).to(accelerator.device).to(weight_dtype) + pool2 = torch.stack(pool2).to(accelerator.device).to(weight_dtype) + + pool2 = pool2 @ text_projection.to(pool2.dtype) + + # get size embeddings + orig_size = batch["original_sizes_hw"] + crop_size = batch["crop_top_lefts"] + target_size = batch["target_sizes_hw"] + embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype) + + # concat embeddings + vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) + text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents, device=latents.device) + if args.noise_offset: + noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale) + elif args.multires_noise_iterations: + noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount) + + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype + + # Predict the noise residual + with accelerator.autocast(): + noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) + + target = noise + + if args.min_snr_gamma: + # do not mean over batch dimension for snr weight or scale v-pred loss + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + + loss = loss.mean() # mean over batch dimension + else: + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean") + + accelerator.backward(loss) + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = [] + for m in training_models: + params_to_clip.extend(m.parameters()) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + # sdxl_train_util.sample_images( + # accelerator, + # args, + # None, + # global_step, + # accelerator.device, + # vae, + # tokenizer1, + # tokenizer2, + # text_encoder1, + # text_encoder2, + # unet, + # ) + + # 指定ステップごとにモデルを保存 + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path + sdxl_train_util.save_sd_model_on_epoch_end_or_stepwise( + args, + False, + accelerator, + src_path, + save_stable_diffusion_format, + use_safetensors, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(text_encoder1), + accelerator.unwrap_model(text_encoder2), + accelerator.unwrap_model(unet), + vae, + text_projection, + logit_scale, + ckpt_info, + ) + + current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず + if args.logging_dir is not None: + logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} + if ( + args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy" + ): # tracking d*lr value + logs["lr/d*lr"] = ( + lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] + ) + accelerator.log(logs, step=global_step) + + # TODO moving averageにする + loss_total += current_loss + avr_loss = loss_total / (step + 1) + logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if args.logging_dir is not None: + logs = {"loss/epoch": loss_total / len(train_dataloader)} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + if args.save_every_n_epochs is not None: + if accelerator.is_main_process: + src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path + sdxl_train_util.save_sd_model_on_epoch_end_or_stepwise( + args, + True, + accelerator, + src_path, + save_stable_diffusion_format, + use_safetensors, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(text_encoder1), + accelerator.unwrap_model(text_encoder2), + accelerator.unwrap_model(unet), + vae, + text_projection, + logit_scale, + ckpt_info, + ) + + # train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + + is_main_process = accelerator.is_main_process + # if is_main_process: + unet = accelerator.unwrap_model(unet) + text_encoder1 = accelerator.unwrap_model(text_encoder1) + text_encoder2 = accelerator.unwrap_model(text_encoder2) + + accelerator.end_training() + + if args.save_state: # and is_main_process: + train_util.save_state_on_train_end(args, accelerator) + + del accelerator # この後メモリを使うのでこれは消す + + if is_main_process: + src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path + sdxl_train_util.save_sd_model_on_train_end( + args, + src_path, + save_stable_diffusion_format, + use_safetensors, + save_dtype, + epoch, + global_step, + text_encoder1, + text_encoder2, + unet, + vae, + text_projection, + logit_scale, + ckpt_info, + ) + print("model saved.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + train_util.add_sd_models_arguments(parser) + train_util.add_dataset_arguments(parser, True, True, True) + train_util.add_training_arguments(parser, False) + train_util.add_sd_saving_arguments(parser) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) + sdxl_train_util.add_sdxl_training_arguments(parser) + + parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する") + parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + args = train_util.read_config_from_file(args, parser) + + train(args) diff --git a/sdxl_train_network.py b/sdxl_train_network.py new file mode 100644 index 00000000..1bd4cb74 --- /dev/null +++ b/sdxl_train_network.py @@ -0,0 +1,172 @@ +import argparse +import torch +from library import sdxl_model_util, sdxl_train_util, train_util +import train_network + + +class SdxlNetworkTrainer(train_network.NetworkTrainer): + def __init__(self): + super().__init__() + self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR + + def assert_extra_args(self, args, train_dataset_group): + super().assert_extra_args(args) + sdxl_train_util.verify_sdxl_training_args(args) + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + assert ( + args.network_train_unet_only or not args.cache_text_encoder_outputs + ), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません" + + def load_target_model(self, args, weight_dtype, accelerator): + ( + load_stable_diffusion_format, + text_encoder1, + text_encoder2, + vae, + unet, + text_projection, + logit_scale, + ckpt_info, + ) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, weight_dtype) + + self.load_stable_diffusion_format = load_stable_diffusion_format + self.text_projection = text_projection.to(accelerator.device, dtype=weight_dtype) + self.logit_scale = logit_scale + self.ckpt_info = ckpt_info + + return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, [text_encoder1, text_encoder2], vae, unet + + def load_tokenizer(self, args): + tokenizer = sdxl_train_util.load_tokenizers(args) + return tokenizer + + def is_text_encoder_outputs_cached(self, args): + return args.cache_text_encoder_outputs + + def cache_text_encoder_outputs_if_needed( + self, args, accelerator, unet, vae, tokenizers, text_encoders, data_loader, weight_dtype + ): + if args.cache_text_encoder_outputs: + if not args.lowram: + # メモリ消費を減らす + print("move vae and unet to cpu to save memory") + org_vae_device = vae.device + org_unet_device = unet.device + vae.to("cpu") + unet.to("cpu") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + text_encoder1_cache, text_encoder2_cache = sdxl_train_util.cache_text_encoder_outputs( + args, accelerator, tokenizers, text_encoders, data_loader, weight_dtype + ) + accelerator.wait_for_everyone() + text_encoders[0].to("cpu") + text_encoders[1].to("cpu") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + self.text_encoder1_cache = text_encoder1_cache + self.text_encoder2_cache = text_encoder2_cache + + if not args.lowram: + print("move vae and unet back to original device") + vae.to(org_vae_device) + unet.to(org_unet_device) + else: + self.text_encoder1_cache = None + self.text_encoder2_cache = None + + # Text Encoderから毎回出力を取得するので、GPUに乗せておく + text_encoders[0].to(accelerator.device) + text_encoders[1].to(accelerator.device) + + def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): + input_ids1 = batch["input_ids"] + input_ids2 = batch["input_ids2"] + if not args.cache_text_encoder_outputs: + with torch.enable_grad(): + # Get the text embedding for conditioning + # TODO support weighted captions + # if args.weighted_captions: + # encoder_hidden_states = get_weighted_text_embeddings( + # tokenizer, + # text_encoder, + # batch["captions"], + # accelerator.device, + # args.max_token_length // 75 if args.max_token_length else 1, + # clip_skip=args.clip_skip, + # ) + # else: + input_ids1 = input_ids1.to(accelerator.device) + input_ids2 = input_ids2.to(accelerator.device) + encoder_hidden_states1, encoder_hidden_states2, pool2 = sdxl_train_util.get_hidden_states( + args, + input_ids1, + input_ids2, + tokenizers[0], + tokenizers[1], + text_encoders[0], + text_encoders[1], + None if not args.full_fp16 else weight_dtype, + ) + pool2 = pool2 @ self.text_projection.to(pool2.dtype) + else: + encoder_hidden_states1 = [] + encoder_hidden_states2 = [] + pool2 = [] + for input_id1, input_id2 in zip(input_ids1, input_ids2): + input_id1 = input_id1.squeeze(0) + input_id2 = input_id2.squeeze(0) + encoder_hidden_states1.append(self.text_encoder1_cache[tuple(input_id1.tolist())]) + hidden_states2, p2 = self.text_encoder2_cache[tuple(input_id2.tolist())] + encoder_hidden_states2.append(hidden_states2) + pool2.append(p2) + encoder_hidden_states1 = torch.stack(encoder_hidden_states1).to(accelerator.device).to(weight_dtype) + encoder_hidden_states2 = torch.stack(encoder_hidden_states2).to(accelerator.device).to(weight_dtype) + pool2 = torch.stack(pool2).to(accelerator.device).to(weight_dtype) + + pool2 = pool2 @ self.text_projection.to(weight_dtype) + + return encoder_hidden_states1, encoder_hidden_states2, pool2 + + def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): + noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype + + # get size embeddings + orig_size = batch["original_sizes_hw"] + crop_size = batch["crop_top_lefts"] + target_size = batch["target_sizes_hw"] + embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype) + + # concat embeddings + encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds + vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) + text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) + + noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) + return noise_pred + + def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet): + print("sample_images is not implemented") + + +def setup_parser() -> argparse.ArgumentParser: + parser = train_network.setup_parser() + sdxl_train_util.add_sdxl_training_arguments(parser) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + args = train_util.read_config_from_file(args, parser) + + trainer = SdxlNetworkTrainer() + trainer.train(args) diff --git a/train_network.py b/train_network.py index 7e930e8a..e2920db4 100644 --- a/train_network.py +++ b/train_network.py @@ -3,6 +3,7 @@ import argparse import gc import math import os +import sys import random import time import json @@ -12,6 +13,7 @@ from tqdm import tqdm import torch from accelerate.utils import set_seed from diffusers import DDPMScheduler +from library import model_util import library.train_util as train_util from library.train_util import ( @@ -34,753 +36,855 @@ from library.custom_train_functions import ( ) -# TODO 他のスクリプトと共通化する -def generate_step_logs( - args: argparse.Namespace, current_loss, avr_loss, lr_scheduler, keys_scaled=None, mean_norm=None, maximum_norm=None -): - logs = {"loss/current": current_loss, "loss/average": avr_loss} +class NetworkTrainer: + def __init__(self): + self.vae_scale_factor = 0.18215 - if keys_scaled is not None: - logs["max_norm/keys_scaled"] = keys_scaled - logs["max_norm/average_key_norm"] = mean_norm - logs["max_norm/max_key_norm"] = maximum_norm + # TODO 他のスクリプトと共通化する + def generate_step_logs( + self, args: argparse.Namespace, current_loss, avr_loss, lr_scheduler, keys_scaled=None, mean_norm=None, maximum_norm=None + ): + logs = {"loss/current": current_loss, "loss/average": avr_loss} - lrs = lr_scheduler.get_last_lr() + if keys_scaled is not None: + logs["max_norm/keys_scaled"] = keys_scaled + logs["max_norm/average_key_norm"] = mean_norm + logs["max_norm/max_key_norm"] = maximum_norm - if args.network_train_text_encoder_only or len(lrs) <= 2: # not block lr (or single block) - if args.network_train_unet_only: - logs["lr/unet"] = float(lrs[0]) - elif args.network_train_text_encoder_only: - logs["lr/textencoder"] = float(lrs[0]) - else: - logs["lr/textencoder"] = float(lrs[0]) - logs["lr/unet"] = float(lrs[-1]) # may be same to textencoder + lrs = lr_scheduler.get_last_lr() - if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): # tracking d*lr value of unet. - logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] - else: - idx = 0 - if not args.network_train_unet_only: - logs["lr/textencoder"] = float(lrs[0]) - idx = 1 + if args.network_train_text_encoder_only or len(lrs) <= 2: # not block lr (or single block) + if args.network_train_unet_only: + logs["lr/unet"] = float(lrs[0]) + elif args.network_train_text_encoder_only: + logs["lr/textencoder"] = float(lrs[0]) + else: + logs["lr/textencoder"] = float(lrs[0]) + logs["lr/unet"] = float(lrs[-1]) # may be same to textencoder - for i in range(idx, len(lrs)): - logs[f"lr/group{i}"] = float(lrs[i]) - if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): - logs[f"lr/d*lr/group{i}"] = ( - lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] + if ( + args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower() + ): # tracking d*lr value of unet. + logs["lr/d*lr"] = ( + lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] ) + else: + idx = 0 + if not args.network_train_unet_only: + logs["lr/textencoder"] = float(lrs[0]) + idx = 1 - return logs - - -def train(args): - session_id = random.randint(0, 2**32) - training_started_at = time.time() - train_util.verify_training_args(args) - train_util.prepare_dataset_args(args, True) - - cache_latents = args.cache_latents - use_dreambooth_method = args.in_json is None - use_user_config = args.dataset_config is not None - - if args.seed is None: - args.seed = random.randint(0, 2**32) - set_seed(args.seed) - - tokenizer = train_util.load_tokenizer(args) - - # データセットを準備する - if args.dataset_class is None: - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) - if use_user_config: - print(f"Loading dataset config from {args.dataset_config}") - user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "reg_data_dir", "in_json"] - if any(getattr(args, attr) is not None for attr in ignored): - print( - "ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( - ", ".join(ignored) + for i in range(idx, len(lrs)): + logs[f"lr/group{i}"] = float(lrs[i]) + if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): + logs[f"lr/d*lr/group{i}"] = ( + lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] ) - ) - else: - if use_dreambooth_method: - print("Using DreamBooth method.") - user_config = { - "datasets": [ - { - "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( - args.train_data_dir, args.reg_data_dir - ) - } - ] - } - else: - print("Training with captions.") - user_config = { - "datasets": [ - { - "subsets": [ - { - "image_dir": args.train_data_dir, - "metadata_file": args.in_json, - } - ] - } - ] - } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) - else: - # use arbitrary dataset class - train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer) - - current_epoch = Value("i", 0) - current_step = Value("i", 0) - ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None - collater = train_util.collater_class(current_epoch, current_step, ds_for_collater) - - if args.debug_dataset: - train_util.debug_dataset(train_dataset_group) - return - if len(train_dataset_group) == 0: - print( - "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)" - ) - return - - if cache_latents: - assert ( - train_dataset_group.is_latent_cacheable() - ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" - - # acceleratorを準備する - print("preparing accelerator") - accelerator = train_util.prepare_accelerator(args) - is_main_process = accelerator.is_main_process - - # mixed precisionに対応した型を用意しておき適宜castする - weight_dtype, save_dtype = train_util.prepare_dtype(args) - - # モデルを読み込む - text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) - - # モデルに xformers とか memory efficient attention を組み込む - train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) - - # 差分追加学習のためにモデルを読み込む - import sys - - sys.path.append(os.path.dirname(__file__)) - accelerator.print("import network module:", args.network_module) - network_module = importlib.import_module(args.network_module) - - if args.base_weights is not None: - # base_weights が指定されている場合は、指定された重みを読み込みマージする - for i, weight_path in enumerate(args.base_weights): - if args.base_weights_multiplier is None or len(args.base_weights_multiplier) <= i: - multiplier = 1.0 - else: - multiplier = args.base_weights_multiplier[i] - - accelerator.print(f"merging module: {weight_path} with multiplier {multiplier}") - - module, weights_sd = network_module.create_network_from_weights( - multiplier, weight_path, vae, text_encoder, unet, for_inference=True - ) - module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu") - - accelerator.print(f"all weights merged: {', '.join(args.base_weights)}") - - # 学習を準備する - if cache_latents: - vae.to(accelerator.device, dtype=weight_dtype) - vae.requires_grad_(False) - vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) - vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() - - accelerator.wait_for_everyone() - - # prepare network - net_kwargs = {} - if args.network_args is not None: - for net_arg in args.network_args: - key, value = net_arg.split("=") - net_kwargs[key] = value - - # if a new network is added in future, add if ~ then blocks for each network (;'∀') - if args.dim_from_weights: - network, _ = network_module.create_network_from_weights(1, args.network_weights, vae, text_encoder, unet, **net_kwargs) - else: - # LyCORIS will work with this... - network = network_module.create_network( - 1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, neuron_dropout=args.network_dropout, **net_kwargs - ) - if network is None: - return - - if hasattr(network, "prepare_network"): - network.prepare_network(args) - if args.scale_weight_norms and not hasattr(network, "apply_max_norm_regularization"): - print( - "warning: scale_weight_norms is specified but the network does not support it / scale_weight_normsが指定されていますが、ネットワークが対応していません" - ) - args.scale_weight_norms = False - - train_unet = not args.network_train_text_encoder_only - train_text_encoder = not args.network_train_unet_only - network.apply_to(text_encoder, unet, train_text_encoder, train_unet) - - if args.network_weights is not None: - info = network.load_weights(args.network_weights) - accelerator.print(f"load network weights from {args.network_weights}: {info}") - - if args.gradient_checkpointing: - unet.enable_gradient_checkpointing() - text_encoder.gradient_checkpointing_enable() - network.enable_gradient_checkpointing() # may have no effect - - # 学習に必要なクラスを準備する - accelerator.print("prepare optimizer, data loader etc.") - - # 後方互換性を確保するよ - try: - trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate) - except TypeError: - accelerator.print( - "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)" - ) - trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr) - - optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) - - # dataloaderを準備する - # DataLoaderのプロセス数:0はメインプロセスになる - n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで - - train_dataloader = torch.utils.data.DataLoader( - train_dataset_group, - batch_size=1, - shuffle=True, - collate_fn=collater, - num_workers=n_workers, - persistent_workers=args.persistent_data_loader_workers, - ) - - # 学習ステップ数を計算する - if args.max_train_epochs is not None: - args.max_train_steps = args.max_train_epochs * math.ceil( - len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps - ) - accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") - - # データセット側にも学習ステップを送信 - train_dataset_group.set_max_train_steps(args.max_train_steps) - - # lr schedulerを用意する - lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) - - # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする - if args.full_fp16: - assert ( - args.mixed_precision == "fp16" - ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" - accelerator.print("enable full fp16 training.") - network.to(weight_dtype) - - # acceleratorがなんかよろしくやってくれるらしい - if train_unet and train_text_encoder: - unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler - ) - elif train_unet: - unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, network, optimizer, train_dataloader, lr_scheduler - ) - elif train_text_encoder: - text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - text_encoder, network, optimizer, train_dataloader, lr_scheduler - ) - else: - network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler) - - # transform DDP after prepare (train_network here only) - text_encoder, unet, network = train_util.transform_if_model_is_DDP(text_encoder, unet, network) - - unet.requires_grad_(False) - unet.to(accelerator.device, dtype=weight_dtype) - text_encoder.requires_grad_(False) - text_encoder.to(accelerator.device) - if args.gradient_checkpointing: # according to TI example in Diffusers, train is required - unet.train() - text_encoder.train() - - # set top parameter requires_grad = True for gradient checkpointing works - text_encoder.text_model.embeddings.requires_grad_(True) - else: - unet.eval() - text_encoder.eval() - - network.prepare_grad_etc(text_encoder, unet) - - if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する - vae.requires_grad_(False) - vae.eval() - vae.to(accelerator.device, dtype=weight_dtype) - - # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする - if args.full_fp16: - train_util.patch_accelerator_for_fp16_training(accelerator) - - # resumeする - train_util.resume_from_local_or_hf_if_specified(accelerator, args) - - # epoch数を計算する - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): - args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 - - # 学習する - # TODO: find a way to handle total batch size when there are multiple datasets - total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - - accelerator.print("running training / 学習開始") - accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") - accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") - accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") - accelerator.print(f" num epochs / epoch数: {num_train_epochs}") - accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") - # accelerator.print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") - accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") - accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") - - # TODO refactor metadata creation and move to util - metadata = { - "ss_session_id": session_id, # random integer indicating which group of epochs the model came from - "ss_training_started_at": training_started_at, # unix timestamp - "ss_output_name": args.output_name, - "ss_learning_rate": args.learning_rate, - "ss_text_encoder_lr": args.text_encoder_lr, - "ss_unet_lr": args.unet_lr, - "ss_num_train_images": train_dataset_group.num_train_images, - "ss_num_reg_images": train_dataset_group.num_reg_images, - "ss_num_batches_per_epoch": len(train_dataloader), - "ss_num_epochs": num_train_epochs, - "ss_gradient_checkpointing": args.gradient_checkpointing, - "ss_gradient_accumulation_steps": args.gradient_accumulation_steps, - "ss_max_train_steps": args.max_train_steps, - "ss_lr_warmup_steps": args.lr_warmup_steps, - "ss_lr_scheduler": args.lr_scheduler, - "ss_network_module": args.network_module, - "ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim - "ss_network_alpha": args.network_alpha, # some networks may not have alpha - "ss_network_dropout": args.network_dropout, # some networks may not have dropout - "ss_mixed_precision": args.mixed_precision, - "ss_full_fp16": bool(args.full_fp16), - "ss_v2": bool(args.v2), - "ss_clip_skip": args.clip_skip, - "ss_max_token_length": args.max_token_length, - "ss_cache_latents": bool(args.cache_latents), - "ss_seed": args.seed, - "ss_lowram": args.lowram, - "ss_noise_offset": args.noise_offset, - "ss_multires_noise_iterations": args.multires_noise_iterations, - "ss_multires_noise_discount": args.multires_noise_discount, - "ss_adaptive_noise_scale": args.adaptive_noise_scale, - "ss_training_comment": args.training_comment, # will not be updated after training - "ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(), - "ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""), - "ss_max_grad_norm": args.max_grad_norm, - "ss_caption_dropout_rate": args.caption_dropout_rate, - "ss_caption_dropout_every_n_epochs": args.caption_dropout_every_n_epochs, - "ss_caption_tag_dropout_rate": args.caption_tag_dropout_rate, - "ss_face_crop_aug_range": args.face_crop_aug_range, - "ss_prior_loss_weight": args.prior_loss_weight, - "ss_min_snr_gamma": args.min_snr_gamma, - "ss_scale_weight_norms": args.scale_weight_norms, - } - - if use_user_config: - # save metadata of multiple datasets - # NOTE: pack "ss_datasets" value as json one time - # or should also pack nested collections as json? - datasets_metadata = [] - tag_frequency = {} # merge tag frequency for metadata editor - dataset_dirs_info = {} # merge subset dirs for metadata editor - - for dataset in train_dataset_group.datasets: - is_dreambooth_dataset = isinstance(dataset, DreamBoothDataset) - dataset_metadata = { - "is_dreambooth": is_dreambooth_dataset, - "batch_size_per_device": dataset.batch_size, - "num_train_images": dataset.num_train_images, # includes repeating - "num_reg_images": dataset.num_reg_images, - "resolution": (dataset.width, dataset.height), - "enable_bucket": bool(dataset.enable_bucket), - "min_bucket_reso": dataset.min_bucket_reso, - "max_bucket_reso": dataset.max_bucket_reso, - "tag_frequency": dataset.tag_frequency, - "bucket_info": dataset.bucket_info, - } - - subsets_metadata = [] - for subset in dataset.subsets: - subset_metadata = { - "img_count": subset.img_count, - "num_repeats": subset.num_repeats, - "color_aug": bool(subset.color_aug), - "flip_aug": bool(subset.flip_aug), - "random_crop": bool(subset.random_crop), - "shuffle_caption": bool(subset.shuffle_caption), - "keep_tokens": subset.keep_tokens, - } - - image_dir_or_metadata_file = None - if subset.image_dir: - image_dir = os.path.basename(subset.image_dir) - subset_metadata["image_dir"] = image_dir - image_dir_or_metadata_file = image_dir - - if is_dreambooth_dataset: - subset_metadata["class_tokens"] = subset.class_tokens - subset_metadata["is_reg"] = subset.is_reg - if subset.is_reg: - image_dir_or_metadata_file = None # not merging reg dataset - else: - metadata_file = os.path.basename(subset.metadata_file) - subset_metadata["metadata_file"] = metadata_file - image_dir_or_metadata_file = metadata_file # may overwrite - - subsets_metadata.append(subset_metadata) - - # merge dataset dir: not reg subset only - # TODO update additional-network extension to show detailed dataset config from metadata - if image_dir_or_metadata_file is not None: - # datasets may have a certain dir multiple times - v = image_dir_or_metadata_file - i = 2 - while v in dataset_dirs_info: - v = image_dir_or_metadata_file + f" ({i})" - i += 1 - image_dir_or_metadata_file = v - - dataset_dirs_info[image_dir_or_metadata_file] = {"n_repeats": subset.num_repeats, "img_count": subset.img_count} - - dataset_metadata["subsets"] = subsets_metadata - datasets_metadata.append(dataset_metadata) - - # merge tag frequency: - for ds_dir_name, ds_freq_for_dir in dataset.tag_frequency.items(): - # あるディレクトリが複数のdatasetで使用されている場合、一度だけ数える - # もともと繰り返し回数を指定しているので、キャプション内でのタグの出現回数と、それが学習で何度使われるかは一致しない - # なので、ここで複数datasetの回数を合算してもあまり意味はない - if ds_dir_name in tag_frequency: - continue - tag_frequency[ds_dir_name] = ds_freq_for_dir - - metadata["ss_datasets"] = json.dumps(datasets_metadata) - metadata["ss_tag_frequency"] = json.dumps(tag_frequency) - metadata["ss_dataset_dirs"] = json.dumps(dataset_dirs_info) - else: - # conserving backward compatibility when using train_dataset_dir and reg_dataset_dir - assert ( - len(train_dataset_group.datasets) == 1 - ), f"There should be a single dataset but {len(train_dataset_group.datasets)} found. This seems to be a bug. / データセットは1個だけ存在するはずですが、実際には{len(train_dataset_group.datasets)}個でした。プログラムのバグかもしれません。" - - dataset = train_dataset_group.datasets[0] - - dataset_dirs_info = {} - reg_dataset_dirs_info = {} - if use_dreambooth_method: - for subset in dataset.subsets: - info = reg_dataset_dirs_info if subset.is_reg else dataset_dirs_info - info[os.path.basename(subset.image_dir)] = {"n_repeats": subset.num_repeats, "img_count": subset.img_count} - else: - for subset in dataset.subsets: - dataset_dirs_info[os.path.basename(subset.metadata_file)] = { - "n_repeats": subset.num_repeats, - "img_count": subset.img_count, - } - - metadata.update( - { - "ss_batch_size_per_device": args.train_batch_size, - "ss_total_batch_size": total_batch_size, - "ss_resolution": args.resolution, - "ss_color_aug": bool(args.color_aug), - "ss_flip_aug": bool(args.flip_aug), - "ss_random_crop": bool(args.random_crop), - "ss_shuffle_caption": bool(args.shuffle_caption), - "ss_enable_bucket": bool(dataset.enable_bucket), - "ss_bucket_no_upscale": bool(dataset.bucket_no_upscale), - "ss_min_bucket_reso": dataset.min_bucket_reso, - "ss_max_bucket_reso": dataset.max_bucket_reso, - "ss_keep_tokens": args.keep_tokens, - "ss_dataset_dirs": json.dumps(dataset_dirs_info), - "ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info), - "ss_tag_frequency": json.dumps(dataset.tag_frequency), - "ss_bucket_info": json.dumps(dataset.bucket_info), - } - ) - - # add extra args - if args.network_args: - metadata["ss_network_args"] = json.dumps(net_kwargs) - - # model name and hash - if args.pretrained_model_name_or_path is not None: - sd_model_name = args.pretrained_model_name_or_path - if os.path.exists(sd_model_name): - metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name) - metadata["ss_new_sd_model_hash"] = train_util.calculate_sha256(sd_model_name) - sd_model_name = os.path.basename(sd_model_name) - metadata["ss_sd_model_name"] = sd_model_name - - if args.vae is not None: - vae_name = args.vae - if os.path.exists(vae_name): - metadata["ss_vae_hash"] = train_util.model_hash(vae_name) - metadata["ss_new_vae_hash"] = train_util.calculate_sha256(vae_name) - vae_name = os.path.basename(vae_name) - metadata["ss_vae_name"] = vae_name - - metadata = {k: str(v) for k, v in metadata.items()} - - # make minimum metadata for filtering - minimum_keys = ["ss_network_module", "ss_network_dim", "ss_network_alpha", "ss_network_args"] - minimum_metadata = {} - for key in minimum_keys: - if key in metadata: - minimum_metadata[key] = metadata[key] - - progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") - global_step = 0 - - noise_scheduler = DDPMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False - ) - prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) - - if accelerator.is_main_process: - accelerator.init_trackers("network_train" if args.log_tracker_name is None else args.log_tracker_name) - - loss_list = [] - loss_total = 0.0 - del train_dataset_group - - # callback for step start - if hasattr(network, "on_step_start"): - on_step_start = network.on_step_start - else: - on_step_start = lambda *args, **kwargs: None - - # function for saving/removing - def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False): - os.makedirs(args.output_dir, exist_ok=True) - ckpt_file = os.path.join(args.output_dir, ckpt_name) - - accelerator.print(f"\nsaving checkpoint: {ckpt_file}") - metadata["ss_training_finished_at"] = str(time.time()) - metadata["ss_steps"] = str(steps) - metadata["ss_epoch"] = str(epoch_no) - - unwrapped_nw.save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata) - if args.huggingface_repo_id is not None: - huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) - - def remove_model(old_ckpt_name): - old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) - if os.path.exists(old_ckpt_file): - accelerator.print(f"removing old checkpoint: {old_ckpt_file}") - os.remove(old_ckpt_file) - - # training loop - for epoch in range(num_train_epochs): - accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") - current_epoch.value = epoch + 1 - - metadata["ss_epoch"] = str(epoch + 1) - - network.on_epoch_start(text_encoder, unet) - - for step, batch in enumerate(train_dataloader): - current_step.value = global_step - with accelerator.accumulate(network): - on_step_start(text_encoder, unet) - - with torch.no_grad(): - if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) - else: - # latentに変換 - latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() - latents = latents * 0.18215 - b_size = latents.shape[0] - - with torch.set_grad_enabled(train_text_encoder): - # Get the text embedding for conditioning - if args.weighted_captions: - encoder_hidden_states = get_weighted_text_embeddings( - tokenizer, - text_encoder, - batch["captions"], - accelerator.device, - args.max_token_length // 75 if args.max_token_length else 1, - clip_skip=args.clip_skip, + return logs + + def assert_extra_args(self, args): + pass + + def load_target_model(self, args, weight_dtype, accelerator): + text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) + return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet + + def load_tokenizer(self, args): + tokenizer = train_util.load_tokenizer(args) + return tokenizer + + def is_text_encoder_outputs_cached(self, args): + return False + + def cache_text_encoder_outputs_if_needed( + self, args, accelerator, uner, vae, tokenizers, text_encoders, data_loader, weight_dtype + ): + for t_enc in text_encoders: + t_enc.to(accelerator.device) + + def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): + input_ids = batch["input_ids"].to(accelerator.device) + encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizers[0], text_encoders[0], weight_dtype) + return encoder_hidden_states + + def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): + noise_pred = unet(noisy_latents, timesteps, text_conds).sample + return noise_pred + + def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet): + train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet) + + def train(self, args): + session_id = random.randint(0, 2**32) + training_started_at = time.time() + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + + cache_latents = args.cache_latents + use_dreambooth_method = args.in_json is None + use_user_config = args.dataset_config is not None + + if args.seed is None: + args.seed = random.randint(0, 2**32) + set_seed(args.seed) + + # tokenizerは単体またはリスト、tokenizersは必ずリスト:既存のコードとの互換性のため + tokenizer = self.load_tokenizer(args) + tokenizers = tokenizer if isinstance(tokenizer, list) else [tokenizer] + + # データセットを準備する + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) + if use_user_config: + print(f"Loading dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "reg_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + print( + "ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) ) - else: - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype) - - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents, device=latents.device) - if args.noise_offset: - noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale) - elif args.multires_noise_iterations: - noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount) - - # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) - timesteps = timesteps.long() - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - # Predict the noise residual - with accelerator.autocast(): - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) + ) + else: + if use_dreambooth_method: + print("Using DreamBooth method.") + user_config = { + "datasets": [ + { + "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( + args.train_data_dir, args.reg_data_dir + ) + } + ] + } else: - target = noise + print("Training with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - loss = loss.mean([1, 2, 3]) + blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + else: + # use arbitrary dataset class + train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer) - loss_weights = batch["loss_weights"] # 各sampleごとのweight - loss = loss * loss_weights + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collater = train_util.collater_class(current_epoch, current_step, ds_for_collater) - if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) - if args.scale_v_pred_loss_like_noise_pred: - loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if args.debug_dataset: + train_util.debug_dataset(train_dataset_group) + return + if len(train_dataset_group) == 0: + print( + "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)" + ) + return - loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" - accelerator.backward(loss) - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = network.get_trainable_params() - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + self.assert_extra_args(args, train_dataset_group) - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) + # acceleratorを準備する + print("preparing accelerator") + accelerator = train_util.prepare_accelerator(args) + is_main_process = accelerator.is_main_process - if args.scale_weight_norms: - keys_scaled, mean_norm, maximum_norm = network.apply_max_norm_regularization( - args.scale_weight_norms, accelerator.device + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + + # モデルを読み込む + model_version, text_encoder, vae, unet = self.load_target_model(args, weight_dtype, accelerator) + + # text_encoder is List[CLIPTextModel] or CLIPTextModel + text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder] + + # モデルに xformers とか memory efficient attention を組み込む + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) + + # 差分追加学習のためにモデルを読み込む + sys.path.append(os.path.dirname(__file__)) + accelerator.print("import network module:", args.network_module) + network_module = importlib.import_module(args.network_module) + + if args.base_weights is not None: + # base_weights が指定されている場合は、指定された重みを読み込みマージする + for i, weight_path in enumerate(args.base_weights): + if args.base_weights_multiplier is None or len(args.base_weights_multiplier) <= i: + multiplier = 1.0 + else: + multiplier = args.base_weights_multiplier[i] + + accelerator.print(f"merging module: {weight_path} with multiplier {multiplier}") + + module, weights_sd = network_module.create_network_from_weights( + multiplier, weight_path, vae, text_encoder, unet, for_inference=True ) - max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm} - else: - keys_scaled, mean_norm, maximum_norm = None, None, None + module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu") - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - progress_bar.update(1) - global_step += 1 + accelerator.print(f"all weights merged: {', '.join(args.base_weights)}") - train_util.sample_images( - accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet + # 学習を準備する + if cache_latents: + vae.to(accelerator.device, dtype=weight_dtype) + vae.requires_grad_(False) + vae.eval() + with torch.no_grad(): + train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) + vae.to("cpu") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + accelerator.wait_for_everyone() + + # prepare network + net_kwargs = {} + if args.network_args is not None: + for net_arg in args.network_args: + key, value = net_arg.split("=") + net_kwargs[key] = value + + # if a new network is added in future, add if ~ then blocks for each network (;'∀') + if args.dim_from_weights: + network, _ = network_module.create_network_from_weights(1, args.network_weights, vae, text_encoder, unet, **net_kwargs) + else: + # LyCORIS will work with this... + network = network_module.create_network( + 1.0, + args.network_dim, + args.network_alpha, + vae, + text_encoder, + unet, + neuron_dropout=args.network_dropout, + **net_kwargs, + ) + if network is None: + return + + if hasattr(network, "prepare_network"): + network.prepare_network(args) + if args.scale_weight_norms and not hasattr(network, "apply_max_norm_regularization"): + print( + "warning: scale_weight_norms is specified but the network does not support it / scale_weight_normsが指定されていますが、ネットワークが対応していません" + ) + args.scale_weight_norms = False + + train_unet = not args.network_train_text_encoder_only + train_text_encoder = not args.network_train_unet_only and not self.is_text_encoder_outputs_cached(args) + network.apply_to(text_encoder, unet, train_text_encoder, train_unet) + + if args.network_weights is not None: + info = network.load_weights(args.network_weights) + accelerator.print(f"load network weights from {args.network_weights}: {info}") + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + for t_enc in text_encoders: + t_enc.gradient_checkpointing_enable() + del t_enc + network.enable_gradient_checkpointing() # may have no effect + + # 学習に必要なクラスを準備する + accelerator.print("prepare optimizer, data loader etc.") + + # 後方互換性を確保するよ + try: + trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate) + except TypeError: + accelerator.print( + "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)" + ) + trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr) + + optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) + + # dataloaderを準備する + # DataLoaderのプロセス数:0はメインプロセスになる + n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collater, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) + + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # lr schedulerを用意する + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + accelerator.print("enable full fp16 training.") + network.to(weight_dtype) + + # acceleratorがなんかよろしくやってくれるらしい + # TODO めちゃくちゃ冗長なのでコードを整理する + if train_unet and train_text_encoder: + if len(text_encoders) > 1: + unet, t_enc1, t_enc2, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoders[0], text_encoders[1], network, optimizer, train_dataloader, lr_scheduler ) - - # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: - accelerator.wait_for_everyone() - if accelerator.is_main_process: - ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) - save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch) - - if args.save_state: - train_util.save_and_remove_state_stepwise(args, accelerator, global_step) - - remove_step_no = train_util.get_remove_step_no(args, global_step) - if remove_step_no is not None: - remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no) - remove_model(remove_ckpt_name) - - current_loss = loss.detach().item() - if epoch == 0: - loss_list.append(current_loss) + text_encoder = text_encoders = [t_enc1, t_enc2] + del t_enc1, t_enc2 else: - loss_total -= loss_list[step] - loss_list[step] = current_loss - loss_total += current_loss - avr_loss = loss_total / len(loss_list) - logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} - progress_bar.set_postfix(**logs) + unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler + ) + text_encoders = [text_encoder] + elif train_unet: + unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, network, optimizer, train_dataloader, lr_scheduler + ) + elif train_text_encoder: + if len(text_encoders) > 1: + t_enc1, t_enc2, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + text_encoders[0], text_encoders[1], network, optimizer, train_dataloader, lr_scheduler + ) + text_encoder = text_encoders = [t_enc1, t_enc2] + del t_enc1, t_enc2 + else: + text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + text_encoder, network, optimizer, train_dataloader, lr_scheduler + ) + text_encoders = [text_encoder] + else: + network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + network, optimizer, train_dataloader, lr_scheduler + ) - if args.scale_weight_norms: - progress_bar.set_postfix(**{**max_mean_logs, **logs}) + # transform DDP after prepare (train_network here only) + text_encoders = train_util.transform_models_if_DDP(text_encoders) + unet, network = train_util.transform_models_if_DDP([unet, network]) + + unet.requires_grad_(False) + unet.to(accelerator.device, dtype=weight_dtype) + for t_enc in text_encoders: + t_enc.requires_grad_(False) + + if args.gradient_checkpointing: + # according to TI example in Diffusers, train is required + unet.train() + for t_enc in text_encoders: + t_enc.train() + + # set top parameter requires_grad = True for gradient checkpointing works + t_enc.text_model.embeddings.requires_grad_(True) + else: + unet.eval() + for t_enc in text_encoders: + t_enc.eval() + + del t_enc + + network.prepare_grad_etc(text_encoder, unet) + + if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=weight_dtype) + + # 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される + self.cache_text_encoder_outputs_if_needed( + args, accelerator, unet, vae, tokenizers, text_encoders, train_dataloader, weight_dtype + ) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + # TODO: find a way to handle total batch size when there are multiple datasets + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + accelerator.print("running training / 学習開始") + accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) + # accelerator.print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + # TODO refactor metadata creation and move to util + metadata = { + "ss_session_id": session_id, # random integer indicating which group of epochs the model came from + "ss_training_started_at": training_started_at, # unix timestamp + "ss_output_name": args.output_name, + "ss_learning_rate": args.learning_rate, + "ss_text_encoder_lr": args.text_encoder_lr, + "ss_unet_lr": args.unet_lr, + "ss_num_train_images": train_dataset_group.num_train_images, + "ss_num_reg_images": train_dataset_group.num_reg_images, + "ss_num_batches_per_epoch": len(train_dataloader), + "ss_num_epochs": num_train_epochs, + "ss_gradient_checkpointing": args.gradient_checkpointing, + "ss_gradient_accumulation_steps": args.gradient_accumulation_steps, + "ss_max_train_steps": args.max_train_steps, + "ss_lr_warmup_steps": args.lr_warmup_steps, + "ss_lr_scheduler": args.lr_scheduler, + "ss_network_module": args.network_module, + "ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim + "ss_network_alpha": args.network_alpha, # some networks may not have alpha + "ss_network_dropout": args.network_dropout, # some networks may not have dropout + "ss_mixed_precision": args.mixed_precision, + "ss_full_fp16": bool(args.full_fp16), + "ss_v2": bool(args.v2), + "ss_base_model_version": model_version, + "ss_clip_skip": args.clip_skip, + "ss_max_token_length": args.max_token_length, + "ss_cache_latents": bool(args.cache_latents), + "ss_seed": args.seed, + "ss_lowram": args.lowram, + "ss_noise_offset": args.noise_offset, + "ss_multires_noise_iterations": args.multires_noise_iterations, + "ss_multires_noise_discount": args.multires_noise_discount, + "ss_adaptive_noise_scale": args.adaptive_noise_scale, + "ss_training_comment": args.training_comment, # will not be updated after training + "ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(), + "ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""), + "ss_max_grad_norm": args.max_grad_norm, + "ss_caption_dropout_rate": args.caption_dropout_rate, + "ss_caption_dropout_every_n_epochs": args.caption_dropout_every_n_epochs, + "ss_caption_tag_dropout_rate": args.caption_tag_dropout_rate, + "ss_face_crop_aug_range": args.face_crop_aug_range, + "ss_prior_loss_weight": args.prior_loss_weight, + "ss_min_snr_gamma": args.min_snr_gamma, + "ss_scale_weight_norms": args.scale_weight_norms, + } + + if use_user_config: + # save metadata of multiple datasets + # NOTE: pack "ss_datasets" value as json one time + # or should also pack nested collections as json? + datasets_metadata = [] + tag_frequency = {} # merge tag frequency for metadata editor + dataset_dirs_info = {} # merge subset dirs for metadata editor + + for dataset in train_dataset_group.datasets: + is_dreambooth_dataset = isinstance(dataset, DreamBoothDataset) + dataset_metadata = { + "is_dreambooth": is_dreambooth_dataset, + "batch_size_per_device": dataset.batch_size, + "num_train_images": dataset.num_train_images, # includes repeating + "num_reg_images": dataset.num_reg_images, + "resolution": (dataset.width, dataset.height), + "enable_bucket": bool(dataset.enable_bucket), + "min_bucket_reso": dataset.min_bucket_reso, + "max_bucket_reso": dataset.max_bucket_reso, + "tag_frequency": dataset.tag_frequency, + "bucket_info": dataset.bucket_info, + } + + subsets_metadata = [] + for subset in dataset.subsets: + subset_metadata = { + "img_count": subset.img_count, + "num_repeats": subset.num_repeats, + "color_aug": bool(subset.color_aug), + "flip_aug": bool(subset.flip_aug), + "random_crop": bool(subset.random_crop), + "shuffle_caption": bool(subset.shuffle_caption), + "keep_tokens": subset.keep_tokens, + } + + image_dir_or_metadata_file = None + if subset.image_dir: + image_dir = os.path.basename(subset.image_dir) + subset_metadata["image_dir"] = image_dir + image_dir_or_metadata_file = image_dir + + if is_dreambooth_dataset: + subset_metadata["class_tokens"] = subset.class_tokens + subset_metadata["is_reg"] = subset.is_reg + if subset.is_reg: + image_dir_or_metadata_file = None # not merging reg dataset + else: + metadata_file = os.path.basename(subset.metadata_file) + subset_metadata["metadata_file"] = metadata_file + image_dir_or_metadata_file = metadata_file # may overwrite + + subsets_metadata.append(subset_metadata) + + # merge dataset dir: not reg subset only + # TODO update additional-network extension to show detailed dataset config from metadata + if image_dir_or_metadata_file is not None: + # datasets may have a certain dir multiple times + v = image_dir_or_metadata_file + i = 2 + while v in dataset_dirs_info: + v = image_dir_or_metadata_file + f" ({i})" + i += 1 + image_dir_or_metadata_file = v + + dataset_dirs_info[image_dir_or_metadata_file] = { + "n_repeats": subset.num_repeats, + "img_count": subset.img_count, + } + + dataset_metadata["subsets"] = subsets_metadata + datasets_metadata.append(dataset_metadata) + + # merge tag frequency: + for ds_dir_name, ds_freq_for_dir in dataset.tag_frequency.items(): + # あるディレクトリが複数のdatasetで使用されている場合、一度だけ数える + # もともと繰り返し回数を指定しているので、キャプション内でのタグの出現回数と、それが学習で何度使われるかは一致しない + # なので、ここで複数datasetの回数を合算してもあまり意味はない + if ds_dir_name in tag_frequency: + continue + tag_frequency[ds_dir_name] = ds_freq_for_dir + + metadata["ss_datasets"] = json.dumps(datasets_metadata) + metadata["ss_tag_frequency"] = json.dumps(tag_frequency) + metadata["ss_dataset_dirs"] = json.dumps(dataset_dirs_info) + else: + # conserving backward compatibility when using train_dataset_dir and reg_dataset_dir + assert ( + len(train_dataset_group.datasets) == 1 + ), f"There should be a single dataset but {len(train_dataset_group.datasets)} found. This seems to be a bug. / データセットは1個だけ存在するはずですが、実際には{len(train_dataset_group.datasets)}個でした。プログラムのバグかもしれません。" + + dataset = train_dataset_group.datasets[0] + + dataset_dirs_info = {} + reg_dataset_dirs_info = {} + if use_dreambooth_method: + for subset in dataset.subsets: + info = reg_dataset_dirs_info if subset.is_reg else dataset_dirs_info + info[os.path.basename(subset.image_dir)] = {"n_repeats": subset.num_repeats, "img_count": subset.img_count} + else: + for subset in dataset.subsets: + dataset_dirs_info[os.path.basename(subset.metadata_file)] = { + "n_repeats": subset.num_repeats, + "img_count": subset.img_count, + } + + metadata.update( + { + "ss_batch_size_per_device": args.train_batch_size, + "ss_total_batch_size": total_batch_size, + "ss_resolution": args.resolution, + "ss_color_aug": bool(args.color_aug), + "ss_flip_aug": bool(args.flip_aug), + "ss_random_crop": bool(args.random_crop), + "ss_shuffle_caption": bool(args.shuffle_caption), + "ss_enable_bucket": bool(dataset.enable_bucket), + "ss_bucket_no_upscale": bool(dataset.bucket_no_upscale), + "ss_min_bucket_reso": dataset.min_bucket_reso, + "ss_max_bucket_reso": dataset.max_bucket_reso, + "ss_keep_tokens": args.keep_tokens, + "ss_dataset_dirs": json.dumps(dataset_dirs_info), + "ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info), + "ss_tag_frequency": json.dumps(dataset.tag_frequency), + "ss_bucket_info": json.dumps(dataset.bucket_info), + } + ) + + # add extra args + if args.network_args: + metadata["ss_network_args"] = json.dumps(net_kwargs) + + # model name and hash + if args.pretrained_model_name_or_path is not None: + sd_model_name = args.pretrained_model_name_or_path + if os.path.exists(sd_model_name): + metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name) + metadata["ss_new_sd_model_hash"] = train_util.calculate_sha256(sd_model_name) + sd_model_name = os.path.basename(sd_model_name) + metadata["ss_sd_model_name"] = sd_model_name + + if args.vae is not None: + vae_name = args.vae + if os.path.exists(vae_name): + metadata["ss_vae_hash"] = train_util.model_hash(vae_name) + metadata["ss_new_vae_hash"] = train_util.calculate_sha256(vae_name) + vae_name = os.path.basename(vae_name) + metadata["ss_vae_name"] = vae_name + + metadata = {k: str(v) for k, v in metadata.items()} + + # make minimum metadata for filtering + minimum_keys = [ + "ss_v2", + "ss_base_model_version", + "ss_network_module", + "ss_network_dim", + "ss_network_alpha", + "ss_network_args", + ] + minimum_metadata = {} + for key in minimum_keys: + if key in metadata: + minimum_metadata[key] = metadata[key] + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + noise_scheduler = DDPMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False + ) + prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) + + if accelerator.is_main_process: + accelerator.init_trackers("network_train" if args.log_tracker_name is None else args.log_tracker_name) + + loss_list = [] + loss_total = 0.0 + del train_dataset_group + + # callback for step start + if hasattr(network, "on_step_start"): + on_step_start = network.on_step_start + else: + on_step_start = lambda *args, **kwargs: None + + # function for saving/removing + def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False): + os.makedirs(args.output_dir, exist_ok=True) + ckpt_file = os.path.join(args.output_dir, ckpt_name) + + accelerator.print(f"\nsaving checkpoint: {ckpt_file}") + metadata["ss_training_finished_at"] = str(time.time()) + metadata["ss_steps"] = str(steps) + metadata["ss_epoch"] = str(epoch_no) + + unwrapped_nw.save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata) + if args.huggingface_repo_id is not None: + huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) + + def remove_model(old_ckpt_name): + old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) + if os.path.exists(old_ckpt_file): + accelerator.print(f"removing old checkpoint: {old_ckpt_file}") + os.remove(old_ckpt_file) + + # training loop + for epoch in range(num_train_epochs): + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 + + metadata["ss_epoch"] = str(epoch + 1) + + network.on_epoch_start(text_encoder, unet) + + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + with accelerator.accumulate(network): + on_step_start(text_encoder, unet) + + with torch.no_grad(): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device) + else: + # latentに変換 + latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * self.vae_scale_factor + b_size = latents.shape[0] + + with torch.set_grad_enabled(train_text_encoder): + # Get the text embedding for conditioning + if args.weighted_captions: + text_encoder_conds = get_weighted_text_embeddings( + tokenizer, + text_encoder, + batch["captions"], + accelerator.device, + args.max_token_length // 75 if args.max_token_length else 1, + clip_skip=args.clip_skip, + ) + else: + text_encoder_conds = self.get_text_cond( + args, accelerator, batch, tokenizers, text_encoders, weight_dtype + ) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents, device=latents.device) + if args.noise_offset: + noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale) + elif args.multires_noise_iterations: + noise = pyramid_noise_like( + noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount + ) + + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) + timesteps = timesteps.long() + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Predict the noise residual + with accelerator.autocast(): + noise_pred = self.call_unet( + args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype + ) + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + + accelerator.backward(loss) + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = network.get_trainable_params() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + if args.scale_weight_norms: + keys_scaled, mean_norm, maximum_norm = network.apply_max_norm_regularization( + args.scale_weight_norms, accelerator.device + ) + max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm} + else: + keys_scaled, mean_norm, maximum_norm = None, None, None + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + self.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + + # 指定ステップごとにモデルを保存 + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) + save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch) + + if args.save_state: + train_util.save_and_remove_state_stepwise(args, accelerator, global_step) + + remove_step_no = train_util.get_remove_step_no(args, global_step) + if remove_step_no is not None: + remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no) + remove_model(remove_ckpt_name) + + current_loss = loss.detach().item() + if epoch == 0: + loss_list.append(current_loss) + else: + loss_total -= loss_list[step] + loss_list[step] = current_loss + loss_total += current_loss + avr_loss = loss_total / len(loss_list) + logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if args.scale_weight_norms: + progress_bar.set_postfix(**{**max_mean_logs, **logs}) + + if args.logging_dir is not None: + logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break if args.logging_dir is not None: - logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) - accelerator.log(logs, step=global_step) + logs = {"loss/epoch": loss_total / len(loss_list)} + accelerator.log(logs, step=epoch + 1) - if global_step >= args.max_train_steps: - break + accelerator.wait_for_everyone() - if args.logging_dir is not None: - logs = {"loss/epoch": loss_total / len(loss_list)} - accelerator.log(logs, step=epoch + 1) + # 指定エポックごとにモデルを保存 + if args.save_every_n_epochs is not None: + saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs + if is_main_process and saving: + ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1) + save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch + 1) - accelerator.wait_for_everyone() + remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) + if remove_epoch_no is not None: + remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no) + remove_model(remove_ckpt_name) - # 指定エポックごとにモデルを保存 - if args.save_every_n_epochs is not None: - saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs - if is_main_process and saving: - ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1) - save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch + 1) + if args.save_state: + train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) - remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) - if remove_epoch_no is not None: - remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no) - remove_model(remove_ckpt_name) + self.sample_images( + accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet + ) - if args.save_state: - train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) + # end of epoch - train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + # metadata["ss_epoch"] = str(num_train_epochs) + metadata["ss_training_finished_at"] = str(time.time()) - # end of epoch + if is_main_process: + network = accelerator.unwrap_model(network) - # metadata["ss_epoch"] = str(num_train_epochs) - metadata["ss_training_finished_at"] = str(time.time()) + accelerator.end_training() - if is_main_process: - network = accelerator.unwrap_model(network) + if is_main_process and args.save_state: + train_util.save_state_on_train_end(args, accelerator) - accelerator.end_training() + if is_main_process: + ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) + save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True) - if is_main_process and args.save_state: - train_util.save_state_on_train_end(args, accelerator) - - del accelerator # この後メモリを使うのでこれは消す - - if is_main_process: - ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) - save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True) - - print("model saved.") + print("model saved.") def setup_parser() -> argparse.ArgumentParser: @@ -866,4 +970,5 @@ if __name__ == "__main__": args = parser.parse_args() args = train_util.read_config_from_file(args, parser) - train(args) + trainer = NetworkTrainer() + trainer.train(args)