mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
update diffusers to 1.16 | train_textual_inversion
This commit is contained in:
@@ -1,15 +1,12 @@
|
|||||||
import importlib
|
|
||||||
import argparse
|
import argparse
|
||||||
import gc
|
import gc
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import toml
|
|
||||||
from multiprocessing import Value
|
from multiprocessing import Value
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
import diffusers
|
|
||||||
from diffusers import DDPMScheduler
|
from diffusers import DDPMScheduler
|
||||||
|
|
||||||
import library.train_util as train_util
|
import library.train_util as train_util
|
||||||
@@ -104,7 +101,7 @@ def train(args):
|
|||||||
if args.init_word is not None:
|
if args.init_word is not None:
|
||||||
init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False)
|
init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False)
|
||||||
if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token:
|
if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token:
|
||||||
print(
|
accelerator.print(
|
||||||
f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}"
|
f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -118,7 +115,7 @@ def train(args):
|
|||||||
), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}"
|
), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}"
|
||||||
|
|
||||||
token_ids = tokenizer.convert_tokens_to_ids(token_strings)
|
token_ids = tokenizer.convert_tokens_to_ids(token_strings)
|
||||||
print(f"tokens are added: {token_ids}")
|
accelerator.print(f"tokens are added: {token_ids}")
|
||||||
assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered"
|
assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered"
|
||||||
assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}"
|
assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}"
|
||||||
|
|
||||||
@@ -130,7 +127,7 @@ def train(args):
|
|||||||
if init_token_ids is not None:
|
if init_token_ids is not None:
|
||||||
for i, token_id in enumerate(token_ids):
|
for i, token_id in enumerate(token_ids):
|
||||||
token_embeds[token_id] = token_embeds[init_token_ids[i % len(init_token_ids)]]
|
token_embeds[token_id] = token_embeds[init_token_ids[i % len(init_token_ids)]]
|
||||||
# print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
|
# accelerator.print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
|
||||||
|
|
||||||
# load weights
|
# load weights
|
||||||
if args.weights is not None:
|
if args.weights is not None:
|
||||||
@@ -138,22 +135,22 @@ def train(args):
|
|||||||
assert len(token_ids) == len(
|
assert len(token_ids) == len(
|
||||||
embeddings
|
embeddings
|
||||||
), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}"
|
), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}"
|
||||||
# print(token_ids, embeddings.size())
|
# accelerator.print(token_ids, embeddings.size())
|
||||||
for token_id, embedding in zip(token_ids, embeddings):
|
for token_id, embedding in zip(token_ids, embeddings):
|
||||||
token_embeds[token_id] = embedding
|
token_embeds[token_id] = embedding
|
||||||
# print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
|
# accelerator.print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
|
||||||
print(f"weighs loaded")
|
accelerator.print(f"weighs loaded")
|
||||||
|
|
||||||
print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
|
accelerator.print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
|
||||||
|
|
||||||
# データセットを準備する
|
# データセットを準備する
|
||||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False))
|
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False))
|
||||||
if args.dataset_config is not None:
|
if args.dataset_config is not None:
|
||||||
print(f"Load dataset config from {args.dataset_config}")
|
accelerator.print(f"Load dataset config from {args.dataset_config}")
|
||||||
user_config = config_util.load_user_config(args.dataset_config)
|
user_config = config_util.load_user_config(args.dataset_config)
|
||||||
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
|
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
|
||||||
if any(getattr(args, attr) is not None for attr in ignored):
|
if any(getattr(args, attr) is not None for attr in ignored):
|
||||||
print(
|
accelerator.print(
|
||||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||||
", ".join(ignored)
|
", ".join(ignored)
|
||||||
)
|
)
|
||||||
@@ -161,14 +158,14 @@ def train(args):
|
|||||||
else:
|
else:
|
||||||
use_dreambooth_method = args.in_json is None
|
use_dreambooth_method = args.in_json is None
|
||||||
if use_dreambooth_method:
|
if use_dreambooth_method:
|
||||||
print("Use DreamBooth method.")
|
accelerator.print("Use DreamBooth method.")
|
||||||
user_config = {
|
user_config = {
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
|
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
print("Train with captions.")
|
accelerator.print("Train with captions.")
|
||||||
user_config = {
|
user_config = {
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
@@ -192,7 +189,7 @@ def train(args):
|
|||||||
|
|
||||||
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
|
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
|
||||||
if use_template:
|
if use_template:
|
||||||
print("use template for training captions. is object: {args.use_object_template}")
|
accelerator.print("use template for training captions. is object: {args.use_object_template}")
|
||||||
templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small
|
templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small
|
||||||
replace_to = " ".join(token_strings)
|
replace_to = " ".join(token_strings)
|
||||||
captions = []
|
captions = []
|
||||||
@@ -216,7 +213,7 @@ def train(args):
|
|||||||
train_util.debug_dataset(train_dataset_group, show_input_ids=True)
|
train_util.debug_dataset(train_dataset_group, show_input_ids=True)
|
||||||
return
|
return
|
||||||
if len(train_dataset_group) == 0:
|
if len(train_dataset_group) == 0:
|
||||||
print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
|
accelerator.print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
|
||||||
return
|
return
|
||||||
|
|
||||||
if cache_latents:
|
if cache_latents:
|
||||||
@@ -246,7 +243,7 @@ def train(args):
|
|||||||
text_encoder.gradient_checkpointing_enable()
|
text_encoder.gradient_checkpointing_enable()
|
||||||
|
|
||||||
# 学習に必要なクラスを準備する
|
# 学習に必要なクラスを準備する
|
||||||
print("prepare optimizer, data loader etc.")
|
accelerator.print("prepare optimizer, data loader etc.")
|
||||||
trainable_params = text_encoder.get_input_embeddings().parameters()
|
trainable_params = text_encoder.get_input_embeddings().parameters()
|
||||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||||
|
|
||||||
@@ -267,7 +264,7 @@ def train(args):
|
|||||||
args.max_train_steps = args.max_train_epochs * math.ceil(
|
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||||
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
||||||
)
|
)
|
||||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||||
|
|
||||||
# データセット側にも学習ステップを送信
|
# データセット側にも学習ステップを送信
|
||||||
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||||
@@ -284,7 +281,7 @@ def train(args):
|
|||||||
text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet)
|
text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet)
|
||||||
|
|
||||||
index_no_updates = torch.arange(len(tokenizer)) < token_ids[0]
|
index_no_updates = torch.arange(len(tokenizer)) < token_ids[0]
|
||||||
# print(len(index_no_updates), torch.sum(index_no_updates))
|
# accelerator.print(len(index_no_updates), torch.sum(index_no_updates))
|
||||||
orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
|
orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
|
||||||
|
|
||||||
# Freeze all parameters except for the token embeddings in text encoder
|
# Freeze all parameters except for the token embeddings in text encoder
|
||||||
@@ -322,15 +319,15 @@ def train(args):
|
|||||||
|
|
||||||
# 学習する
|
# 学習する
|
||||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||||
print("running training / 学習開始")
|
accelerator.print("running training / 学習開始")
|
||||||
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
||||||
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
||||||
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||||
print(f" num epochs / epoch数: {num_train_epochs}")
|
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
||||||
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
accelerator.print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
||||||
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
accelerator.print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
||||||
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
accelerator.print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
||||||
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
||||||
|
|
||||||
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
||||||
global_step = 0
|
global_step = 0
|
||||||
@@ -347,7 +344,7 @@ def train(args):
|
|||||||
os.makedirs(args.output_dir, exist_ok=True)
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||||
|
|
||||||
print(f"\nsaving checkpoint: {ckpt_file}")
|
accelerator.print(f"\nsaving checkpoint: {ckpt_file}")
|
||||||
save_weights(ckpt_file, embs, save_dtype)
|
save_weights(ckpt_file, embs, save_dtype)
|
||||||
if args.huggingface_repo_id is not None:
|
if args.huggingface_repo_id is not None:
|
||||||
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
|
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
|
||||||
@@ -355,12 +352,12 @@ def train(args):
|
|||||||
def remove_model(old_ckpt_name):
|
def remove_model(old_ckpt_name):
|
||||||
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
|
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
|
||||||
if os.path.exists(old_ckpt_file):
|
if os.path.exists(old_ckpt_file):
|
||||||
print(f"removing old checkpoint: {old_ckpt_file}")
|
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
|
||||||
os.remove(old_ckpt_file)
|
os.remove(old_ckpt_file)
|
||||||
|
|
||||||
# training loop
|
# training loop
|
||||||
for epoch in range(num_train_epochs):
|
for epoch in range(num_train_epochs):
|
||||||
print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||||
current_epoch.value = epoch + 1
|
current_epoch.value = epoch + 1
|
||||||
|
|
||||||
text_encoder.train()
|
text_encoder.train()
|
||||||
|
|||||||
Reference in New Issue
Block a user