Merge branch 'dev' into dev

This commit is contained in:
Kohya S
2023-03-19 10:25:22 +09:00
committed by GitHub
14 changed files with 4272 additions and 3551 deletions

View File

@@ -127,7 +127,33 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
## Change History ## Change History
- 9 Mar. 2023, 2023/3/9: - 11 Mar. 2023, 2023/3/11:
- Fix `svd_merge_lora.py` causes an error about the device.
- `svd_merge_lora.py` でデバイス関連のエラーが発生する不具合を修正しました。
- 10 Mar. 2023, 2023/3/10: release v0.5.1
- Fix to LoRA modules in the model are same to the previous (before 0.5.0) if Conv2d-3x3 is disabled (no `conv_dim` arg, default).
- Conv2D with kernel size 1x1 in ResNet modules were accidentally included in v0.5.0.
- Trained models with v0.5.0 will work with Web UI's built-in LoRA and Additional Networks extension.
- Fix an issue that dim (rank) of LoRA module is limited to the in/out dimensions of the target Linear/Conv2d (in case of the dim > 320).
- `resize_lora.py` now have a feature to `dynamic resizing` which means each LoRA module can have different ranks (dims). Thanks to mgz-dev for this great work!
- The appropriate rank is selected based on the complexity of each module with an algorithm specified in the command line arguments. For details: https://github.com/kohya-ss/sd-scripts/pull/243
- Multiple GPUs training is finally supported in `train_network.py`. Thanks to ddPn08 to solve this long running issue!
- Dataset with fine-tuning method (with metadata json) now works without images if `.npz` files exist. Thanks to rvhfxb!
- `train_network.py` can work if the current directory is not the directory where the script is in. Thanks to mio2333!
- Fix `extract_lora_from_models.py` and `svd_merge_lora.py` doesn't work with higher rank (>320).
- LoRAのConv2d-3x3拡張を行わない場合`conv_dim` を指定しない場合、以前v0.5.0)と同じ構成になるよう修正しました。
- ResNetのカーネルサイズ1x1のConv2dが誤って対象になっていました。
- ただv0.5.0で学習したモデルは Additional Networks 拡張、およびWeb UIのLoRA機能で問題なく使えると思われます。
- LoRAモジュールの dim (rank) が、対象モジュールの次元数以下に制限される不具合を修正しました320より大きい dim を指定した場合)。
- `resize_lora.py` に `dynamic resizing` リサイズ後の各LoRAモジュールが異なるrank (dim) を持てる機能を追加しました。mgz-dev 氏の貢献に感謝します。
- 適切なランクがコマンドライン引数で指定したアルゴリズムにより自動的に選択されます。詳細はこちらをご覧ください: https://github.com/kohya-ss/sd-scripts/pull/243
- `train_network.py` でマルチGPU学習をサポートしました。長年の懸案を解決された ddPn08 氏に感謝します。
- fine-tuning方式のデータセットメタデータ.jsonファイルを使うデータセットで `.npz` が存在するときには画像がなくても動作するようになりました。rvhfxb 氏に感謝します。
- 他のディレクトリから `train_network.py` を呼び出しても動作するよう変更しました。 mio2333 氏に感謝します。
- `extract_lora_from_models.py` および `svd_merge_lora.py` が320より大きいrankを指定すると動かない不具合を修正しました。
- 9 Mar. 2023, 2023/3/9: release v0.5.0
- There may be problems due to major changes. If you cannot revert back to the previous version when problems occur, please do not update for a while. - There may be problems due to major changes. If you cannot revert back to the previous version when problems occur, please do not update for a while.
- Minimum metadata (module name, dim, alpha and network_args) is recorded even with `--no_metadata`, issue https://github.com/kohya-ss/sd-scripts/issues/254 - Minimum metadata (module name, dim, alpha and network_args) is recorded even with `--no_metadata`, issue https://github.com/kohya-ss/sd-scripts/issues/254
- `train_network.py` supports LoRA for Conv2d-3x3 (extended to conv2d with a kernel size not 1x1). - `train_network.py` supports LoRA for Conv2d-3x3 (extended to conv2d with a kernel size not 1x1).

View File

@@ -5,6 +5,7 @@ import argparse
import gc import gc
import math import math
import os import os
import toml
from tqdm import tqdm from tqdm import tqdm
import torch import torch
@@ -15,349 +16,391 @@ from diffusers import DDPMScheduler
import library.train_util as train_util import library.train_util as train_util
import library.config_util as config_util import library.config_util as config_util
from library.config_util import ( from library.config_util import (
ConfigSanitizer, ConfigSanitizer,
BlueprintGenerator, BlueprintGenerator,
) )
def collate_fn(examples): def collate_fn(examples):
return examples[0] return examples[0]
def train(args): def train(args):
train_util.verify_training_args(args) train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True) train_util.prepare_dataset_args(args, True)
cache_latents = args.cache_latents cache_latents = args.cache_latents
if args.seed is not None: if args.seed is not None:
set_seed(args.seed) # 乱数系列を初期化する set_seed(args.seed) # 乱数系列を初期化する
tokenizer = train_util.load_tokenizer(args) tokenizer = train_util.load_tokenizer(args)
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, True)) blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, True))
if args.dataset_config is not None: if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}") print(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config) user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"] ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored): if any(getattr(args, attr) is not None for attr in ignored):
print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored))) print(
else: "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
user_config = { ", ".join(ignored)
"datasets": [{ )
"subsets": [{ )
"image_dir": args.train_data_dir,
"metadata_file": args.in_json,
}]
}]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
if args.debug_dataset:
train_util.debug_dataset(train_dataset_group)
return
if len(train_dataset_group) == 0:
print("No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。")
return
if cache_latents:
assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
# acceleratorを準備する
print("prepare accelerator")
accelerator, unwrap_model = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)
# モデルを読み込む
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype)
# verify load/save model formats
if load_stable_diffusion_format:
src_stable_diffusion_ckpt = args.pretrained_model_name_or_path
src_diffusers_model_path = None
else:
src_stable_diffusion_ckpt = None
src_diffusers_model_path = args.pretrained_model_name_or_path
if args.save_model_as is None:
save_stable_diffusion_format = load_stable_diffusion_format
use_safetensors = args.use_safetensors
else:
save_stable_diffusion_format = args.save_model_as.lower() == 'ckpt' or args.save_model_as.lower() == 'safetensors'
use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower())
# Diffusers版のxformers使用フラグを設定する関数
def set_diffusers_xformers_flag(model, valid):
# model.set_use_memory_efficient_attention_xformers(valid) # 次のリリースでなくなりそう
# pipeが自動で再帰的にset_use_memory_efficient_attention_xformersを探すんだって(;´Д`)
# U-Netだけ使う時にはどうすればいいのか……仕方ないからコピって使うか
# 0.10.2でなんか巻き戻って個別に指定するようになった(;^ω^)
# Recursively walk through all the children.
# Any children which exposes the set_use_memory_efficient_attention_xformers method
# gets the message
def fn_recursive_set_mem_eff(module: torch.nn.Module):
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
module.set_use_memory_efficient_attention_xformers(valid)
for child in module.children():
fn_recursive_set_mem_eff(child)
fn_recursive_set_mem_eff(model)
# モデルに xformers とか memory efficient attention を組み込む
if args.diffusers_xformers:
print("Use xformers by Diffusers")
set_diffusers_xformers_flag(unet, True)
else:
# Windows版のxformersはfloatで学習できないのでxformersを使わない設定も可能にしておく必要がある
print("Disable Diffusers' xformers")
set_diffusers_xformers_flag(unet, False)
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
# 学習を準備する
if cache_latents:
vae.to(accelerator.device, dtype=weight_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# 学習を準備する:モデルを適切な状態にする
training_models = []
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
training_models.append(unet)
if args.train_text_encoder:
print("enable text encoder training")
if args.gradient_checkpointing:
text_encoder.gradient_checkpointing_enable()
training_models.append(text_encoder)
else:
text_encoder.to(accelerator.device, dtype=weight_dtype)
text_encoder.requires_grad_(False) # text encoderは学習しない
if args.gradient_checkpointing:
text_encoder.gradient_checkpointing_enable()
text_encoder.train() # required for gradient_checkpointing
else: else:
text_encoder.eval() user_config = {
"datasets": [
{
"subsets": [
{
"image_dir": args.train_data_dir,
"metadata_file": args.in_json,
}
]
}
]
}
if not cache_latents: blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
vae.requires_grad_(False) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
vae.eval()
vae.to(accelerator.device, dtype=weight_dtype)
for m in training_models: if args.debug_dataset:
m.requires_grad_(True) train_util.debug_dataset(train_dataset_group)
params = [] return
for m in training_models: if len(train_dataset_group) == 0:
params.extend(m.parameters()) print(
params_to_optimize = params "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
)
return
# 学習に必要なクラスを準備する if cache_latents:
print("prepare optimizer, data loader etc.") assert (
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
# dataloaderを準備する # acceleratorを準備する
# DataLoaderのプロセス数0はメインプロセスになる print("prepare accelerator")
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで accelerator, unwrap_model = train_util.prepare_accelerator(args)
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
# 学習ステップ数を計算する # mixed precisionに対応した型を用意しておき適宜castする
if args.max_train_epochs is not None: weight_dtype, save_dtype = train_util.prepare_dtype(args)
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# lr schedulerを用意する # モデルを読み込む
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype)
# 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする # verify load/save model formats
if args.full_fp16: if load_stable_diffusion_format:
assert args.mixed_precision == "fp16", "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" src_stable_diffusion_ckpt = args.pretrained_model_name_or_path
print("enable full fp16 training.") src_diffusers_model_path = None
unet.to(weight_dtype) else:
text_encoder.to(weight_dtype) src_stable_diffusion_ckpt = None
src_diffusers_model_path = args.pretrained_model_name_or_path
# acceleratorがなんかよろしくやってくれるらしい if args.save_model_as is None:
if args.train_text_encoder: save_stable_diffusion_format = load_stable_diffusion_format
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( use_safetensors = args.use_safetensors
unet, text_encoder, optimizer, train_dataloader, lr_scheduler) else:
else: save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors"
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower())
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする # Diffusers版のxformers使用フラグを設定する関数
if args.full_fp16: def set_diffusers_xformers_flag(model, valid):
train_util.patch_accelerator_for_fp16_training(accelerator) # model.set_use_memory_efficient_attention_xformers(valid) # 次のリリースでなくなりそう
# pipeが自動で再帰的にset_use_memory_efficient_attention_xformersを探すんだって(;´Д`)
# U-Netだけ使う時にはどうすればいいのか……仕方ないからコピって使うか
# 0.10.2でなんか巻き戻って個別に指定するようになった(;^ω^)
# resumeする # Recursively walk through all the children.
if args.resume is not None: # Any children which exposes the set_use_memory_efficient_attention_xformers method
print(f"resume training from state: {args.resume}") # gets the message
accelerator.load_state(args.resume) def fn_recursive_set_mem_eff(module: torch.nn.Module):
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
module.set_use_memory_efficient_attention_xformers(valid)
# epoch数を計算する for child in module.children():
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) fn_recursive_set_mem_eff(child)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
# 学習する fn_recursive_set_mem_eff(model)
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
print("running training / 学習開始")
print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
print(f" num epochs / epoch数: {num_train_epochs}")
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") # モデルに xformers とか memory efficient attention を組み込む
global_step = 0 if args.diffusers_xformers:
print("Use xformers by Diffusers")
set_diffusers_xformers_flag(unet, True)
else:
# Windows版のxformersはfloatで学習できないのでxformersを使わない設定も可能にしておく必要がある
print("Disable Diffusers' xformers")
set_diffusers_xformers_flag(unet, False)
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", # 学習を準備する
num_train_timesteps=1000, clip_sample=False) if cache_latents:
vae.to(accelerator.device, dtype=weight_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
if accelerator.is_main_process: # 学習を準備する:モデルを適切な状態にする
accelerator.init_trackers("finetuning") training_models = []
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
training_models.append(unet)
for epoch in range(num_train_epochs): if args.train_text_encoder:
print(f"epoch {epoch+1}/{num_train_epochs}") print("enable text encoder training")
train_dataset_group.set_current_epoch(epoch + 1) if args.gradient_checkpointing:
text_encoder.gradient_checkpointing_enable()
training_models.append(text_encoder)
else:
text_encoder.to(accelerator.device, dtype=weight_dtype)
text_encoder.requires_grad_(False) # text encoderは学習しない
if args.gradient_checkpointing:
text_encoder.gradient_checkpointing_enable()
text_encoder.train() # required for gradient_checkpointing
else:
text_encoder.eval()
if not cache_latents:
vae.requires_grad_(False)
vae.eval()
vae.to(accelerator.device, dtype=weight_dtype)
for m in training_models: for m in training_models:
m.train() m.requires_grad_(True)
params = []
for m in training_models:
params.extend(m.parameters())
params_to_optimize = params
loss_total = 0 # 学習に必要なクラスを準備する
for step, batch in enumerate(train_dataloader): print("prepare optimizer, data loader etc.")
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
b_size = latents.shape[0]
with torch.set_grad_enabled(args.train_text_encoder): # dataloaderを準備する
# Get the text embedding for conditioning # DataLoaderのプロセス数0はメインプロセスになる
input_ids = batch["input_ids"].to(accelerator.device) n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
encoder_hidden_states = train_util.get_hidden_states( train_dataloader = torch.utils.data.DataLoader(
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype) train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collate_fn,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# Sample noise that we'll add to the latents # 学習ステップ数を計算する
noise = torch.randn_like(latents, device=latents.device) if args.max_train_epochs is not None:
if args.noise_offset: args.max_train_steps = args.max_train_epochs * len(train_dataloader)
# https://www.crosslabs.org//blog/diffusion-with-offset-noise print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
# Sample a random timestep for each image # lr schedulerを用意する
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep # 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする
# (this is the forward diffusion process) if args.full_fp16:
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) assert (
args.mixed_precision == "fp16"
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
print("enable full fp16 training.")
unet.to(weight_dtype)
text_encoder.to(weight_dtype)
# Predict the noise residual # acceleratorがなんかよろしくやってくれるらしい
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample if args.train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
if args.v_parameterization: # 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
# v-parameterization training if args.full_fp16:
target = noise_scheduler.get_velocity(latents, noise, timesteps) train_util.patch_accelerator_for_fp16_training(accelerator)
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean") # resumeする
if args.resume is not None:
print(f"resume training from state: {args.resume}")
accelerator.load_state(args.resume)
accelerator.backward(loss) # epoch数を計算する
if accelerator.sync_gradients and args.max_grad_norm != 0.0: num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
params_to_clip = [] num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
for m in training_models: if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
params_to_clip.extend(m.parameters()) args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step() # 学習する
lr_scheduler.step() total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
optimizer.zero_grad(set_to_none=True) print("running training / 学習開始")
print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
print(f" num epochs / epoch数: {num_train_epochs}")
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
# Checks if the accelerator has performed an optimization step behind the scenes progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
if accelerator.sync_gradients: global_step = 0
progress_bar.update(1)
global_step += 1
train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
)
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず if accelerator.is_main_process:
if args.logging_dir is not None: accelerator.init_trackers("finetuning")
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
accelerator.log(logs, step=global_step)
# TODO moving averageにする for epoch in range(num_train_epochs):
loss_total += current_loss print(f"epoch {epoch+1}/{num_train_epochs}")
avr_loss = loss_total / (step+1) train_dataset_group.set_current_epoch(epoch + 1)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps: for m in training_models:
break m.train()
if args.logging_dir is not None: loss_total = 0
logs = {"loss/epoch": loss_total / len(train_dataloader)} for step, batch in enumerate(train_dataloader):
accelerator.log(logs, step=epoch+1) with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
b_size = latents.shape[0]
accelerator.wait_for_everyone() with torch.set_grad_enabled(args.train_text_encoder):
# Get the text embedding for conditioning
input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states(
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
)
if args.save_every_n_epochs is not None: # Sample noise that we'll add to the latents
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path noise = torch.randn_like(latents, device=latents.device)
train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors, if args.noise_offset:
save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae) # https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) # Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
is_main_process = accelerator.is_main_process timesteps = timesteps.long()
if is_main_process:
unet = unwrap_model(unet)
text_encoder = unwrap_model(text_encoder)
accelerator.end_training() # Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
if args.save_state: # Predict the noise residual
train_util.save_state_on_train_end(args, accelerator) noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
del accelerator # この後メモリを使うのでこれは消す if args.v_parameterization:
# v-parameterization training
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
target = noise
if is_main_process: loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
train_util.save_sd_model_on_train_end(args, src_path, save_stable_diffusion_format, use_safetensors, accelerator.backward(loss)
save_dtype, epoch, global_step, text_encoder, unet, vae) if accelerator.sync_gradients and args.max_grad_norm != 0.0:
print("model saved.") params_to_clip = []
for m in training_models:
params_to_clip.extend(m.parameters())
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
train_util.sample_images(
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
)
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
logs["lr/d*lr"] = (
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
)
accelerator.log(logs, step=global_step)
# TODO moving averageにする
loss_total += current_loss
avr_loss = loss_total / (step + 1)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(train_dataloader)}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()
if args.save_every_n_epochs is not None:
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
train_util.save_sd_model_on_epoch_end(
args,
accelerator,
src_path,
save_stable_diffusion_format,
use_safetensors,
save_dtype,
epoch,
num_train_epochs,
global_step,
unwrap_model(text_encoder),
unwrap_model(unet),
vae,
)
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
is_main_process = accelerator.is_main_process
if is_main_process:
unet = unwrap_model(unet)
text_encoder = unwrap_model(text_encoder)
accelerator.end_training()
if args.save_state:
train_util.save_state_on_train_end(args, accelerator)
del accelerator # この後メモリを使うのでこれは消す
if is_main_process:
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
train_util.save_sd_model_on_train_end(
args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae
)
print("model saved.")
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser) train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, False, True, True) train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False) train_util.add_training_arguments(parser, False)
train_util.add_sd_saving_arguments(parser) train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser) train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser) config_util.add_config_arguments(parser)
parser.add_argument("--diffusers_xformers", action='store_true', parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
help='use xformers by diffusers / Diffusersでxformersを使用する') parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
args = parser.parse_args() args = parser.parse_args()
train(args) args = train_util.read_config_from_file(args, parser)
train(args)

View File

@@ -4,7 +4,7 @@ from pathlib import Path
from typing import List from typing import List
from tqdm import tqdm from tqdm import tqdm
import library.train_util as train_util import library.train_util as train_util
import os
def main(args): def main(args):
assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください" assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
@@ -29,6 +29,9 @@ def main(args):
caption_path = image_path.with_suffix(args.caption_extension) caption_path = image_path.with_suffix(args.caption_extension)
caption = caption_path.read_text(encoding='utf-8').strip() caption = caption_path.read_text(encoding='utf-8').strip()
if not os.path.exists(caption_path):
caption_path = os.path.join(image_path, args.caption_extension)
image_key = str(image_path) if args.full_path else image_path.stem image_key = str(image_path) if args.full_path else image_path.stem
if image_key not in metadata: if image_key not in metadata:
metadata[image_key] = {} metadata[image_key] = {}

View File

@@ -4,7 +4,7 @@ from pathlib import Path
from typing import List from typing import List
from tqdm import tqdm from tqdm import tqdm
import library.train_util as train_util import library.train_util as train_util
import os
def main(args): def main(args):
assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください" assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
@@ -29,6 +29,9 @@ def main(args):
tags_path = image_path.with_suffix(args.caption_extension) tags_path = image_path.with_suffix(args.caption_extension)
tags = tags_path.read_text(encoding='utf-8').strip() tags = tags_path.read_text(encoding='utf-8').strip()
if not os.path.exists(tags_path):
tags_path = os.path.join(image_path, args.caption_extension)
image_key = str(image_path) if args.full_path else image_path.stem image_key = str(image_path) if args.full_path else image_path.stem
if image_key not in metadata: if image_key not in metadata:
metadata[image_key] = {} metadata[image_key] = {}

File diff suppressed because it is too large Load Diff

View File

@@ -103,7 +103,8 @@ def svd(args):
if args.device: if args.device:
mat = mat.to(args.device) mat = mat.to(args.device)
# print(mat.size(), mat.device, rank, in_dim, out_dim)
# print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim)
rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
if conv2d: if conv2d:
@@ -137,27 +138,17 @@ def svd(args):
lora_weights[lora_name] = (U, Vh) lora_weights[lora_name] = (U, Vh)
# make state dict for LoRA # make state dict for LoRA
lora_network_o.apply_to(text_encoder_o, unet_o, text_encoder_different, True) # to make state dict lora_sd = {}
lora_sd = lora_network_o.state_dict() for lora_name, (up_weight, down_weight) in lora_weights.items():
print(f"LoRA has {len(lora_sd)} weights.") lora_sd[lora_name + '.lora_up.weight'] = up_weight
lora_sd[lora_name + '.lora_down.weight'] = down_weight
for key in list(lora_sd.keys()): lora_sd[lora_name + '.alpha'] = torch.tensor(down_weight.size()[0])
if "alpha" in key:
continue
lora_name = key.split('.')[0]
i = 0 if "lora_up" in key else 1
weights = lora_weights[lora_name][i]
# print(key, i, weights.size(), lora_sd[key].size())
# if len(lora_sd[key].size()) == 4:
# weights = weights.unsqueeze(2).unsqueeze(3)
assert weights.size() == lora_sd[key].size(), f"size unmatch: {key}"
lora_sd[key] = weights
# load state dict to LoRA and save it # load state dict to LoRA and save it
info = lora_network_o.load_state_dict(lora_sd) lora_network_save = lora.create_network_from_weights(1.0, None, None, text_encoder_o, unet_o, weights_sd=lora_sd)
lora_network_save.apply_to(text_encoder_o, unet_o) # create internal module references for state_dict
info = lora_network_save.load_state_dict(lora_sd)
print(f"Loading extracted LoRA weights: {info}") print(f"Loading extracted LoRA weights: {info}")
dir_name = os.path.dirname(args.save_to) dir_name = os.path.dirname(args.save_to)
@@ -167,7 +158,7 @@ def svd(args):
# minimum metadata # minimum metadata
metadata = {"ss_network_module": "networks.lora", "ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)} metadata = {"ss_network_module": "networks.lora", "ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)}
lora_network_o.save_weights(args.save_to, save_dtype, metadata) lora_network_save.save_weights(args.save_to, save_dtype, metadata)
print(f"LoRA weights are saved to: {args.save_to}") print(f"LoRA weights are saved to: {args.save_to}")

View File

@@ -21,30 +21,34 @@ class LoRAModule(torch.nn.Module):
""" if alpha == 0 or None, alpha is rank (no scaling). """ """ if alpha == 0 or None, alpha is rank (no scaling). """
super().__init__() super().__init__()
self.lora_name = lora_name self.lora_name = lora_name
self.lora_dim = lora_dim
if org_module.__class__.__name__ == 'Conv2d': if org_module.__class__.__name__ == 'Conv2d':
in_dim = org_module.in_channels in_dim = org_module.in_channels
out_dim = org_module.out_channels out_dim = org_module.out_channels
else:
in_dim = org_module.in_features
out_dim = org_module.out_features
self.lora_dim = min(self.lora_dim, in_dim, out_dim) # if limit_rank:
if self.lora_dim != lora_dim: # self.lora_dim = min(lora_dim, in_dim, out_dim)
print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") # if self.lora_dim != lora_dim:
# print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
# else:
self.lora_dim = lora_dim
if org_module.__class__.__name__ == 'Conv2d':
kernel_size = org_module.kernel_size kernel_size = org_module.kernel_size
stride = org_module.stride stride = org_module.stride
padding = org_module.padding padding = org_module.padding
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
else: else:
in_dim = org_module.in_features self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
out_dim = org_module.out_features self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
self.lora_down = torch.nn.Linear(in_dim, lora_dim, bias=False)
self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False)
if type(alpha) == torch.Tensor: if type(alpha) == torch.Tensor:
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
alpha = lora_dim if alpha is None or alpha == 0 else alpha alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
self.scale = alpha / self.lora_dim self.scale = alpha / self.lora_dim
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
@@ -149,12 +153,13 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
return network return network
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwargs): def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, **kwargs):
if os.path.splitext(file)[1] == '.safetensors': if weights_sd is None:
from safetensors.torch import load_file, safe_open if os.path.splitext(file)[1] == '.safetensors':
weights_sd = load_file(file) from safetensors.torch import load_file, safe_open
else: weights_sd = load_file(file)
weights_sd = torch.load(file, map_location='cpu') else:
weights_sd = torch.load(file, map_location='cpu')
# get dim/alpha mapping # get dim/alpha mapping
modules_dim = {} modules_dim = {}
@@ -174,7 +179,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwa
# support old LoRA without alpha # support old LoRA without alpha
for key in modules_dim.keys(): for key in modules_dim.keys():
if key not in modules_alpha: if key not in modules_alpha:
modules_alpha = modules_dim[key] modules_alpha = modules_dim[key]
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha) network = LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha)
network.weights_sd = weights_sd network.weights_sd = weights_sd
@@ -183,7 +188,8 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwa
class LoRANetwork(torch.nn.Module): class LoRANetwork(torch.nn.Module):
# is it possible to apply conv_in and conv_out? # is it possible to apply conv_in and conv_out?
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention", "ResnetBlock2D", "Downsample2D", "Upsample2D"] UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
LORA_PREFIX_UNET = 'lora_unet' LORA_PREFIX_UNET = 'lora_unet'
LORA_PREFIX_TEXT_ENCODER = 'lora_te' LORA_PREFIX_TEXT_ENCODER = 'lora_te'
@@ -245,7 +251,12 @@ class LoRANetwork(torch.nn.Module):
text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, LoRANetwork.UNET_TARGET_REPLACE_MODULE) # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
if modules_dim is not None or self.conv_lora_dim is not None:
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, target_modules)
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
self.weights_sd = None self.weights_sd = None
@@ -371,7 +382,7 @@ class LoRANetwork(torch.nn.Module):
else: else:
torch.save(state_dict, file) torch.save(state_dict, file)
@staticmethod @ staticmethod
def set_regions(networks, image): def set_regions(networks, image):
image = image.astype(np.float32) / 255.0 image = image.astype(np.float32) / 255.0
for i, network in enumerate(networks[:3]): for i, network in enumerate(networks[:3]):

View File

@@ -1,6 +1,6 @@
# Convert LoRA to different rank approximation (should only be used to go to lower rank) # Convert LoRA to different rank approximation (should only be used to go to lower rank)
# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py # This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
# Thanks to cloneofsimo and kohya # Thanks to cloneofsimo
import argparse import argparse
import torch import torch

View File

@@ -23,16 +23,16 @@ def load_state_dict(file_name, dtype):
return sd return sd
def save_to_file(file_name, model, state_dict, dtype): def save_to_file(file_name, state_dict, dtype):
if dtype is not None: if dtype is not None:
for key in list(state_dict.keys()): for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor: if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype) state_dict[key] = state_dict[key].to(dtype)
if os.path.splitext(file_name)[1] == '.safetensors': if os.path.splitext(file_name)[1] == '.safetensors':
save_file(model, file_name) save_file(state_dict, file_name)
else: else:
torch.save(model, file_name) torch.save(state_dict, file_name)
def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype): def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype):
@@ -77,6 +77,10 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
# W <- W + U * D # W <- W + U * D
scale = (alpha / network_dim) scale = (alpha / network_dim)
if device: # and isinstance(scale, torch.Tensor):
scale = scale.to(device)
if not conv2d: # linear if not conv2d: # linear
weight = weight + ratio * (up_weight @ down_weight) * scale weight = weight + ratio * (up_weight @ down_weight) * scale
elif kernel_size == (1, 1): elif kernel_size == (1, 1):
@@ -105,6 +109,7 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
mat = mat.squeeze() mat = mat.squeeze()
module_new_rank = new_conv_rank if conv2d_3x3 else new_rank module_new_rank = new_conv_rank if conv2d_3x3 else new_rank
module_new_rank = min(module_new_rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
U, S, Vh = torch.linalg.svd(mat) U, S, Vh = torch.linalg.svd(mat)
@@ -156,7 +161,7 @@ def merge(args):
state_dict = merge_lora_models(args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype) state_dict = merge_lora_models(args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype)
print(f"saving model to: {args.save_to}") print(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype) save_to_file(args.save_to, state_dict, save_dtype)
if __name__ == '__main__': if __name__ == '__main__':

View File

@@ -502,6 +502,14 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
clip_skipと同様に、モデルの学習状態と異なる長さで学習するには、ある程度の教師データ枚数、長めの学習時間が必要になると思われます。 clip_skipと同様に、モデルの学習状態と異なる長さで学習するには、ある程度の教師データ枚数、長めの学習時間が必要になると思われます。
- `--persistent_data_loader_workers`
Windows環境で指定するとエポック間の待ち時間が大幅に短縮されます。
- `--max_data_loader_n_workers`
データ読み込みのプロセス数を指定します。プロセス数が多いとデータ読み込みが速くなりGPUを効率的に利用できますが、メインメモリを消費します。デフォルトは「`8` または `CPU同時実行スレッド数-1` の小さいほう」なので、メインメモリに余裕がない場合や、GPU使用率が90%程度以上なら、それらの数値を見ながら `2` または `1` 程度まで下げてください。
- `--logging_dir` / `--log_prefix` - `--logging_dir` / `--log_prefix`
学習ログの保存に関するオプションです。logging_dirオプションにログ保存先フォルダを指定してください。TensorBoard形式のログが保存されます。 学習ログの保存に関するオプションです。logging_dirオプションにログ保存先フォルダを指定してください。TensorBoard形式のログが保存されます。

View File

@@ -7,6 +7,7 @@ import argparse
import itertools import itertools
import math import math
import os import os
import toml
from tqdm import tqdm from tqdm import tqdm
import torch import torch
@@ -17,346 +18,392 @@ from diffusers import DDPMScheduler
import library.train_util as train_util import library.train_util as train_util
import library.config_util as config_util import library.config_util as config_util
from library.config_util import ( from library.config_util import (
ConfigSanitizer, ConfigSanitizer,
BlueprintGenerator, BlueprintGenerator,
) )
def collate_fn(examples): def collate_fn(examples):
return examples[0] return examples[0]
def train(args): def train(args):
train_util.verify_training_args(args) train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, False) train_util.prepare_dataset_args(args, False)
cache_latents = args.cache_latents cache_latents = args.cache_latents
if args.seed is not None: if args.seed is not None:
set_seed(args.seed) # 乱数系列を初期化する set_seed(args.seed) # 乱数系列を初期化する
tokenizer = train_util.load_tokenizer(args) tokenizer = train_util.load_tokenizer(args)
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, True)) blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, True))
if args.dataset_config is not None: if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}") print(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config) user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "reg_data_dir"] ignored = ["train_data_dir", "reg_data_dir"]
if any(getattr(args, attr) is not None for attr in ignored): if any(getattr(args, attr) is not None for attr in ignored):
print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored))) print(
else: "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
user_config = { ", ".join(ignored)
"datasets": [{ )
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir) )
}] else:
} user_config = {
"datasets": [
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
if args.no_token_padding: if args.no_token_padding:
train_dataset_group.disable_token_padding() train_dataset_group.disable_token_padding()
if args.debug_dataset: if args.debug_dataset:
train_util.debug_dataset(train_dataset_group) train_util.debug_dataset(train_dataset_group)
return return
if cache_latents: if cache_latents:
assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" assert (
train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
# acceleratorを準備する # acceleratorを準備する
print("prepare accelerator") print("prepare accelerator")
if args.gradient_accumulation_steps > 1: if args.gradient_accumulation_steps > 1:
print(f"gradient_accumulation_steps is {args.gradient_accumulation_steps}. accelerate does not support gradient_accumulation_steps when training multiple models (U-Net and Text Encoder), so something might be wrong") print(
print( f"gradient_accumulation_steps is {args.gradient_accumulation_steps}. accelerate does not support gradient_accumulation_steps when training multiple models (U-Net and Text Encoder), so something might be wrong"
f"gradient_accumulation_stepsが{args.gradient_accumulation_steps}に設定されています。accelerateは複数モデルU-NetおよびText Encoderの学習時にgradient_accumulation_stepsをサポートしていないため結果は未知数です") )
print(
f"gradient_accumulation_stepsが{args.gradient_accumulation_steps}に設定されています。accelerateは複数モデルU-NetおよびText Encoderの学習時にgradient_accumulation_stepsをサポートしていないため結果は未知数です"
)
accelerator, unwrap_model = train_util.prepare_accelerator(args) accelerator, unwrap_model = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする # mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args) weight_dtype, save_dtype = train_util.prepare_dtype(args)
# モデルを読み込む # モデルを読み込む
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype) text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype)
# verify load/save model formats # verify load/save model formats
if load_stable_diffusion_format: if load_stable_diffusion_format:
src_stable_diffusion_ckpt = args.pretrained_model_name_or_path src_stable_diffusion_ckpt = args.pretrained_model_name_or_path
src_diffusers_model_path = None src_diffusers_model_path = None
else: else:
src_stable_diffusion_ckpt = None src_stable_diffusion_ckpt = None
src_diffusers_model_path = args.pretrained_model_name_or_path src_diffusers_model_path = args.pretrained_model_name_or_path
if args.save_model_as is None: if args.save_model_as is None:
save_stable_diffusion_format = load_stable_diffusion_format save_stable_diffusion_format = load_stable_diffusion_format
use_safetensors = args.use_safetensors use_safetensors = args.use_safetensors
else: else:
save_stable_diffusion_format = args.save_model_as.lower() == 'ckpt' or args.save_model_as.lower() == 'safetensors' save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors"
use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower()) use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower())
# モデルに xformers とか memory efficient attention を組み込む # モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
# 学習を準備する # 学習を準備する
if cache_latents: if cache_latents:
vae.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype)
vae.requires_grad_(False) vae.requires_grad_(False)
vae.eval() vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# 学習を準備する:モデルを適切な状態にする
train_text_encoder = args.stop_text_encoder_training is None or args.stop_text_encoder_training >= 0
unet.requires_grad_(True) # 念のため追加
text_encoder.requires_grad_(train_text_encoder)
if not train_text_encoder:
print("Text Encoder is not trained.")
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
text_encoder.gradient_checkpointing_enable()
if not cache_latents:
vae.requires_grad_(False)
vae.eval()
vae.to(accelerator.device, dtype=weight_dtype)
# 学習に必要なクラスを準備する
print("prepare optimizer, data loader etc.")
if train_text_encoder:
trainable_params = (itertools.chain(unet.parameters(), text_encoder.parameters()))
else:
trainable_params = unet.parameters()
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
if args.stop_text_encoder_training is None:
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
# lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する
lr_scheduler = train_util.get_scheduler_fix(args, optimizer)
# 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする
if args.full_fp16:
assert args.mixed_precision == "fp16", "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
print("enable full fp16 training.")
unet.to(weight_dtype)
text_encoder.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい
if train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
if not train_text_encoder:
text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator)
# resumeする
if args.resume is not None:
print(f"resume training from state: {args.resume}")
accelerator.load_state(args.resume)
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
# 学習する
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
print("running training / 学習開始")
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
print(f" num epochs / epoch数: {num_train_epochs}")
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0
noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
num_train_timesteps=1000, clip_sample=False)
if accelerator.is_main_process:
accelerator.init_trackers("dreambooth")
loss_list = []
loss_total = 0.0
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset_group.set_current_epoch(epoch + 1)
# 指定したステップ数までText Encoderを学習するepoch最初の状態
unet.train()
# train==True is required to enable gradient_checkpointing
if args.gradient_checkpointing or global_step < args.stop_text_encoder_training:
text_encoder.train()
for step, batch in enumerate(train_dataloader):
# 指定したステップ数でText Encoderの学習を止める
if global_step == args.stop_text_encoder_training:
print(f"stop text encoder training at step {global_step}")
if not args.gradient_checkpointing:
text_encoder.train(False)
text_encoder.requires_grad_(False)
with accelerator.accumulate(unet):
with torch.no_grad(): with torch.no_grad():
# latentに変換 train_dataset_group.cache_latents(vae)
if cache_latents: vae.to("cpu")
latents = batch["latents"].to(accelerator.device) if torch.cuda.is_available():
else: torch.cuda.empty_cache()
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() gc.collect()
latents = latents * 0.18215
b_size = latents.shape[0]
# Sample noise that we'll add to the latents # 学習を準備する:モデルを適切な状態にする
noise = torch.randn_like(latents, device=latents.device) train_text_encoder = args.stop_text_encoder_training is None or args.stop_text_encoder_training >= 0
if args.noise_offset: unet.requires_grad_(True) # 念のため追加
# https://www.crosslabs.org//blog/diffusion-with-offset-noise text_encoder.requires_grad_(train_text_encoder)
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) if not train_text_encoder:
print("Text Encoder is not trained.")
# Get the text embedding for conditioning if args.gradient_checkpointing:
with torch.set_grad_enabled(global_step < args.stop_text_encoder_training): unet.enable_gradient_checkpointing()
input_ids = batch["input_ids"].to(accelerator.device) text_encoder.gradient_checkpointing_enable()
encoder_hidden_states = train_util.get_hidden_states(
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype)
# Sample a random timestep for each image if not cache_latents:
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) vae.requires_grad_(False)
timesteps = timesteps.long() vae.eval()
vae.to(accelerator.device, dtype=weight_dtype)
# Add noise to the latents according to the noise magnitude at each timestep # 学習に必要なクラスを準備する
# (this is the forward diffusion process) print("prepare optimizer, data loader etc.")
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) if train_text_encoder:
trainable_params = itertools.chain(unet.parameters(), text_encoder.parameters())
else:
trainable_params = unet.parameters()
# Predict the noise residual _, _, optimizer = train_util.get_optimizer(args, trainable_params)
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
if args.v_parameterization: # dataloaderを準備する
# v-parameterization training # DataLoaderのプロセス数0はメインプロセスになる
target = noise_scheduler.get_velocity(latents, noise, timesteps) n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
else: train_dataloader = torch.utils.data.DataLoader(
target = noise train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collate_fn,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") # 学習ステップ数を計算する
loss = loss.mean([1, 2, 3]) if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
loss_weights = batch["loss_weights"] # 各sampleごとのweight if args.stop_text_encoder_training is None:
loss = loss * loss_weights args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし # lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する
lr_scheduler = train_util.get_scheduler_fix(args, optimizer)
accelerator.backward(loss) # 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする
if accelerator.sync_gradients and args.max_grad_norm != 0.0: if args.full_fp16:
if train_text_encoder: assert (
params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters())) args.mixed_precision == "fp16"
else: ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
params_to_clip = unet.parameters() print("enable full fp16 training.")
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) unet.to(weight_dtype)
text_encoder.to(weight_dtype)
optimizer.step() # acceleratorがなんかよろしくやってくれるらしい
lr_scheduler.step() if train_text_encoder:
optimizer.zero_grad(set_to_none=True) unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
# Checks if the accelerator has performed an optimization step behind the scenes if not train_text_encoder:
if accelerator.sync_gradients: text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error
progress_bar.update(1)
global_step += 1
train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) # 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator)
current_loss = loss.detach().item() # resumeする
if args.logging_dir is not None: if args.resume is not None:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} print(f"resume training from state: {args.resume}")
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value accelerator.load_state(args.resume)
logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
accelerator.log(logs, step=global_step)
if epoch == 0: # epoch数を計算する
loss_list.append(current_loss) num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
else: num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
loss_total -= loss_list[step] if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
loss_list[step] = current_loss args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
loss_total += current_loss
avr_loss = loss_total / len(loss_list)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps: # 学習する
break total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
print("running training / 学習開始")
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
print(f" num epochs / epoch数: {num_train_epochs}")
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
if args.logging_dir is not None: progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
logs = {"loss/epoch": loss_total / len(loss_list)} global_step = 0
accelerator.log(logs, step=epoch+1)
accelerator.wait_for_everyone() noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
)
if args.save_every_n_epochs is not None: if accelerator.is_main_process:
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path accelerator.init_trackers("dreambooth")
train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) loss_list = []
loss_total = 0.0
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset_group.set_current_epoch(epoch + 1)
is_main_process = accelerator.is_main_process # 指定したステップ数までText Encoderを学習するepoch最初の状態
if is_main_process: unet.train()
unet = unwrap_model(unet) # train==True is required to enable gradient_checkpointing
text_encoder = unwrap_model(text_encoder) if args.gradient_checkpointing or global_step < args.stop_text_encoder_training:
text_encoder.train()
accelerator.end_training() for step, batch in enumerate(train_dataloader):
# 指定したステップ数でText Encoderの学習を止める
if global_step == args.stop_text_encoder_training:
print(f"stop text encoder training at step {global_step}")
if not args.gradient_checkpointing:
text_encoder.train(False)
text_encoder.requires_grad_(False)
if args.save_state: with accelerator.accumulate(unet):
train_util.save_state_on_train_end(args, accelerator) with torch.no_grad():
# latentに変換
if cache_latents:
latents = batch["latents"].to(accelerator.device)
else:
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
b_size = latents.shape[0]
del accelerator # この後メモリを使うのでこれは消す # Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
if is_main_process: # Get the text embedding for conditioning
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path with torch.set_grad_enabled(global_step < args.stop_text_encoder_training):
train_util.save_sd_model_on_train_end(args, src_path, save_stable_diffusion_format, use_safetensors, input_ids = batch["input_ids"].to(accelerator.device)
save_dtype, epoch, global_step, text_encoder, unet, vae) encoder_hidden_states = train_util.get_hidden_states(
print("model saved.") args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
)
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Predict the noise residual
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
if args.v_parameterization:
# v-parameterization training
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
if train_text_encoder:
params_to_clip = itertools.chain(unet.parameters(), text_encoder.parameters())
else:
params_to_clip = unet.parameters()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
train_util.sample_images(
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
)
current_loss = loss.detach().item()
if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
logs["lr/d*lr"] = (
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
)
accelerator.log(logs, step=global_step)
if epoch == 0:
loss_list.append(current_loss)
else:
loss_total -= loss_list[step]
loss_list[step] = current_loss
loss_total += current_loss
avr_loss = loss_total / len(loss_list)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(loss_list)}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()
if args.save_every_n_epochs is not None:
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
train_util.save_sd_model_on_epoch_end(
args,
accelerator,
src_path,
save_stable_diffusion_format,
use_safetensors,
save_dtype,
epoch,
num_train_epochs,
global_step,
unwrap_model(text_encoder),
unwrap_model(unet),
vae,
)
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
is_main_process = accelerator.is_main_process
if is_main_process:
unet = unwrap_model(unet)
text_encoder = unwrap_model(text_encoder)
accelerator.end_training()
if args.save_state:
train_util.save_state_on_train_end(args, accelerator)
del accelerator # この後メモリを使うのでこれは消す
if is_main_process:
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
train_util.save_sd_model_on_train_end(
args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae
)
print("model saved.")
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser) train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, False, True) train_util.add_dataset_arguments(parser, True, False, True)
train_util.add_training_arguments(parser, True) train_util.add_training_arguments(parser, True)
train_util.add_sd_saving_arguments(parser) train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser) train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser) config_util.add_config_arguments(parser)
parser.add_argument("--no_token_padding", action="store_true", parser.add_argument(
help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にするDiffusers版DreamBoothと同じ動作") "--no_token_padding",
parser.add_argument("--stop_text_encoder_training", type=int, default=None, action="store_true",
help="steps to stop text encoder training, -1 for no training / Text Encoderの学習を止めるステップ数、-1で最初から学習しない") help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にするDiffusers版DreamBoothと同じ動作",
)
parser.add_argument(
"--stop_text_encoder_training",
type=int,
default=None,
help="steps to stop text encoder training, -1 for no training / Text Encoderの学習を止めるステップ数、-1で最初から学習しない",
)
args = parser.parse_args() args = parser.parse_args()
train(args) args = train_util.read_config_from_file(args, parser)
train(args)

File diff suppressed because it is too large Load Diff

View File

@@ -64,6 +64,10 @@ accelerate launch --num_cpu_threads_per_process 1 train_network.py
* LoRAのRANKを指定します``--networkdim=4``など。省略時は4になります。数が多いほど表現力は増しますが、学習に必要なメモリ、時間は増えます。また闇雲に増やしても良くないようです。 * LoRAのRANKを指定します``--networkdim=4``など。省略時は4になります。数が多いほど表現力は増しますが、学習に必要なメモリ、時間は増えます。また闇雲に増やしても良くないようです。
* `--network_alpha` * `--network_alpha`
* アンダーフローを防ぎ安定して学習するための ``alpha`` 値を指定します。デフォルトは1です。``network_dim``と同じ値を指定すると以前のバージョンと同じ動作になります。 * アンダーフローを防ぎ安定して学習するための ``alpha`` 値を指定します。デフォルトは1です。``network_dim``と同じ値を指定すると以前のバージョンと同じ動作になります。
* `--persistent_data_loader_workers`
* Windows環境で指定するとエポック間の待ち時間が大幅に短縮されます。
* `--max_data_loader_n_workers`
* データ読み込みのプロセス数を指定します。プロセス数が多いとデータ読み込みが速くなりGPUを効率的に利用できますが、メインメモリを消費します。デフォルトは「`8` または `CPU同時実行スレッド数-1` の小さいほう」なので、メインメモリに余裕がない場合や、GPU使用率が90%程度以上なら、それらの数値を見ながら `2` または `1` 程度まで下げてください。
* `--network_weights` * `--network_weights`
* 学習前に学習済みのLoRAの重みを読み込み、そこから追加で学習します。 * 学習前に学習済みのLoRAの重みを読み込み、そこから追加で学習します。
* `--network_train_unet_only` * `--network_train_unet_only`

View File

@@ -3,6 +3,7 @@ import argparse
import gc import gc
import math import math
import os import os
import toml
from tqdm import tqdm from tqdm import tqdm
import torch import torch
@@ -13,8 +14,8 @@ from diffusers import DDPMScheduler
import library.train_util as train_util import library.train_util as train_util
import library.config_util as config_util import library.config_util as config_util
from library.config_util import ( from library.config_util import (
ConfigSanitizer, ConfigSanitizer,
BlueprintGenerator, BlueprintGenerator,
) )
imagenet_templates_small = [ imagenet_templates_small = [
@@ -71,454 +72,500 @@ imagenet_style_templates_small = [
def collate_fn(examples): def collate_fn(examples):
return examples[0] return examples[0]
def train(args): def train(args):
if args.output_name is None: if args.output_name is None:
args.output_name = args.token_string args.output_name = args.token_string
use_template = args.use_object_template or args.use_style_template use_template = args.use_object_template or args.use_style_template
train_util.verify_training_args(args) train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True) train_util.prepare_dataset_args(args, True)
cache_latents = args.cache_latents cache_latents = args.cache_latents
if args.seed is not None: if args.seed is not None:
set_seed(args.seed) set_seed(args.seed)
tokenizer = train_util.load_tokenizer(args) tokenizer = train_util.load_tokenizer(args)
# acceleratorを準備する # acceleratorを準備する
print("prepare accelerator") print("prepare accelerator")
accelerator, unwrap_model = train_util.prepare_accelerator(args) accelerator, unwrap_model = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする # mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args) weight_dtype, save_dtype = train_util.prepare_dtype(args)
# モデルを読み込む # モデルを読み込む
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype) text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype)
# Convert the init_word to token_id # Convert the init_word to token_id
if args.init_word is not None: if args.init_word is not None:
init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False) init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False)
if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token: if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token:
print( print(
f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}") f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}"
else: )
init_token_ids = None
# add new word to tokenizer, count is num_vectors_per_token
token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)]
num_added_tokens = tokenizer.add_tokens(token_strings)
assert num_added_tokens == args.num_vectors_per_token, f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}"
token_ids = tokenizer.convert_tokens_to_ids(token_strings)
print(f"tokens are added: {token_ids}")
assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered"
assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}"
# Resize the token embeddings as we are adding new special tokens to the tokenizer
text_encoder.resize_token_embeddings(len(tokenizer))
# Initialise the newly added placeholder token with the embeddings of the initializer token
token_embeds = text_encoder.get_input_embeddings().weight.data
if init_token_ids is not None:
for i, token_id in enumerate(token_ids):
token_embeds[token_id] = token_embeds[init_token_ids[i % len(init_token_ids)]]
# print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
# load weights
if args.weights is not None:
embeddings = load_weights(args.weights)
assert len(token_ids) == len(
embeddings), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}"
# print(token_ids, embeddings.size())
for token_id, embedding in zip(token_ids, embeddings):
token_embeds[token_id] = embedding
# print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
print(f"weighs loaded")
print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
# データセットを準備する
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False))
if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))
else:
use_dreambooth_method = args.in_json is None
if use_dreambooth_method:
print("Use DreamBooth method.")
user_config = {
"datasets": [{
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)
}]
}
else: else:
print("Train with captions.") init_token_ids = None
user_config = {
"datasets": [{
"subsets": [{
"image_dir": args.train_data_dir,
"metadata_file": args.in_json,
}]
}]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) # add new word to tokenizer, count is num_vectors_per_token
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)]
num_added_tokens = tokenizer.add_tokens(token_strings)
assert (
num_added_tokens == args.num_vectors_per_token
), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}"
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装 token_ids = tokenizer.convert_tokens_to_ids(token_strings)
if use_template: print(f"tokens are added: {token_ids}")
print("use template for training captions. is object: {args.use_object_template}") assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered"
templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}"
replace_to = " ".join(token_strings)
captions = []
for tmpl in templates:
captions.append(tmpl.format(replace_to))
train_dataset_group.add_replacement("", captions)
if args.num_vectors_per_token > 1: # Resize the token embeddings as we are adding new special tokens to the tokenizer
prompt_replacement = (args.token_string, replace_to) text_encoder.resize_token_embeddings(len(tokenizer))
# Initialise the newly added placeholder token with the embeddings of the initializer token
token_embeds = text_encoder.get_input_embeddings().weight.data
if init_token_ids is not None:
for i, token_id in enumerate(token_ids):
token_embeds[token_id] = token_embeds[init_token_ids[i % len(init_token_ids)]]
# print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
# load weights
if args.weights is not None:
embeddings = load_weights(args.weights)
assert len(token_ids) == len(
embeddings
), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}"
# print(token_ids, embeddings.size())
for token_id, embedding in zip(token_ids, embeddings):
token_embeds[token_id] = embedding
# print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
print(f"weighs loaded")
print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
# データセットを準備する
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False))
if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else: else:
prompt_replacement = None use_dreambooth_method = args.in_json is None
else: if use_dreambooth_method:
if args.num_vectors_per_token > 1: print("Use DreamBooth method.")
replace_to = " ".join(token_strings) user_config = {
train_dataset_group.add_replacement(args.token_string, replace_to) "datasets": [
prompt_replacement = (args.token_string, replace_to) {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
else: ]
prompt_replacement = None }
if args.debug_dataset:
train_util.debug_dataset(train_dataset_group, show_input_ids=True)
return
if len(train_dataset_group) == 0:
print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
return
if cache_latents:
assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
# 学習を準備する
if cache_latents:
vae.to(accelerator.device, dtype=weight_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
text_encoder.gradient_checkpointing_enable()
# 学習に必要なクラスを準備する
print("prepare optimizer, data loader etc.")
trainable_params = text_encoder.get_input_embeddings().parameters()
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# acceleratorがなんかよろしくやってくれるらしい
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoder, optimizer, train_dataloader, lr_scheduler)
index_no_updates = torch.arange(len(tokenizer)) < token_ids[0]
# print(len(index_no_updates), torch.sum(index_no_updates))
orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
# Freeze all parameters except for the token embeddings in text encoder
text_encoder.requires_grad_(True)
text_encoder.text_model.encoder.requires_grad_(False)
text_encoder.text_model.final_layer_norm.requires_grad_(False)
text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
# text_encoder.text_model.embeddings.token_embedding.requires_grad_(True)
unet.requires_grad_(False)
unet.to(accelerator.device, dtype=weight_dtype)
if args.gradient_checkpointing: # according to TI example in Diffusers, train is required
unet.train()
else:
unet.eval()
if not cache_latents:
vae.requires_grad_(False)
vae.eval()
vae.to(accelerator.device, dtype=weight_dtype)
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator)
text_encoder.to(weight_dtype)
# resumeする
if args.resume is not None:
print(f"resume training from state: {args.resume}")
accelerator.load_state(args.resume)
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
# 学習する
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
print("running training / 学習開始")
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
print(f" num epochs / epoch数: {num_train_epochs}")
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0
noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
num_train_timesteps=1000, clip_sample=False)
if accelerator.is_main_process:
accelerator.init_trackers("textual_inversion")
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset_group.set_current_epoch(epoch + 1)
text_encoder.train()
loss_total = 0
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(text_encoder):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
b_size = latents.shape[0]
# Get the text embedding for conditioning
input_ids = batch["input_ids"].to(accelerator.device)
# weight_dtype) use float instead of fp16/bf16 because text encoder is float
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, torch.float)
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Predict the noise residual
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
if args.v_parameterization:
# v-parameterization training
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else: else:
target = noise print("Train with captions.")
user_config = {
"datasets": [
{
"subsets": [
{
"image_dir": args.train_data_dir,
"metadata_file": args.in_json,
}
]
}
]
}
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
loss = loss.mean([1, 2, 3]) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
loss_weights = batch["loss_weights"] # 各sampleごとのweight # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
loss = loss * loss_weights if use_template:
print("use template for training captions. is object: {args.use_object_template}")
templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small
replace_to = " ".join(token_strings)
captions = []
for tmpl in templates:
captions.append(tmpl.format(replace_to))
train_dataset_group.add_replacement("", captions)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし if args.num_vectors_per_token > 1:
prompt_replacement = (args.token_string, replace_to)
else:
prompt_replacement = None
else:
if args.num_vectors_per_token > 1:
replace_to = " ".join(token_strings)
train_dataset_group.add_replacement(args.token_string, replace_to)
prompt_replacement = (args.token_string, replace_to)
else:
prompt_replacement = None
accelerator.backward(loss) if args.debug_dataset:
if accelerator.sync_gradients and args.max_grad_norm != 0.0: train_util.debug_dataset(train_dataset_group, show_input_ids=True)
params_to_clip = text_encoder.get_input_embeddings().parameters() return
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) if len(train_dataset_group) == 0:
print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
return
optimizer.step() if cache_latents:
lr_scheduler.step() assert (
optimizer.zero_grad(set_to_none=True) train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
# Let's make sure we don't update any embedding weights besides the newly added token # モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
# 学習を準備する
if cache_latents:
vae.to(accelerator.device, dtype=weight_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad(): with torch.no_grad():
unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[index_no_updates] train_dataset_group.cache_latents(vae)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# Checks if the accelerator has performed an optimization step behind the scenes if args.gradient_checkpointing:
if accelerator.sync_gradients: unet.enable_gradient_checkpointing()
progress_bar.update(1) text_encoder.gradient_checkpointing_enable()
global_step += 1
train_util.sample_images(accelerator, args, None, global_step, accelerator.device, # 学習に必要なクラスを準備する
vae, tokenizer, text_encoder, unet, prompt_replacement) print("prepare optimizer, data loader etc.")
trainable_params = text_encoder.get_input_embeddings().parameters()
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
current_loss = loss.detach().item() # dataloaderを準備する
if args.logging_dir is not None: # DataLoaderのプロセス数0はメインプロセスになる
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value train_dataloader = torch.utils.data.DataLoader(
logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr'] train_dataset_group,
accelerator.log(logs, step=global_step) batch_size=1,
shuffle=True,
collate_fn=collate_fn,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
loss_total += current_loss # 学習ステップ数を計算する
avr_loss = loss_total / (step+1) if args.max_train_epochs is not None:
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} args.max_train_steps = args.max_train_epochs * len(train_dataloader)
progress_bar.set_postfix(**logs) print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
if global_step >= args.max_train_steps: # lr schedulerを用意する
break lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
if args.logging_dir is not None: # acceleratorがなんかよろしくやってくれるらしい
logs = {"loss/epoch": loss_total / len(train_dataloader)} text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
accelerator.log(logs, step=epoch+1) text_encoder, optimizer, train_dataloader, lr_scheduler
)
accelerator.wait_for_everyone() index_no_updates = torch.arange(len(tokenizer)) < token_ids[0]
# print(len(index_no_updates), torch.sum(index_no_updates))
orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone() # Freeze all parameters except for the token embeddings in text encoder
text_encoder.requires_grad_(True)
text_encoder.text_model.encoder.requires_grad_(False)
text_encoder.text_model.final_layer_norm.requires_grad_(False)
text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
# text_encoder.text_model.embeddings.token_embedding.requires_grad_(True)
if args.save_every_n_epochs is not None: unet.requires_grad_(False)
model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name unet.to(accelerator.device, dtype=weight_dtype)
if args.gradient_checkpointing: # according to TI example in Diffusers, train is required
unet.train()
else:
unet.eval()
def save_func(): if not cache_latents:
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as vae.requires_grad_(False)
vae.eval()
vae.to(accelerator.device, dtype=weight_dtype)
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator)
text_encoder.to(weight_dtype)
# resumeする
if args.resume is not None:
print(f"resume training from state: {args.resume}")
accelerator.load_state(args.resume)
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
# 学習する
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
print("running training / 学習開始")
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
print(f" num epochs / epoch数: {num_train_epochs}")
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0
noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
)
if accelerator.is_main_process:
accelerator.init_trackers("textual_inversion")
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset_group.set_current_epoch(epoch + 1)
text_encoder.train()
loss_total = 0
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(text_encoder):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
b_size = latents.shape[0]
# Get the text embedding for conditioning
input_ids = batch["input_ids"].to(accelerator.device)
# weight_dtype) use float instead of fp16/bf16 because text encoder is float
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, torch.float)
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Predict the noise residual
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
if args.v_parameterization:
# v-parameterization training
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = text_encoder.get_input_embeddings().parameters()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
# Let's make sure we don't update any embedding weights besides the newly added token
with torch.no_grad():
unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[
index_no_updates
]
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
train_util.sample_images(
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement
)
current_loss = loss.detach().item()
if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
logs["lr/d*lr"] = (
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
)
accelerator.log(logs, step=global_step)
loss_total += current_loss
avr_loss = loss_total / (step + 1)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(train_dataloader)}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
if args.save_every_n_epochs is not None:
model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
def save_func():
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + "." + args.save_model_as
ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"saving checkpoint: {ckpt_file}")
save_weights(ckpt_file, updated_embs, save_dtype)
def remove_old_func(old_epoch_no):
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
if os.path.exists(old_ckpt_file):
print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)
saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
if saving and args.save_state:
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
train_util.sample_images(
accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement
)
# end of epoch
is_main_process = accelerator.is_main_process
if is_main_process:
text_encoder = unwrap_model(text_encoder)
accelerator.end_training()
if args.save_state:
train_util.save_state_on_train_end(args, accelerator)
updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone()
del accelerator # この後メモリを使うのでこれは消す
if is_main_process:
os.makedirs(args.output_dir, exist_ok=True)
model_name = train_util.DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
ckpt_name = model_name + "." + args.save_model_as
ckpt_file = os.path.join(args.output_dir, ckpt_name) ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"saving checkpoint: {ckpt_file}")
print(f"save trained model to {ckpt_file}")
save_weights(ckpt_file, updated_embs, save_dtype) save_weights(ckpt_file, updated_embs, save_dtype)
print("model saved.")
def remove_old_func(old_epoch_no):
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
if os.path.exists(old_ckpt_file):
print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)
saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
if saving and args.save_state:
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device,
vae, tokenizer, text_encoder, unet, prompt_replacement)
# end of epoch
is_main_process = accelerator.is_main_process
if is_main_process:
text_encoder = unwrap_model(text_encoder)
accelerator.end_training()
if args.save_state:
train_util.save_state_on_train_end(args, accelerator)
updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone()
del accelerator # この後メモリを使うのでこれは消す
if is_main_process:
os.makedirs(args.output_dir, exist_ok=True)
model_name = train_util.DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
ckpt_name = model_name + '.' + args.save_model_as
ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"save trained model to {ckpt_file}")
save_weights(ckpt_file, updated_embs, save_dtype)
print("model saved.")
def save_weights(file, updated_embs, save_dtype): def save_weights(file, updated_embs, save_dtype):
state_dict = {"emb_params": updated_embs} state_dict = {"emb_params": updated_embs}
if save_dtype is not None: if save_dtype is not None:
for key in list(state_dict.keys()): for key in list(state_dict.keys()):
v = state_dict[key] v = state_dict[key]
v = v.detach().clone().to("cpu").to(save_dtype) v = v.detach().clone().to("cpu").to(save_dtype)
state_dict[key] = v state_dict[key] = v
if os.path.splitext(file)[1] == '.safetensors': if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import save_file from safetensors.torch import save_file
save_file(state_dict, file)
else: save_file(state_dict, file)
torch.save(state_dict, file) # can be loaded in Web UI else:
torch.save(state_dict, file) # can be loaded in Web UI
def load_weights(file): def load_weights(file):
if os.path.splitext(file)[1] == '.safetensors': if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file from safetensors.torch import load_file
data = load_file(file)
else:
# compatible to Web UI's file format
data = torch.load(file, map_location='cpu')
if type(data) != dict:
raise ValueError(f"weight file is not dict / 重みファイルがdict形式ではありません: {file}")
if 'string_to_param' in data: # textual inversion embeddings data = load_file(file)
data = data['string_to_param'] else:
if hasattr(data, '_parameters'): # support old PyTorch? # compatible to Web UI's file format
data = getattr(data, '_parameters') data = torch.load(file, map_location="cpu")
if type(data) != dict:
raise ValueError(f"weight file is not dict / 重みファイルがdict形式ではありません: {file}")
emb = next(iter(data.values())) if "string_to_param" in data: # textual inversion embeddings
if type(emb) != torch.Tensor: data = data["string_to_param"]
raise ValueError(f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {file}") if hasattr(data, "_parameters"): # support old PyTorch?
data = getattr(data, "_parameters")
if len(emb.size()) == 1: emb = next(iter(data.values()))
emb = emb.unsqueeze(0) if type(emb) != torch.Tensor:
raise ValueError(f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {file}")
return emb if len(emb.size()) == 1:
emb = emb.unsqueeze(0)
return emb
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser) train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, False) train_util.add_dataset_arguments(parser, True, True, False)
train_util.add_training_arguments(parser, True) train_util.add_training_arguments(parser, True)
train_util.add_optimizer_arguments(parser) train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser) config_util.add_config_arguments(parser)
parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"], parser.add_argument(
help="format to save the model (default is .pt) / モデル保存時の形式デフォルトはpt") "--save_model_as",
type=str,
default="pt",
choices=[None, "ckpt", "pt", "safetensors"],
help="format to save the model (default is .pt) / モデル保存時の形式デフォルトはpt",
)
parser.add_argument("--weights", type=str, default=None, parser.add_argument("--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み")
help="embedding weights to initialize / 学習するネットワークの初期重み") parser.add_argument(
parser.add_argument("--num_vectors_per_token", type=int, default=1, "--num_vectors_per_token", type=int, default=1, help="number of vectors per token / トークンに割り当てるembeddingsの要素数"
help='number of vectors per token / トークンに割り当てるembeddingsの要素数') )
parser.add_argument("--token_string", type=str, default=None, parser.add_argument(
help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること") "--token_string",
parser.add_argument("--init_word", type=str, default=None, type=str,
help="words to initialize vector / ベクトルを初期化に使用する単語、複数可") default=None,
parser.add_argument("--use_object_template", action='store_true', help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること",
help="ignore caption and use default templates for object / キャプションは使わずデフォルトの物体用テンプレートで学習する") )
parser.add_argument("--use_style_template", action='store_true', parser.add_argument("--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可")
help="ignore caption and use default templates for stype / キャプションは使わずデフォルトのスタイル用テンプレートで学習する") parser.add_argument(
"--use_object_template",
action="store_true",
help="ignore caption and use default templates for object / キャプションは使わずデフォルトの物体用テンプレートで学習する",
)
parser.add_argument(
"--use_style_template",
action="store_true",
help="ignore caption and use default templates for stype / キャプションは使わずデフォルトのスタイル用テンプレートで学習する",
)
args = parser.parse_args() args = parser.parse_args()
train(args) args = train_util.read_config_from_file(args, parser)
train(args)