update diffusers to 1.16 | train_textual_inversion

This commit is contained in:
ddPn08
2023-06-01 10:37:23 +09:00
parent 1f1cae6c5a
commit 23c4e5cb01

View File

@@ -1,15 +1,12 @@
import importlib
import argparse import argparse
import gc import gc
import math import math
import os import os
import toml
from multiprocessing import Value from multiprocessing import Value
from tqdm import tqdm from tqdm import tqdm
import torch import torch
from accelerate.utils import set_seed from accelerate.utils import set_seed
import diffusers
from diffusers import DDPMScheduler from diffusers import DDPMScheduler
import library.train_util as train_util import library.train_util as train_util
@@ -104,7 +101,7 @@ def train(args):
if args.init_word is not None: if args.init_word is not None:
init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False) init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False)
if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token: if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token:
print( accelerator.print(
f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}" f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}"
) )
else: else:
@@ -118,7 +115,7 @@ def train(args):
), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}" ), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}"
token_ids = tokenizer.convert_tokens_to_ids(token_strings) token_ids = tokenizer.convert_tokens_to_ids(token_strings)
print(f"tokens are added: {token_ids}") accelerator.print(f"tokens are added: {token_ids}")
assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered" assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered"
assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}" assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}"
@@ -130,7 +127,7 @@ def train(args):
if init_token_ids is not None: if init_token_ids is not None:
for i, token_id in enumerate(token_ids): for i, token_id in enumerate(token_ids):
token_embeds[token_id] = token_embeds[init_token_ids[i % len(init_token_ids)]] token_embeds[token_id] = token_embeds[init_token_ids[i % len(init_token_ids)]]
# print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) # accelerator.print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
# load weights # load weights
if args.weights is not None: if args.weights is not None:
@@ -138,22 +135,22 @@ def train(args):
assert len(token_ids) == len( assert len(token_ids) == len(
embeddings embeddings
), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}" ), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}"
# print(token_ids, embeddings.size()) # accelerator.print(token_ids, embeddings.size())
for token_id, embedding in zip(token_ids, embeddings): for token_id, embedding in zip(token_ids, embeddings):
token_embeds[token_id] = embedding token_embeds[token_id] = embedding
# print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) # accelerator.print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
print(f"weighs loaded") accelerator.print(f"weighs loaded")
print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}") accelerator.print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
# データセットを準備する # データセットを準備する
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False)) blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False))
if args.dataset_config is not None: if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}") accelerator.print(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config) user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "reg_data_dir", "in_json"] ignored = ["train_data_dir", "reg_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored): if any(getattr(args, attr) is not None for attr in ignored):
print( accelerator.print(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored) ", ".join(ignored)
) )
@@ -161,14 +158,14 @@ def train(args):
else: else:
use_dreambooth_method = args.in_json is None use_dreambooth_method = args.in_json is None
if use_dreambooth_method: if use_dreambooth_method:
print("Use DreamBooth method.") accelerator.print("Use DreamBooth method.")
user_config = { user_config = {
"datasets": [ "datasets": [
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)} {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
] ]
} }
else: else:
print("Train with captions.") accelerator.print("Train with captions.")
user_config = { user_config = {
"datasets": [ "datasets": [
{ {
@@ -192,7 +189,7 @@ def train(args):
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装 # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
if use_template: if use_template:
print("use template for training captions. is object: {args.use_object_template}") accelerator.print("use template for training captions. is object: {args.use_object_template}")
templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small
replace_to = " ".join(token_strings) replace_to = " ".join(token_strings)
captions = [] captions = []
@@ -216,7 +213,7 @@ def train(args):
train_util.debug_dataset(train_dataset_group, show_input_ids=True) train_util.debug_dataset(train_dataset_group, show_input_ids=True)
return return
if len(train_dataset_group) == 0: if len(train_dataset_group) == 0:
print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください") accelerator.print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
return return
if cache_latents: if cache_latents:
@@ -246,7 +243,7 @@ def train(args):
text_encoder.gradient_checkpointing_enable() text_encoder.gradient_checkpointing_enable()
# 学習に必要なクラスを準備する # 学習に必要なクラスを準備する
print("prepare optimizer, data loader etc.") accelerator.print("prepare optimizer, data loader etc.")
trainable_params = text_encoder.get_input_embeddings().parameters() trainable_params = text_encoder.get_input_embeddings().parameters()
_, _, optimizer = train_util.get_optimizer(args, trainable_params) _, _, optimizer = train_util.get_optimizer(args, trainable_params)
@@ -267,7 +264,7 @@ def train(args):
args.max_train_steps = args.max_train_epochs * math.ceil( args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
) )
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# データセット側にも学習ステップを送信 # データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps) train_dataset_group.set_max_train_steps(args.max_train_steps)
@@ -284,7 +281,7 @@ def train(args):
text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet) text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet)
index_no_updates = torch.arange(len(tokenizer)) < token_ids[0] index_no_updates = torch.arange(len(tokenizer)) < token_ids[0]
# print(len(index_no_updates), torch.sum(index_no_updates)) # accelerator.print(len(index_no_updates), torch.sum(index_no_updates))
orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone() orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
# Freeze all parameters except for the token embeddings in text encoder # Freeze all parameters except for the token embeddings in text encoder
@@ -322,15 +319,15 @@ def train(args):
# 学習する # 学習する
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
print("running training / 学習開始") accelerator.print("running training / 学習開始")
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
print(f" num epochs / epoch数: {num_train_epochs}") accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
print(f" batch size per device / バッチサイズ: {args.train_batch_size}") accelerator.print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") accelerator.print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") accelerator.print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0 global_step = 0
@@ -347,7 +344,7 @@ def train(args):
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
ckpt_file = os.path.join(args.output_dir, ckpt_name) ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"\nsaving checkpoint: {ckpt_file}") accelerator.print(f"\nsaving checkpoint: {ckpt_file}")
save_weights(ckpt_file, embs, save_dtype) save_weights(ckpt_file, embs, save_dtype)
if args.huggingface_repo_id is not None: if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
@@ -355,12 +352,12 @@ def train(args):
def remove_model(old_ckpt_name): def remove_model(old_ckpt_name):
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
if os.path.exists(old_ckpt_file): if os.path.exists(old_ckpt_file):
print(f"removing old checkpoint: {old_ckpt_file}") accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file) os.remove(old_ckpt_file)
# training loop # training loop
for epoch in range(num_train_epochs): for epoch in range(num_train_epochs):
print(f"\nepoch {epoch+1}/{num_train_epochs}") accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1 current_epoch.value = epoch + 1
text_encoder.train() text_encoder.train()