Merge branch 'dev' into dev_device_support

This commit is contained in:
Kohya S
2024-02-12 13:01:54 +09:00
62 changed files with 1387 additions and 993 deletions

View File

@@ -34,6 +34,12 @@ from library.custom_train_functions import (
)
import library.original_unet as original_unet
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
logger = logging.getLogger(__name__)
imagenet_templates_small = [
"a photo of a {}",
@@ -97,7 +103,7 @@ def train(args):
train_util.prepare_dataset_args(args, True)
if args.sample_every_n_steps is not None or args.sample_every_n_epochs is not None:
print(
logger.warning(
"sample_every_n_steps and sample_every_n_epochs are not supported in this script currently / sample_every_n_stepsとsample_every_n_epochsは現在このスクリプトではサポートされていません"
)
assert (
@@ -112,7 +118,7 @@ def train(args):
tokenizer = train_util.load_tokenizer(args)
# acceleratorを準備する
print("prepare accelerator")
logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
@@ -125,7 +131,7 @@ def train(args):
if args.init_word is not None:
init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False)
if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token:
print(
logger.warning(
f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}"
)
else:
@@ -139,7 +145,7 @@ def train(args):
), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}"
token_ids = tokenizer.convert_tokens_to_ids(token_strings)
print(f"tokens are added: {token_ids}")
logger.info(f"tokens are added: {token_ids}")
assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered"
assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}"
@@ -167,7 +173,7 @@ def train(args):
tokenizer.add_tokens(token_strings_XTI)
token_ids_XTI = tokenizer.convert_tokens_to_ids(token_strings_XTI)
print(f"tokens are added (XTI): {token_ids_XTI}")
logger.info(f"tokens are added (XTI): {token_ids_XTI}")
# Resize the token embeddings as we are adding new special tokens to the tokenizer
text_encoder.resize_token_embeddings(len(tokenizer))
@@ -176,7 +182,7 @@ def train(args):
if init_token_ids is not None:
for i, token_id in enumerate(token_ids_XTI):
token_embeds[token_id] = token_embeds[init_token_ids[(i // 16) % len(init_token_ids)]]
# print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
# logger.info(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
# load weights
if args.weights is not None:
@@ -184,22 +190,22 @@ def train(args):
assert len(token_ids) == len(
embeddings
), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}"
# print(token_ids, embeddings.size())
# logger.info(token_ids, embeddings.size())
for token_id, embedding in zip(token_ids_XTI, embeddings):
token_embeds[token_id] = embedding
# print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
print(f"weighs loaded")
# logger.info(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
logger.info(f"weighs loaded")
print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
logger.info(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
# データセットを準備する
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, False))
if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}")
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
logger.info(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
@@ -207,14 +213,14 @@ def train(args):
else:
use_dreambooth_method = args.in_json is None
if use_dreambooth_method:
print("Use DreamBooth method.")
logger.info("Use DreamBooth method.")
user_config = {
"datasets": [
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
]
}
else:
print("Train with captions.")
logger.info("Train with captions.")
user_config = {
"datasets": [
{
@@ -238,7 +244,7 @@ def train(args):
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
if use_template:
print(f"use template for training captions. is object: {args.use_object_template}")
logger.info(f"use template for training captions. is object: {args.use_object_template}")
templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small
replace_to = " ".join(token_strings)
captions = []
@@ -262,7 +268,7 @@ def train(args):
train_util.debug_dataset(train_dataset_group, show_input_ids=True)
return
if len(train_dataset_group) == 0:
print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
logger.error("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
return
if cache_latents:
@@ -293,7 +299,7 @@ def train(args):
text_encoder.gradient_checkpointing_enable()
# 学習に必要なクラスを準備する
print("prepare optimizer, data loader etc.")
logger.info("prepare optimizer, data loader etc.")
trainable_params = text_encoder.get_input_embeddings().parameters()
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
@@ -314,7 +320,9 @@ def train(args):
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
logger.info(
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
)
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
@@ -328,7 +336,7 @@ def train(args):
)
index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0]
# print(len(index_no_updates), torch.sum(index_no_updates))
# logger.info(len(index_no_updates), torch.sum(index_no_updates))
orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
# Freeze all parameters except for the token embeddings in text encoder
@@ -366,15 +374,17 @@ def train(args):
# 学習する
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
print("running training / 学習開始")
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
print(f" num epochs / epoch数: {num_train_epochs}")
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
print(f" gradient ccumulation steps / 勾配合計するステップ数 = {args.gradient_accumulation_steps}")
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
logger.info("running training / 学習開始")
logger.info(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
logger.info(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
logger.info(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
logger.info(f" num epochs / epoch数: {num_train_epochs}")
logger.info(f" batch size per device / バッチサイズ: {args.train_batch_size}")
logger.info(
f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
)
logger.info(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
logger.info(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0
@@ -389,17 +399,20 @@ def train(args):
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs['wandb'] = {'name': args.wandb_run_name}
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
accelerator.init_trackers(
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
)
# function for saving/removing
def save_model(ckpt_name, embs, steps, epoch_no, force_sync_upload=False):
os.makedirs(args.output_dir, exist_ok=True)
ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"\nsaving checkpoint: {ckpt_file}")
logger.info("")
logger.info(f"saving checkpoint: {ckpt_file}")
save_weights(ckpt_file, embs, save_dtype)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
@@ -407,12 +420,13 @@ def train(args):
def remove_model(old_ckpt_name):
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
if os.path.exists(old_ckpt_file):
print(f"removing old checkpoint: {old_ckpt_file}")
logger.info(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)
# training loop
for epoch in range(num_train_epochs):
print(f"\nepoch {epoch+1}/{num_train_epochs}")
logger.info("")
logger.info(f"epoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
text_encoder.train()
@@ -582,7 +596,7 @@ def train(args):
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
save_model(ckpt_name, updated_embs, global_step, num_train_epochs, force_sync_upload=True)
print("model saved.")
logger.info("model saved.")
def save_weights(file, updated_embs, save_dtype):
@@ -643,6 +657,7 @@ def load_weights(file):
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, False)
train_util.add_training_arguments(parser, True)
@@ -658,7 +673,9 @@ def setup_parser() -> argparse.ArgumentParser:
help="format to save the model (default is .pt) / モデル保存時の形式デフォルトはpt",
)
parser.add_argument("--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み")
parser.add_argument(
"--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み"
)
parser.add_argument(
"--num_vectors_per_token", type=int, default=1, help="number of vectors per token / トークンに割り当てるembeddingsの要素数"
)
@@ -668,7 +685,9 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること",
)
parser.add_argument("--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可")
parser.add_argument(
"--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可"
)
parser.add_argument(
"--use_object_template",
action="store_true",