From f54b784d88246a7b3f60a46c3782f29cd892d0c7 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 10 Jul 2023 22:04:02 +0900 Subject: [PATCH] support textual inversion training --- README.md | 30 +- library/sdxl_train_util.py | 41 +- sdxl_gen_img.py | 33 +- sdxl_train_textual_inversion.py | 142 +++++ train_textual_inversion.py | 987 ++++++++++++++++++-------------- 5 files changed, 787 insertions(+), 446 deletions(-) create mode 100644 sdxl_train_textual_inversion.py diff --git a/README.md b/README.md index bf1d61c6..4e4150b6 100644 --- a/README.md +++ b/README.md @@ -25,18 +25,31 @@ The feature of SDXL training is now available in sdxl branch as an experimental Summary of the feature: - `sdxl_train.py` is a script for SDXL fine-tuning. The usage is almost the same as `fine_tune.py`, but it also supports DreamBooth dataset. - - `--full_bf16` option is added. This option enables the full bfloat16 training. This option is useful to reduce the GPU memory usage. + - `--full_bf16` option is added. Thanks to KohakuBlueleaf! + - This option enables the full bfloat16 training (includes gradients). This option is useful to reduce the GPU memory usage. - However, bitsandbytes==0.35 doesn't seem to support this. Please use a newer version of bitsandbytes or another optimizer. - I cannot find bitsandbytes>0.35.0 that works correctly on Windows. + - In addition, the full bfloat16 training might be unstable. Please use it at your own risk. - `prepare_buckets_latents.py` now supports SDXL fine-tuning. - `sdxl_train_network.py` is a script for LoRA training for SDXL. The usage is almost the same as `train_network.py`. - Both scripts has following additional options: - `--cache_text_encoder_outputs`: Cache the outputs of the text encoders. This option is useful to reduce the GPU memory usage. This option cannot be used with options for shuffling or dropping the captions. - `--no_half_vae`: Disable the half-precision (mixed-precision) VAE. VAE for SDXL seems to produce NaNs in some cases. This option is useful to avoid the NaNs. - The image generation during training is now available. However, the VAE for SDXL seems to produce NaNs in some cases when using `fp16`. The images will be black. Currently, the NaNs cannot be avoided even with `--no_half_vae` option. It works with `bf16` or without mixed precision. -- `--weighted_captions` option is not supported yet. + +- `--weighted_captions` option is not supported yet for both scripts. - `--min_timestep` and `--max_timestep` options are added to each training script. These options can be used to train U-Net with different timesteps. The default values are 0 and 1000. + +- `sdxl_train_textual_inversion.py` is a script for Textual Inversion training for SDXL. The usage is almost the same as `train_textual_inversion.py`. + - `--cache_text_encoder_outputs` is not supported. + - `token_string` must be alphabet only currently, due to the limitation of the open-clip tokenizer. + - There are two options for captions: + 1. Training with captions. All captions must include the token string. The token string is replaced with multiple tokens. + 2. Use `--use_object_template` or `--use_style_template` option. The captions are generated from the template. The existing captions are ignored. + - See below for the format of the embeddings. + - `sdxl_gen_img.py` is added. This script can be used to generate images with SDXL, including LoRA. See the help message for the usage. + - Textual Inversion is supported, but the name for the embeds in the caption becomes alphabet only. For example, `neg_hand_v1.safetensors` can be activated with `neghandv`. `requirements.txt` is updated to support SDXL training. @@ -54,7 +67,7 @@ Summary of the feature: - `--bucket_reso_steps` can be set to 32 instead of the default value 64. Smaller values than 32 will not work for SDXL training. Example of the optimizer settings for Adafactor with the fixed learning rate: -``` +```toml optimizer_type = "adafactor" optimizer_args = [ "scale_parameter=False", "relative_step=False", "warmup_init=False" ] lr_scheduler = "constant_with_warmup" @@ -62,13 +75,22 @@ lr_warmup_steps = 100 learning_rate = 4e-7 # SDXL original learning rate ``` +### Format of Textual Inversion embeddings + +```python +from safetensors.torch import save_file + +state_dict = {"clip_g": embs_for_text_encoder_1280, "clip_l": embs_for_text_encoder_768} +save_file(state_dict, file) +``` + ### TODO -- [ ] Support Textual Inversion training. - [ ] Support conversion of Diffusers SDXL models. - [ ] Support `--weighted_captions` option. - [ ] Change `--output_config` option to continue the training. - [ ] Extend `--full_bf16` for all the scripts. +- [x] Support Textual Inversion training. ## About requirements.txt diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index 675aac3d..0ce09715 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -78,12 +78,13 @@ def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtyp class WrapperTokenizer: # open clipのtokenizerをHuggingFaceのtokenizerと同じ形で使えるようにする + # make open clip tokenizer compatible with HuggingFace tokenizer def __init__(self): open_clip_tokenizer = open_clip.tokenizer._tokenizer self.model_max_length = 77 self.bos_token_id = open_clip_tokenizer.all_special_ids[0] self.eos_token_id = open_clip_tokenizer.all_special_ids[1] - self.pad_token_id = 0 # 結果から推定している + self.pad_token_id = 0 # 結果から推定している assumption from result def __call__(self, *args: Any, **kwds: Any) -> Any: return self.tokenize(*args, **kwds) @@ -107,6 +108,42 @@ class WrapperTokenizer: input_ids = input_ids[: eos_index + 1] # include eos return SimpleNamespace(**{"input_ids": input_ids}) + # for Textual Inversion + # わりと面倒くさいな……これWeb UIとかでどうするんだろう / this is a bit annoying... how to do this in Web UI? + + def encode(self, text, add_special_tokens=False): + assert not add_special_tokens + input_ids = open_clip.tokenizer._tokenizer.encode(text) + return input_ids + + def add_tokens(self, new_tokens): + tokens_to_add = [] + for token in new_tokens: + token = token.lower() + if token + "" not in open_clip.tokenizer._tokenizer.encoder: + tokens_to_add.append(token) + + # open clipのtokenizerに直接追加する / add tokens to open clip tokenizer + for token in tokens_to_add: + open_clip.tokenizer._tokenizer.encoder[token + ""] = len(open_clip.tokenizer._tokenizer.encoder) + open_clip.tokenizer._tokenizer.decoder[len(open_clip.tokenizer._tokenizer.decoder)] = token + "" + open_clip.tokenizer._tokenizer.vocab_size += 1 + + # open clipのtokenizerのcacheに直接設定することで、bpeとかいうやつに含まれていなくてもtokenizeできるようにする + # めちゃくちゃ乱暴なので、open clipのtokenizerの仕様が変わったら動かなくなる + # set cache of open clip tokenizer directly to enable tokenization even if the token is not included in bpe + # this is very rough, so it will not work if the specification of open clip tokenizer changes + open_clip.tokenizer._tokenizer.cache[token] = token + "" + + return len(tokens_to_add) + + def convert_tokens_to_ids(self, tokens): + input_ids = [open_clip.tokenizer._tokenizer.encoder[token + ""] for token in tokens] + return input_ids + + def __len__(self): + return open_clip.tokenizer._tokenizer.vocab_size + def load_tokenizers(args: argparse.Namespace): print("prepare tokenizers") @@ -392,7 +429,7 @@ def verify_sdxl_training_args(args: argparse.Namespace): print(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました") assert ( - not args.weighted_captions + not hasattr(args, "weighted_captions") or not args.weighted_captions ), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません" diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index 8f1c17d6..1e20595c 100644 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -320,7 +320,7 @@ class PipelineLike: self.scheduler = scheduler self.safety_checker = None - # Textual Inversion # not tested yet + # Textual Inversion self.token_replacements_list = [] for _ in range(len(self.text_encoders)): self.token_replacements_list.append({}) @@ -341,6 +341,10 @@ class PipelineLike: token_replacements = self.token_replacements_list[tokenizer_index] def replace_tokens(tokens): + # print("replace_tokens", tokens, "=>", token_replacements) + if isinstance(tokens, torch.Tensor): + tokens = tokens.tolist() + new_tokens = [] for token in tokens: if token in token_replacements: @@ -1594,19 +1598,26 @@ def main(args): if "string_to_param" in data: data = data["string_to_param"] - embeds1 = data["clip_l"] - embeds2 = data["clip_g"] + + embeds1 = data["clip_l"] # text encoder 1 + embeds2 = data["clip_g"] # text encoder 2 num_vectors_per_token = embeds1.size()[0] token_string = os.path.splitext(os.path.basename(embeds_file))[0] - token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)] + + # remove non-alphabet characters to avoid splitting by tokenizer + # TODO make random alphabet string + token_string = "".join([c for c in token_string if c.isalpha()]) + + token_strings = [token_string] + [f"{token_string}{chr(ord('a') + i)}" for i in range(num_vectors_per_token - 1)] # add new word to tokenizer, count is num_vectors_per_token num_added_tokens1 = tokenizer1.add_tokens(token_strings) - num_added_tokens2 = tokenizer2.add_tokens(token_strings) # not working now - assert ( - num_added_tokens1 == num_vectors_per_token and num_added_tokens2 == num_vectors_per_token - ), f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}" + num_added_tokens2 = tokenizer2.add_tokens(token_strings) + assert num_added_tokens1 == num_vectors_per_token and num_added_tokens2 == num_vectors_per_token, ( + f"tokenizer has same word to token string (filename). characters except alphabet are removed: {embeds_file}" + + f" / 指定した名前(ファイル名)のトークンが既に存在します。アルファベット以外の文字は削除されます: {embeds_file}" + ) token_ids1 = tokenizer1.convert_tokens_to_ids(token_strings) token_ids2 = tokenizer2.convert_tokens_to_ids(token_strings) @@ -1617,11 +1628,11 @@ def main(args): assert ( min(token_ids2) == token_ids2[0] and token_ids2[-1] == token_ids2[0] + len(token_ids2) - 1 ), f"token ids2 is not ordered" - assert len(tokenizer1) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer1)}" - assert len(tokenizer2) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer2)}" + assert len(tokenizer1) - 1 == token_ids1[-1], f"token ids 1 is not end of tokenize: {len(tokenizer1)}" + assert len(tokenizer2) - 1 == token_ids2[-1], f"token ids 2 is not end of tokenize: {len(tokenizer2)}" if num_vectors_per_token > 1: - pipe.add_token_replacement(0, token_ids1[0], token_ids1) + pipe.add_token_replacement(0, token_ids1[0], token_ids1) # hoge -> hoge, hogea, hogeb, ... pipe.add_token_replacement(1, token_ids2[0], token_ids2) token_ids_embeds1.append((token_ids1, embeds1)) diff --git a/sdxl_train_textual_inversion.py b/sdxl_train_textual_inversion.py new file mode 100644 index 00000000..9df37092 --- /dev/null +++ b/sdxl_train_textual_inversion.py @@ -0,0 +1,142 @@ +import argparse +import os + +import regex +import torch +import open_clip +from library import sdxl_model_util, sdxl_train_util, train_util + +import train_textual_inversion + + +class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTrainer): + def __init__(self): + super().__init__() + self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR + + def assert_extra_args(self, args, train_dataset_group): + super().assert_extra_args(args, train_dataset_group) + sdxl_train_util.verify_sdxl_training_args(args) + + def load_target_model(self, args, weight_dtype, accelerator): + ( + load_stable_diffusion_format, + text_encoder1, + text_encoder2, + vae, + unet, + logit_scale, + ckpt_info, + ) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, weight_dtype) + + self.load_stable_diffusion_format = load_stable_diffusion_format + self.logit_scale = logit_scale + self.ckpt_info = ckpt_info + + return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, [text_encoder1, text_encoder2], vae, unet + + def load_tokenizer(self, args): + tokenizer = sdxl_train_util.load_tokenizers(args) + return tokenizer + + def assert_token_string(self, token_string, tokenizers): + # tokenizer 1 is seems to be ok + + # count words for token string: regular expression from open_clip + pat = regex.compile(r"""'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", regex.IGNORECASE) + words = regex.findall(pat, token_string) + word_count = len(words) + assert word_count == 1, ( + f"token string {token_string} contain {word_count} words, please don't use digits, punctuation, or special characters" + + f" / トークン文字列 {token_string} には{word_count}個の単語が含まれています。数字、句読点、特殊文字は使用しないでください" + ) + + def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): + input_ids1 = batch["input_ids"] + input_ids2 = batch["input_ids2"] + with torch.enable_grad(): + input_ids1 = input_ids1.to(accelerator.device) + input_ids2 = input_ids2.to(accelerator.device) + encoder_hidden_states1, encoder_hidden_states2, pool2 = sdxl_train_util.get_hidden_states( + args, + input_ids1, + input_ids2, + tokenizers[0], + tokenizers[1], + text_encoders[0], + text_encoders[1], + None if not args.full_fp16 else weight_dtype, + ) + return encoder_hidden_states1, encoder_hidden_states2, pool2 + + def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): + noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype + + # get size embeddings + orig_size = batch["original_sizes_hw"] + crop_size = batch["crop_top_lefts"] + target_size = batch["target_sizes_hw"] + embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype) + + # concat embeddings + encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds + vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) + text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) + + noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) + return noise_pred + + def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement): + sdxl_train_util.sample_images( + accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement + ) + + def save_weights(self, file, updated_embs, save_dtype): + state_dict = {"clip_l": updated_embs[0], "clip_g": updated_embs[1]} + + if save_dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + + save_file(state_dict, file) + else: + torch.save(state_dict, file) + + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + data = load_file(file) + else: + data = torch.load(file, map_location="cpu") + + emb_l = data.get("clib_l", None) # ViT-L text encoder 1 + emb_g = data.get("clib_g", None) # BiG-G text encoder 2 + + assert ( + emb_l is not None or emb_g is not None + ), f"weight file does not contains weights for text encoder 1 or 2 / 重みファイルにテキストエンコーダー1または2の重みが含まれていません: {file}" + + return [emb_l, emb_g] + + +def setup_parser() -> argparse.ArgumentParser: + parser = train_textual_inversion.setup_parser() + # don't add sdxl_train_util.add_sdxl_training_arguments(parser): because it only adds text encoder caching + # sdxl_train_util.add_sdxl_training_arguments(parser) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + args = train_util.read_config_from_file(args, parser) + + trainer = SdxlTextualInversionTrainer() + trainer.train(args) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index ecfaeb4f..09294048 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -8,6 +8,7 @@ from tqdm import tqdm import torch from accelerate.utils import set_seed from diffusers import DDPMScheduler +from library import model_util import library.train_util as train_util import library.huggingface_util as huggingface_util @@ -20,8 +21,6 @@ import library.custom_train_functions as custom_train_functions from library.custom_train_functions import ( apply_snr_weight, prepare_scheduler_for_custom_training, - pyramid_noise_like, - apply_noise_offset, scale_v_prediction_loss_like_noise_prediction, ) @@ -78,503 +77,632 @@ imagenet_style_templates_small = [ ] -def train(args): - if args.output_name is None: - args.output_name = args.token_string - use_template = args.use_object_template or args.use_style_template +class TextualInversionTrainer: + def __init__(self): + self.vae_scale_factor = 0.18215 - train_util.verify_training_args(args) - train_util.prepare_dataset_args(args, True) + def assert_extra_args(self, args, train_dataset_group): + pass - cache_latents = args.cache_latents + def load_target_model(self, args, weight_dtype, accelerator): + text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) + return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet - if args.seed is not None: - set_seed(args.seed) + def load_tokenizer(self, args): + tokenizer = train_util.load_tokenizer(args) + return tokenizer - tokenizer = train_util.load_tokenizer(args) + def assert_token_string(self, token_string, tokenizers): + pass + + def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): + with torch.enable_grad(): + input_ids = batch["input_ids"].to(accelerator.device) + encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizers[0], text_encoders[0], None) + return encoder_hidden_states - # acceleratorを準備する - print("prepare accelerator") - accelerator = train_util.prepare_accelerator(args) + def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): + noise_pred = unet(noisy_latents, timesteps, text_conds).sample + return noise_pred - # mixed precisionに対応した型を用意しておき適宜castする - weight_dtype, save_dtype = train_util.prepare_dtype(args) - - # モデルを読み込む - text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) - - # Convert the init_word to token_id - if args.init_word is not None: - init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False) - if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token: - accelerator.print( - f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}" - ) - else: - init_token_ids = None - - # add new word to tokenizer, count is num_vectors_per_token - token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)] - num_added_tokens = tokenizer.add_tokens(token_strings) - assert ( - num_added_tokens == args.num_vectors_per_token - ), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}" - - token_ids = tokenizer.convert_tokens_to_ids(token_strings) - accelerator.print(f"tokens are added: {token_ids}") - assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered" - assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}" - - # Resize the token embeddings as we are adding new special tokens to the tokenizer - text_encoder.resize_token_embeddings(len(tokenizer)) - - # Initialise the newly added placeholder token with the embeddings of the initializer token - token_embeds = text_encoder.get_input_embeddings().weight.data - if init_token_ids is not None: - for i, token_id in enumerate(token_ids): - token_embeds[token_id] = token_embeds[init_token_ids[i % len(init_token_ids)]] - # accelerator.print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) - - # load weights - if args.weights is not None: - embeddings = load_weights(args.weights) - assert len(token_ids) == len( - embeddings - ), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}" - # accelerator.print(token_ids, embeddings.size()) - for token_id, embedding in zip(token_ids, embeddings): - token_embeds[token_id] = embedding - # accelerator.print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) - accelerator.print(f"weighs loaded") - - accelerator.print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}") - - # データセットを準備する - if args.dataset_class is None: - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, False)) - if args.dataset_config is not None: - accelerator.print(f"Load dataset config from {args.dataset_config}") - user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "reg_data_dir", "in_json"] - if any(getattr(args, attr) is not None for attr in ignored): - accelerator.print( - "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( - ", ".join(ignored) - ) - ) - else: - use_dreambooth_method = args.in_json is None - if use_dreambooth_method: - accelerator.print("Use DreamBooth method.") - user_config = { - "datasets": [ - { - "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( - args.train_data_dir, args.reg_data_dir - ) - } - ] - } - else: - print("Train with captions.") - user_config = { - "datasets": [ - { - "subsets": [ - { - "image_dir": args.train_data_dir, - "metadata_file": args.in_json, - } - ] - } - ] - } - - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) - else: - train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer) - - current_epoch = Value("i", 0) - current_step = Value("i", 0) - ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None - collater = train_util.collater_class(current_epoch, current_step, ds_for_collater) - - # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装 - if use_template: - accelerator.print("use template for training captions. is object: {args.use_object_template}") - templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small - replace_to = " ".join(token_strings) - captions = [] - for tmpl in templates: - captions.append(tmpl.format(replace_to)) - train_dataset_group.add_replacement("", captions) - - if args.num_vectors_per_token > 1: - prompt_replacement = (args.token_string, replace_to) - else: - prompt_replacement = None - else: - if args.num_vectors_per_token > 1: - replace_to = " ".join(token_strings) - train_dataset_group.add_replacement(args.token_string, replace_to) - prompt_replacement = (args.token_string, replace_to) - else: - prompt_replacement = None - - if args.debug_dataset: - train_util.debug_dataset(train_dataset_group, show_input_ids=True) - return - if len(train_dataset_group) == 0: - accelerator.print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください") - return - - if cache_latents: - assert ( - train_dataset_group.is_latent_cacheable() - ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" - - # モデルに xformers とか memory efficient attention を組み込む - train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) - - # 学習を準備する - if cache_latents: - vae.to(accelerator.device, dtype=weight_dtype) - vae.requires_grad_(False) - vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) - vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() - - accelerator.wait_for_everyone() - - if args.gradient_checkpointing: - unet.enable_gradient_checkpointing() - text_encoder.gradient_checkpointing_enable() - - # 学習に必要なクラスを準備する - accelerator.print("prepare optimizer, data loader etc.") - trainable_params = text_encoder.get_input_embeddings().parameters() - _, _, optimizer = train_util.get_optimizer(args, trainable_params) - - # dataloaderを準備する - # DataLoaderのプロセス数:0はメインプロセスになる - n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで - train_dataloader = torch.utils.data.DataLoader( - train_dataset_group, - batch_size=1, - shuffle=True, - collate_fn=collater, - num_workers=n_workers, - persistent_workers=args.persistent_data_loader_workers, - ) - - # 学習ステップ数を計算する - if args.max_train_epochs is not None: - args.max_train_steps = args.max_train_epochs * math.ceil( - len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement): + train_util.sample_images( + accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement ) - accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") - # データセット側にも学習ステップを送信 - train_dataset_group.set_max_train_steps(args.max_train_steps) + def save_weights(self, file, updated_embs, save_dtype): + state_dict = {"emb_params": updated_embs[0]} - # lr schedulerを用意する - lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + if save_dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v - # acceleratorがなんかよろしくやってくれるらしい - text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - text_encoder, optimizer, train_dataloader, lr_scheduler - ) + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file - # transform DDP after prepare - text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet) + save_file(state_dict, file) + else: + torch.save(state_dict, file) # can be loaded in Web UI - index_no_updates = torch.arange(len(tokenizer)) < token_ids[0] - # accelerator.print(len(index_no_updates), torch.sum(index_no_updates)) - orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone() + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file - # Freeze all parameters except for the token embeddings in text encoder - text_encoder.requires_grad_(True) - text_encoder.text_model.encoder.requires_grad_(False) - text_encoder.text_model.final_layer_norm.requires_grad_(False) - text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) - # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True) + data = load_file(file) + else: + # compatible to Web UI's file format + data = torch.load(file, map_location="cpu") + if type(data) != dict: + raise ValueError(f"weight file is not dict / 重みファイルがdict形式ではありません: {file}") - unet.requires_grad_(False) - unet.to(accelerator.device, dtype=weight_dtype) - if args.gradient_checkpointing: # according to TI example in Diffusers, train is required - unet.train() - else: - unet.eval() + if "string_to_param" in data: # textual inversion embeddings + data = data["string_to_param"] + if hasattr(data, "_parameters"): # support old PyTorch? + data = getattr(data, "_parameters") - if not cache_latents: - vae.requires_grad_(False) - vae.eval() - vae.to(accelerator.device, dtype=weight_dtype) + emb = next(iter(data.values())) + if type(emb) != torch.Tensor: + raise ValueError(f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {file}") - # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする - if args.full_fp16: - train_util.patch_accelerator_for_fp16_training(accelerator) - text_encoder.to(weight_dtype) + if len(emb.size()) == 1: + emb = emb.unsqueeze(0) - # resumeする - train_util.resume_from_local_or_hf_if_specified(accelerator, args) + return [emb] - # epoch数を計算する - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): - args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + def train(self, args): + if args.output_name is None: + args.output_name = args.token_string + use_template = args.use_object_template or args.use_style_template - # 学習する - total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - accelerator.print("running training / 学習開始") - accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") - accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") - accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") - accelerator.print(f" num epochs / epoch数: {num_train_epochs}") - accelerator.print(f" batch size per device / バッチサイズ: {args.train_batch_size}") - accelerator.print( - f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" - ) - accelerator.print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") - accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) - progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") - global_step = 0 + cache_latents = args.cache_latents - noise_scheduler = DDPMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False - ) - prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) + if args.seed is not None: + set_seed(args.seed) - if accelerator.is_main_process: - accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name) + tokenizer_or_list = self.load_tokenizer(args) # list of tokenizer or tokenizer + tokenizers = tokenizer_or_list if isinstance(tokenizer_or_list, list) else [tokenizer_or_list] - # function for saving/removing - def save_model(ckpt_name, embs, steps, epoch_no, force_sync_upload=False): - os.makedirs(args.output_dir, exist_ok=True) - ckpt_file = os.path.join(args.output_dir, ckpt_name) + # acceleratorを準備する + print("prepare accelerator") + accelerator = train_util.prepare_accelerator(args) - accelerator.print(f"\nsaving checkpoint: {ckpt_file}") - save_weights(ckpt_file, embs, save_dtype) - if args.huggingface_repo_id is not None: - huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) - def remove_model(old_ckpt_name): - old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) - if os.path.exists(old_ckpt_file): - accelerator.print(f"removing old checkpoint: {old_ckpt_file}") - os.remove(old_ckpt_file) + # モデルを読み込む + model_version, text_encoder_or_list, vae, unet = self.load_target_model(args, weight_dtype, accelerator) + text_encoders = [text_encoder_or_list] if not isinstance(text_encoder_or_list, list) else text_encoder_or_list - # training loop - for epoch in range(num_train_epochs): - accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") - current_epoch.value = epoch + 1 + if len(text_encoders) > 1 and args.gradient_accumulation_steps > 1: + accelerator.print( + "accelerate doesn't seem to support gradient_accumulation_steps for multiple models (text encoders) / " + + "accelerateでは複数のモデル(テキストエンコーダー)のgradient_accumulation_stepsはサポートされていないようです" + ) - text_encoder.train() + # Convert the init_word to token_id + init_token_ids_list = [] + if args.init_word is not None: + for i, tokenizer in enumerate(tokenizers): + init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False) + if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token: + accelerator.print( + f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / " + + f"初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: tokenizer {i+1}, length {len(init_token_ids)}" + ) + init_token_ids_list.append(init_token_ids) + else: + init_token_ids_list = [None] * len(tokenizers) - loss_total = 0 + # tokenizerに新しい単語を追加する。追加する単語の数はnum_vectors_per_token + # add new word to tokenizer, count is num_vectors_per_token - for step, batch in enumerate(train_dataloader): - current_step.value = global_step - with accelerator.accumulate(text_encoder): - with torch.no_grad(): - if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) - else: - # latentに変換 - latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() - latents = latents * 0.18215 - b_size = latents.shape[0] + # token_stringが hoge の場合、"hoge", "hogea", "hogeb", ... が追加される + # 当初は "hoge", "hoge1", "hoge2", ... としていたが、open clipのtokenizerは数字を含む単語を分割してしまうため(;^ω^)、a, b, ... とした - # Get the text embedding for conditioning - input_ids = batch["input_ids"].to(accelerator.device) - # use float instead of fp16/bf16 because text encoder is float - encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, torch.float) + # if token_string is hoge, "hoge", "hogea", "hogeb", ... are added + # originally, "hoge", "hoge1", "hoge2", ... were used, but open clip's tokenizer splits words including numbers (;^ω^), so a, b, ... are used - # Sample noise, sample a random timestep for each image, and add noise to the latents, - # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + self.assert_token_string(args.token_string, tokenizers) - # Predict the noise residual - with accelerator.autocast(): - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + token_strings = [args.token_string] + [ + f"{args.token_string}{chr(ord('a') + i)}" for i in range(args.num_vectors_per_token - 1) + ] + token_ids_list = [] + token_embeds_list = [] + for i, (tokenizer, text_encoder, init_token_ids) in enumerate(zip(tokenizers, text_encoders, init_token_ids_list)): + num_added_tokens = tokenizer.add_tokens(token_strings) + assert ( + num_added_tokens == args.num_vectors_per_token + ), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: tokenizer {i+1}, {args.token_string}" - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) + token_ids = tokenizer.convert_tokens_to_ids(token_strings) + accelerator.print(f"tokens are added for tokenizer {i+1}: {token_ids}") + assert ( + min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1 + ), f"token ids is not ordered : tokenizer {i+1}, {token_ids}" + assert ( + len(tokenizer) - 1 == token_ids[-1] + ), f"token ids is not end of tokenize: tokenizer {i+1}, {token_ids}, {len(tokenizer)}" + token_ids_list.append(token_ids) + + # Resize the token embeddings as we are adding new special tokens to the tokenizer + text_encoder.resize_token_embeddings(len(tokenizer)) + + # Initialise the newly added placeholder token with the embeddings of the initializer token + token_embeds = text_encoder.get_input_embeddings().weight.data + if init_token_ids is not None: + for i, token_id in enumerate(token_ids): + token_embeds[token_id] = token_embeds[init_token_ids[i % len(init_token_ids)]] + # accelerator.print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) + token_embeds_list.append(token_embeds) + + # load weights + if args.weights is not None: + embeddings_list = self.load_weights(args.weights) + assert len(token_ids) == len( + embeddings_list[0] + ), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}" + # accelerator.print(token_ids, embeddings.size()) + for token_ids, embeddings, token_embeds in zip(token_ids_list, embeddings_list, token_embeds_list): + for token_id, embedding in zip(token_ids, embeddings): + token_embeds[token_id] = embedding + # accelerator.print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) + accelerator.print(f"weighs loaded") + + accelerator.print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}") + + # データセットを準備する + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, False)) + if args.dataset_config is not None: + accelerator.print(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "reg_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + accelerator.print( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + use_dreambooth_method = args.in_json is None + if use_dreambooth_method: + accelerator.print("Use DreamBooth method.") + user_config = { + "datasets": [ + { + "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( + args.train_data_dir, args.reg_data_dir + ) + } + ] + } else: - target = noise + print("Train with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - loss = loss.mean([1, 2, 3]) + blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer_or_list) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + else: + train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer_or_list) - loss_weights = batch["loss_weights"] # 各sampleごとのweight - loss = loss * loss_weights + self.assert_extra_args(args, train_dataset_group) - if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) - if args.scale_v_pred_loss_like_noise_pred: - loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collater = train_util.collater_class(current_epoch, current_step, ds_for_collater) - loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装 + if use_template: + accelerator.print("use template for training captions. is object: {args.use_object_template}") + templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small + replace_to = " ".join(token_strings) + captions = [] + for tmpl in templates: + captions.append(tmpl.format(replace_to)) + train_dataset_group.add_replacement("", captions) - accelerator.backward(loss) - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = text_encoder.get_input_embeddings().parameters() - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + # サンプル生成用 + if args.num_vectors_per_token > 1: + prompt_replacement = (args.token_string, replace_to) + else: + prompt_replacement = None + else: + # サンプル生成用 + if args.num_vectors_per_token > 1: + replace_to = " ".join(token_strings) + train_dataset_group.add_replacement(args.token_string, replace_to) + prompt_replacement = (args.token_string, replace_to) + else: + prompt_replacement = None - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) + if args.debug_dataset: + train_util.debug_dataset(train_dataset_group, show_input_ids=True) + return + if len(train_dataset_group) == 0: + accelerator.print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください") + return - # Let's make sure we don't update any embedding weights besides the newly added token - with torch.no_grad(): - accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[ - index_no_updates - ] + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - progress_bar.update(1) - global_step += 1 + # モデルに xformers とか memory efficient attention を組み込む + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) + vae.set_use_memory_efficient_attention_xformers(args.xformers) - train_util.sample_images( - accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement - ) + # 学習を準備する + if cache_latents: + vae.to(accelerator.device, dtype=weight_dtype) + vae.requires_grad_(False) + vae.eval() + with torch.no_grad(): + train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) + vae.to("cpu") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() - # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: - accelerator.wait_for_everyone() - if accelerator.is_main_process: - updated_embs = ( - accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone() + accelerator.wait_for_everyone() + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + for text_encoder in text_encoders: + text_encoder.gradient_checkpointing_enable() + + # 学習に必要なクラスを準備する + accelerator.print("prepare optimizer, data loader etc.") + trainable_params = [] + for text_encoder in text_encoders: + trainable_params += text_encoder.get_input_embeddings().parameters() + _, _, optimizer = train_util.get_optimizer(args, trainable_params) + + # dataloaderを準備する + # DataLoaderのプロセス数:0はメインプロセスになる + n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collater, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) + + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # lr schedulerを用意する + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + # acceleratorがなんかよろしくやってくれるらしい + if len(text_encoders) == 1: + text_encoder_or_list, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + text_encoder_or_list, optimizer, train_dataloader, lr_scheduler + ) + # transform DDP after prepare + text_encoder_or_list, unet = train_util.transform_if_model_is_DDP(text_encoder_or_list, unet) + + elif len(text_encoders) == 2: + text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + text_encoders[0], text_encoders[1], optimizer, train_dataloader, lr_scheduler + ) + # transform DDP after prepare + text_encoder1, text_encoder2, unet = train_util.transform_if_model_is_DDP(text_encoder1, text_encoder2, unet) + + text_encoder_or_list = text_encoders = [text_encoder1, text_encoder2] + + else: + raise NotImplementedError() + + index_no_updates_list = [] + orig_embeds_params_list = [] + for tokenizer, token_ids, text_encoder in zip(tokenizers, token_ids_list, text_encoders): + index_no_updates = torch.arange(len(tokenizer)) < token_ids[0] + index_no_updates_list.append(index_no_updates) + + # accelerator.print(len(index_no_updates), torch.sum(index_no_updates)) + orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone() + orig_embeds_params_list.append(orig_embeds_params) + + # Freeze all parameters except for the token embeddings in text encoder + text_encoder.requires_grad_(True) + text_encoder.text_model.encoder.requires_grad_(False) + text_encoder.text_model.final_layer_norm.requires_grad_(False) + text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) + # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True) + + unet.requires_grad_(False) + unet.to(accelerator.device, dtype=weight_dtype) + if args.gradient_checkpointing: # according to TI example in Diffusers, train is required + # TODO U-Netをオリジナルに置き換えたのでいらないはずなので、後で確認して消す + unet.train() + else: + unet.eval() + + if not cache_latents: + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=weight_dtype) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + train_util.patch_accelerator_for_fp16_training(accelerator) + for text_encoder in text_encoders: + text_encoder.to(weight_dtype) + if args.full_bf16: + for text_encoder in text_encoders: + text_encoder.to(weight_dtype) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + accelerator.print("running training / 学習開始") + accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print(f" batch size per device / バッチサイズ: {args.train_batch_size}") + accelerator.print( + f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" + ) + accelerator.print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + noise_scheduler = DDPMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False + ) + prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) + + if accelerator.is_main_process: + accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name) + + # function for saving/removing + def save_model(ckpt_name, embs_list, steps, epoch_no, force_sync_upload=False): + os.makedirs(args.output_dir, exist_ok=True) + ckpt_file = os.path.join(args.output_dir, ckpt_name) + + accelerator.print(f"\nsaving checkpoint: {ckpt_file}") + self.save_weights(ckpt_file, embs_list, save_dtype) + if args.huggingface_repo_id is not None: + huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) + + def remove_model(old_ckpt_name): + old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) + if os.path.exists(old_ckpt_file): + accelerator.print(f"removing old checkpoint: {old_ckpt_file}") + os.remove(old_ckpt_file) + + # training loop + for epoch in range(num_train_epochs): + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 + + for text_encoder in text_encoders: + text_encoder.train() + + loss_total = 0 + + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + with accelerator.accumulate(text_encoders[0]): + with torch.no_grad(): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device) + else: + # latentに変換 + latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * self.vae_scale_factor + + # Get the text embedding for conditioning + text_encoder_conds = self.get_text_cond(args, accelerator, batch, tokenizers, text_encoders, weight_dtype) + + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents + ) + + # Predict the noise residual + with accelerator.autocast(): + noise_pred = self.call_unet( + args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype ) - ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) - save_model(ckpt_name, updated_embs, global_step, epoch) + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise - if args.save_state: - train_util.save_and_remove_state_stepwise(args, accelerator, global_step) + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) - remove_step_no = train_util.get_remove_step_no(args, global_step) - if remove_step_no is not None: - remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no) - remove_model(remove_ckpt_name) + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights - current_loss = loss.detach().item() - if args.logging_dir is not None: - logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} - if ( - args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower() - ): # tracking d*lr value - logs["lr/d*lr"] = ( - lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + + accelerator.backward(loss) + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = text_encoder.get_input_embeddings().parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + # Let's make sure we don't update any embedding weights besides the newly added token + with torch.no_grad(): + for text_encoder, orig_embeds_params, index_no_updates in zip( + text_encoders, orig_embeds_params_list, index_no_updates_list + ): + accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[ + index_no_updates + ] = orig_embeds_params[index_no_updates] + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + self.sample_images( + accelerator, + args, + None, + global_step, + accelerator.device, + vae, + tokenizer_or_list, + text_encoder_or_list, + unet, + prompt_replacement, ) - accelerator.log(logs, step=global_step) - loss_total += current_loss - avr_loss = loss_total / (step + 1) - logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} - progress_bar.set_postfix(**logs) + # 指定ステップごとにモデルを保存 + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + updated_embs_list = [] + for text_encoder, token_ids in zip(text_encoders, token_ids_list): + updated_embs = ( + accelerator.unwrap_model(text_encoder) + .get_input_embeddings() + .weight[token_ids] + .data.detach() + .clone() + ) + updated_embs_list.append(updated_embs) - if global_step >= args.max_train_steps: - break + ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) + save_model(ckpt_name, updated_embs_list, global_step, epoch) - if args.logging_dir is not None: - logs = {"loss/epoch": loss_total / len(train_dataloader)} - accelerator.log(logs, step=epoch + 1) + if args.save_state: + train_util.save_and_remove_state_stepwise(args, accelerator, global_step) - accelerator.wait_for_everyone() + remove_step_no = train_util.get_remove_step_no(args, global_step) + if remove_step_no is not None: + remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no) + remove_model(remove_ckpt_name) - updated_embs = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone() + current_loss = loss.detach().item() + if args.logging_dir is not None: + logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} + if ( + args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower() + ): # tracking d*lr value + logs["lr/d*lr"] = ( + lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] + ) + accelerator.log(logs, step=global_step) - if args.save_every_n_epochs is not None: - saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs - if accelerator.is_main_process and saving: - ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1) - save_model(ckpt_name, updated_embs, epoch + 1, global_step) + loss_total += current_loss + avr_loss = loss_total / (step + 1) + logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) - remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) - if remove_epoch_no is not None: - remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no) - remove_model(remove_ckpt_name) + if global_step >= args.max_train_steps: + break - if args.save_state: - train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) + if args.logging_dir is not None: + logs = {"loss/epoch": loss_total / len(train_dataloader)} + accelerator.log(logs, step=epoch + 1) - train_util.sample_images( - accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement - ) + accelerator.wait_for_everyone() - # end of epoch + updated_embs_list = [] + for text_encoder, token_ids in zip(text_encoders, token_ids_list): + updated_embs = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone() + updated_embs_list.append(updated_embs) - is_main_process = accelerator.is_main_process - if is_main_process: - text_encoder = accelerator.unwrap_model(text_encoder) + if args.save_every_n_epochs is not None: + saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs + if accelerator.is_main_process and saving: + ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1) + save_model(ckpt_name, updated_embs_list, epoch + 1, global_step) - accelerator.end_training() + remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) + if remove_epoch_no is not None: + remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no) + remove_model(remove_ckpt_name) - if args.save_state and is_main_process: - train_util.save_state_on_train_end(args, accelerator) + if args.save_state: + train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) - updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone() + self.sample_images( + accelerator, + args, + epoch + 1, + global_step, + accelerator.device, + vae, + tokenizer_or_list, + text_encoder_or_list, + unet, + prompt_replacement, + ) - del accelerator # この後メモリを使うのでこれは消す + # end of epoch - if is_main_process: - ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) - save_model(ckpt_name, updated_embs, global_step, num_train_epochs, force_sync_upload=True) + is_main_process = accelerator.is_main_process + if is_main_process: + text_encoder = accelerator.unwrap_model(text_encoder) - print("model saved.") + accelerator.end_training() + if args.save_state and is_main_process: + train_util.save_state_on_train_end(args, accelerator) -def save_weights(file, updated_embs, save_dtype): - state_dict = {"emb_params": updated_embs} + updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone() - if save_dtype is not None: - for key in list(state_dict.keys()): - v = state_dict[key] - v = v.detach().clone().to("cpu").to(save_dtype) - state_dict[key] = v + if is_main_process: + ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) + save_model(ckpt_name, updated_embs_list, global_step, num_train_epochs, force_sync_upload=True) - if os.path.splitext(file)[1] == ".safetensors": - from safetensors.torch import save_file - - save_file(state_dict, file) - else: - torch.save(state_dict, file) # can be loaded in Web UI - - -def load_weights(file): - if os.path.splitext(file)[1] == ".safetensors": - from safetensors.torch import load_file - - data = load_file(file) - else: - # compatible to Web UI's file format - data = torch.load(file, map_location="cpu") - if type(data) != dict: - raise ValueError(f"weight file is not dict / 重みファイルがdict形式ではありません: {file}") - - if "string_to_param" in data: # textual inversion embeddings - data = data["string_to_param"] - if hasattr(data, "_parameters"): # support old PyTorch? - data = getattr(data, "_parameters") - - emb = next(iter(data.values())) - if type(emb) != torch.Tensor: - raise ValueError(f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {file}") - - if len(emb.size()) == 1: - emb = emb.unsqueeze(0) - - return emb + print("model saved.") def setup_parser() -> argparse.ArgumentParser: @@ -626,4 +754,5 @@ if __name__ == "__main__": args = parser.parse_args() args = train_util.read_config_from_file(args, parser) - train(args) + trainer = TextualInversionTrainer() + trainer.train(args)