From 23c4e5cb016a37840919573774838809985ff25a Mon Sep 17 00:00:00 2001 From: ddPn08 Date: Thu, 1 Jun 2023 10:37:23 +0900 Subject: [PATCH] update diffusers to 1.16 | train_textual_inversion --- train_textual_inversion.py | 59 ++++++++++++++++++-------------------- 1 file changed, 28 insertions(+), 31 deletions(-) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index b73027de..3d028442 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -1,15 +1,12 @@ -import importlib import argparse import gc import math import os -import toml from multiprocessing import Value from tqdm import tqdm import torch from accelerate.utils import set_seed -import diffusers from diffusers import DDPMScheduler import library.train_util as train_util @@ -104,7 +101,7 @@ def train(args): if args.init_word is not None: init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False) if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token: - print( + accelerator.print( f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}" ) else: @@ -118,7 +115,7 @@ def train(args): ), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}" token_ids = tokenizer.convert_tokens_to_ids(token_strings) - print(f"tokens are added: {token_ids}") + accelerator.print(f"tokens are added: {token_ids}") assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered" assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}" @@ -130,7 +127,7 @@ def train(args): if init_token_ids is not None: for i, token_id in enumerate(token_ids): token_embeds[token_id] = token_embeds[init_token_ids[i % len(init_token_ids)]] - # print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) + # accelerator.print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) # load weights if args.weights is not None: @@ -138,22 +135,22 @@ def train(args): assert len(token_ids) == len( embeddings ), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}" - # print(token_ids, embeddings.size()) + # accelerator.print(token_ids, embeddings.size()) for token_id, embedding in zip(token_ids, embeddings): token_embeds[token_id] = embedding - # print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) - print(f"weighs loaded") + # accelerator.print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) + accelerator.print(f"weighs loaded") - print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}") + accelerator.print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}") # データセットを準備する blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False)) if args.dataset_config is not None: - print(f"Load dataset config from {args.dataset_config}") + accelerator.print(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "reg_data_dir", "in_json"] if any(getattr(args, attr) is not None for attr in ignored): - print( + accelerator.print( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( ", ".join(ignored) ) @@ -161,14 +158,14 @@ def train(args): else: use_dreambooth_method = args.in_json is None if use_dreambooth_method: - print("Use DreamBooth method.") + accelerator.print("Use DreamBooth method.") user_config = { "datasets": [ {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)} ] } else: - print("Train with captions.") + accelerator.print("Train with captions.") user_config = { "datasets": [ { @@ -192,7 +189,7 @@ def train(args): # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装 if use_template: - print("use template for training captions. is object: {args.use_object_template}") + accelerator.print("use template for training captions. is object: {args.use_object_template}") templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small replace_to = " ".join(token_strings) captions = [] @@ -216,7 +213,7 @@ def train(args): train_util.debug_dataset(train_dataset_group, show_input_ids=True) return if len(train_dataset_group) == 0: - print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください") + accelerator.print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください") return if cache_latents: @@ -246,7 +243,7 @@ def train(args): text_encoder.gradient_checkpointing_enable() # 学習に必要なクラスを準備する - print("prepare optimizer, data loader etc.") + accelerator.print("prepare optimizer, data loader etc.") trainable_params = text_encoder.get_input_embeddings().parameters() _, _, optimizer = train_util.get_optimizer(args, trainable_params) @@ -267,7 +264,7 @@ def train(args): args.max_train_steps = args.max_train_epochs * math.ceil( len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps ) - print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) @@ -284,7 +281,7 @@ def train(args): text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet) index_no_updates = torch.arange(len(tokenizer)) < token_ids[0] - # print(len(index_no_updates), torch.sum(index_no_updates)) + # accelerator.print(len(index_no_updates), torch.sum(index_no_updates)) orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone() # Freeze all parameters except for the token embeddings in text encoder @@ -322,15 +319,15 @@ def train(args): # 学習する total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - print("running training / 学習開始") - print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") - print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") - print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") - print(f" num epochs / epoch数: {num_train_epochs}") - print(f" batch size per device / バッチサイズ: {args.train_batch_size}") - print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") - print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") - print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + accelerator.print("running training / 学習開始") + accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print(f" batch size per device / バッチサイズ: {args.train_batch_size}") + accelerator.print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + accelerator.print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") global_step = 0 @@ -347,7 +344,7 @@ def train(args): os.makedirs(args.output_dir, exist_ok=True) ckpt_file = os.path.join(args.output_dir, ckpt_name) - print(f"\nsaving checkpoint: {ckpt_file}") + accelerator.print(f"\nsaving checkpoint: {ckpt_file}") save_weights(ckpt_file, embs, save_dtype) if args.huggingface_repo_id is not None: huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) @@ -355,12 +352,12 @@ def train(args): def remove_model(old_ckpt_name): old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) if os.path.exists(old_ckpt_file): - print(f"removing old checkpoint: {old_ckpt_file}") + accelerator.print(f"removing old checkpoint: {old_ckpt_file}") os.remove(old_ckpt_file) # training loop for epoch in range(num_train_epochs): - print(f"\nepoch {epoch+1}/{num_train_epochs}") + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 text_encoder.train()