mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
add sdxl fine-tuning and LoRA
This commit is contained in:
@@ -1061,6 +1061,16 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt
|
|||||||
|
|
||||||
return text_model, vae, unet
|
return text_model, vae, unet
|
||||||
|
|
||||||
|
def get_model_version_str_for_sd1_sd2(v2, v_parameterization):
|
||||||
|
# only for reference
|
||||||
|
version_str = "sd"
|
||||||
|
if v2:
|
||||||
|
version_str += "_v2"
|
||||||
|
else:
|
||||||
|
version_str += "_v1"
|
||||||
|
if v_parameterization:
|
||||||
|
version_str += "_v"
|
||||||
|
return version_str
|
||||||
|
|
||||||
def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False):
|
def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False):
|
||||||
def convert_key(key):
|
def convert_key(key):
|
||||||
|
|||||||
@@ -6,6 +6,10 @@ from library import model_util
|
|||||||
from library import sdxl_original_unet
|
from library import sdxl_original_unet
|
||||||
|
|
||||||
|
|
||||||
|
VAE_SCALE_FACTOR = 0.13025
|
||||||
|
MODEL_VERSION_SDXL_BASE_V0_9 = "sdxl_base_v0-9"
|
||||||
|
|
||||||
|
|
||||||
def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
|
def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
|
||||||
SDXL_KEY_PREFIX = "conditioner.embedders.1.model."
|
SDXL_KEY_PREFIX = "conditioner.embedders.1.model."
|
||||||
|
|
||||||
@@ -76,8 +80,8 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
|
|||||||
return new_sd, text_projection, logit_scale
|
return new_sd, text_projection, logit_scale
|
||||||
|
|
||||||
|
|
||||||
def load_models_from_sdxl_checkpoint(model_type, ckpt_path, map_location):
|
def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location):
|
||||||
# model_type is reserved to future use
|
# model_version is reserved for future use
|
||||||
|
|
||||||
# Load the state dict
|
# Load the state dict
|
||||||
if model_util.is_safetensors(ckpt_path):
|
if model_util.is_safetensors(ckpt_path):
|
||||||
|
|||||||
@@ -1069,8 +1069,8 @@ class SdxlUNet2DConditionModel(nn.Module):
|
|||||||
t_emb = t_emb.to(x.dtype)
|
t_emb = t_emb.to(x.dtype)
|
||||||
emb = self.time_embed(t_emb)
|
emb = self.time_embed(t_emb)
|
||||||
|
|
||||||
assert y.shape[0] == x.shape[0]
|
assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}"
|
||||||
assert x.dtype == y.dtype
|
assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}"
|
||||||
# assert x.dtype == self.dtype
|
# assert x.dtype == self.dtype
|
||||||
emb = emb + self.label_emb(y)
|
emb = emb + self.label_emb(y)
|
||||||
|
|
||||||
@@ -1105,6 +1105,8 @@ class SdxlUNet2DConditionModel(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
import time
|
||||||
|
|
||||||
print("create unet")
|
print("create unet")
|
||||||
unet = SdxlUNet2DConditionModel()
|
unet = SdxlUNet2DConditionModel()
|
||||||
|
|
||||||
@@ -1132,8 +1134,11 @@ if __name__ == "__main__":
|
|||||||
print("start training")
|
print("start training")
|
||||||
steps = 10
|
steps = 10
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
|
|
||||||
for step in range(steps):
|
for step in range(steps):
|
||||||
print(f"step {step}")
|
print(f"step {step}")
|
||||||
|
if step == 1:
|
||||||
|
time_start = time.perf_counter()
|
||||||
|
|
||||||
x = torch.randn(batch_size, 4, 128, 128).cuda() # 1024x1024
|
x = torch.randn(batch_size, 4, 128, 128).cuda() # 1024x1024
|
||||||
t = torch.randint(low=0, high=10, size=(batch_size,), device="cuda")
|
t = torch.randint(low=0, high=10, size=(batch_size,), device="cuda")
|
||||||
@@ -1149,3 +1154,6 @@ if __name__ == "__main__":
|
|||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
optimizer.zero_grad(set_to_none=True)
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
|
time_end = time.perf_counter()
|
||||||
|
print(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps")
|
||||||
|
|||||||
384
library/sdxl_train_util.py
Normal file
384
library/sdxl_train_util.py
Normal file
@@ -0,0 +1,384 @@
|
|||||||
|
import argparse
|
||||||
|
import gc
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from typing import Any
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers import CLIPTokenizer
|
||||||
|
import open_clip
|
||||||
|
from library import model_util, sdxl_model_util, train_util
|
||||||
|
|
||||||
|
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
||||||
|
|
||||||
|
DEFAULT_NOISE_OFFSET = 0.0357
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: separate checkpoint for each U-Net/Text Encoder/VAE
|
||||||
|
def load_target_model(args, accelerator, model_version: str, weight_dtype):
|
||||||
|
# load models for each process
|
||||||
|
for pi in range(accelerator.state.num_processes):
|
||||||
|
if pi == accelerator.state.local_process_index:
|
||||||
|
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
|
||||||
|
|
||||||
|
(
|
||||||
|
load_stable_diffusion_format,
|
||||||
|
text_encoder1,
|
||||||
|
text_encoder2,
|
||||||
|
vae,
|
||||||
|
unet,
|
||||||
|
text_projection,
|
||||||
|
logit_scale,
|
||||||
|
ckpt_info,
|
||||||
|
) = _load_target_model(args, model_version, weight_dtype, accelerator.device if args.lowram else "cpu")
|
||||||
|
|
||||||
|
# work on low-ram device
|
||||||
|
if args.lowram:
|
||||||
|
text_encoder1.to(accelerator.device)
|
||||||
|
text_encoder2.to(accelerator.device)
|
||||||
|
unet.to(accelerator.device)
|
||||||
|
vae.to(accelerator.device)
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
|
text_encoder1, text_encoder2, unet = train_util.transform_models_if_DDP([text_encoder1, text_encoder2, unet])
|
||||||
|
|
||||||
|
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, text_projection, logit_scale, ckpt_info
|
||||||
|
|
||||||
|
|
||||||
|
def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtype, device="cpu"):
|
||||||
|
# only supports StableDiffusion
|
||||||
|
name_or_path = args.pretrained_model_name_or_path
|
||||||
|
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
|
||||||
|
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
|
||||||
|
assert (
|
||||||
|
load_stable_diffusion_format
|
||||||
|
), f"only supports StableDiffusion format for SDXL / SDXLではStableDiffusion形式のみサポートしています: {name_or_path}"
|
||||||
|
|
||||||
|
print(f"load StableDiffusion checkpoint: {name_or_path}")
|
||||||
|
(
|
||||||
|
text_encoder1,
|
||||||
|
text_encoder2,
|
||||||
|
vae,
|
||||||
|
unet,
|
||||||
|
text_projection,
|
||||||
|
logit_scale,
|
||||||
|
ckpt_info,
|
||||||
|
) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device)
|
||||||
|
|
||||||
|
# VAEを読み込む
|
||||||
|
if args.vae is not None:
|
||||||
|
vae = model_util.load_vae(args.vae, weight_dtype)
|
||||||
|
print("additional VAE loaded")
|
||||||
|
|
||||||
|
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, text_projection, logit_scale, ckpt_info
|
||||||
|
|
||||||
|
|
||||||
|
class WrapperTokenizer:
|
||||||
|
# open clipのtokenizerをHuggingFaceのtokenizerと同じ形で使えるようにする
|
||||||
|
def __init__(self):
|
||||||
|
open_clip_tokenizer = open_clip.tokenizer._tokenizer
|
||||||
|
self.model_max_length = 77
|
||||||
|
self.bos_token_id = open_clip_tokenizer.all_special_ids[0]
|
||||||
|
self.eos_token_id = open_clip_tokenizer.all_special_ids[1]
|
||||||
|
self.pad_token_id = 0 # 結果から推定している
|
||||||
|
|
||||||
|
def __call__(self, *args: Any, **kwds: Any) -> Any:
|
||||||
|
return self.tokenize(*args, **kwds)
|
||||||
|
|
||||||
|
def tokenize(self, text, padding, truncation, max_length, return_tensors):
|
||||||
|
assert padding == "max_length"
|
||||||
|
assert truncation == True
|
||||||
|
assert return_tensors == "pt"
|
||||||
|
input_ids = open_clip.tokenize(text, context_length=max_length)
|
||||||
|
return SimpleNamespace(**{"input_ids": input_ids})
|
||||||
|
|
||||||
|
|
||||||
|
def load_tokenizers(args: argparse.Namespace):
|
||||||
|
print("prepare tokenizers")
|
||||||
|
original_path = TOKENIZER_PATH
|
||||||
|
|
||||||
|
tokenizer1: CLIPTokenizer = None
|
||||||
|
if args.tokenizer_cache_dir:
|
||||||
|
local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_"))
|
||||||
|
if os.path.exists(local_tokenizer_path):
|
||||||
|
print(f"load tokenizer from cache: {local_tokenizer_path}")
|
||||||
|
tokenizer1 = CLIPTokenizer.from_pretrained(local_tokenizer_path)
|
||||||
|
|
||||||
|
if tokenizer1 is None:
|
||||||
|
tokenizer1 = CLIPTokenizer.from_pretrained(original_path)
|
||||||
|
|
||||||
|
if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
|
||||||
|
print(f"save Tokenizer to cache: {local_tokenizer_path}")
|
||||||
|
tokenizer1.save_pretrained(local_tokenizer_path)
|
||||||
|
|
||||||
|
if hasattr(args, "max_token_length") and args.max_token_length is not None:
|
||||||
|
print(f"update token length: {args.max_token_length}")
|
||||||
|
|
||||||
|
# tokenizer2 is from open_clip
|
||||||
|
# TODO caching
|
||||||
|
tokenizer2 = WrapperTokenizer()
|
||||||
|
|
||||||
|
return [tokenizer1, tokenizer2]
|
||||||
|
|
||||||
|
|
||||||
|
def get_hidden_states(
|
||||||
|
args: argparse.Namespace, input_ids1, input_ids2, tokenizer1, tokenizer2, text_encoder1, text_encoder2, weight_dtype=None
|
||||||
|
):
|
||||||
|
input_ids1 = input_ids1.reshape((-1, tokenizer1.model_max_length)) # batch_size*n, 77
|
||||||
|
input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) # batch_size*n, 77
|
||||||
|
|
||||||
|
# text_encoder1
|
||||||
|
enc_out = text_encoder1(input_ids1, output_hidden_states=True, return_dict=True)
|
||||||
|
hidden_states1 = enc_out["hidden_states"][11]
|
||||||
|
|
||||||
|
# text_encoder2
|
||||||
|
enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True)
|
||||||
|
hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer
|
||||||
|
pool2 = enc_out["pooler_output"]
|
||||||
|
|
||||||
|
if args.max_token_length is not None:
|
||||||
|
# bs*3, 77, 768 or 1024
|
||||||
|
# encoder1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
|
||||||
|
states_list = [hidden_states1[:, 0].unsqueeze(1)] # <BOS>
|
||||||
|
for i in range(1, args.max_token_length, tokenizer1.model_max_length):
|
||||||
|
states_list.append(hidden_states1[:, i : i + tokenizer1.model_max_length - 2]) # <BOS> の後から <EOS> の前まで
|
||||||
|
states_list.append(hidden_states1[:, -1].unsqueeze(1)) # <EOS>
|
||||||
|
hidden_states1 = torch.cat(states_list, dim=1)
|
||||||
|
|
||||||
|
# v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
|
||||||
|
states_list = [hidden_states2[:, 0].unsqueeze(1)] # <BOS>
|
||||||
|
for i in range(1, args.max_token_length, tokenizer2.model_max_length):
|
||||||
|
chunk = hidden_states2[:, i : i + tokenizer2.model_max_length - 2] # <BOS> の後から 最後の前まで
|
||||||
|
if i > 0:
|
||||||
|
for j in range(len(chunk)):
|
||||||
|
if input_ids2[j, 1] == tokenizer2.eos_token: # 空、つまり <BOS> <EOS> <PAD> ...のパターン
|
||||||
|
chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
|
||||||
|
states_list.append(chunk) # <BOS> の後から <EOS> の前まで
|
||||||
|
states_list.append(hidden_states2[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
|
||||||
|
hidden_states2 = torch.cat(states_list, dim=1)
|
||||||
|
|
||||||
|
if weight_dtype is not None:
|
||||||
|
# this is required for additional network training
|
||||||
|
hidden_states1 = hidden_states1.to(weight_dtype)
|
||||||
|
hidden_states2 = hidden_states2.to(weight_dtype)
|
||||||
|
|
||||||
|
return hidden_states1, hidden_states2, pool2
|
||||||
|
|
||||||
|
|
||||||
|
def timestep_embedding(timesteps, dim, max_period=10000):
|
||||||
|
"""
|
||||||
|
Create sinusoidal timestep embeddings.
|
||||||
|
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||||
|
These may be fractional.
|
||||||
|
:param dim: the dimension of the output.
|
||||||
|
:param max_period: controls the minimum frequency of the embeddings.
|
||||||
|
:return: an [N x dim] Tensor of positional embeddings.
|
||||||
|
"""
|
||||||
|
half = dim // 2
|
||||||
|
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
||||||
|
device=timesteps.device
|
||||||
|
)
|
||||||
|
args = timesteps[:, None].float() * freqs[None]
|
||||||
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||||
|
if dim % 2:
|
||||||
|
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
|
||||||
|
def get_timestep_embedding(x, outdim):
|
||||||
|
assert len(x.shape) == 2
|
||||||
|
b, dims = x.shape[0], x.shape[1]
|
||||||
|
x = torch.flatten(x)
|
||||||
|
emb = timestep_embedding(x, outdim)
|
||||||
|
emb = torch.reshape(emb, (b, dims * outdim))
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
def get_size_embeddings(orig_size, crop_size, target_size, device):
|
||||||
|
emb1 = get_timestep_embedding(orig_size, 256)
|
||||||
|
emb2 = get_timestep_embedding(crop_size, 256)
|
||||||
|
emb3 = get_timestep_embedding(target_size, 256)
|
||||||
|
vector = torch.cat([emb1, emb2, emb3], dim=1).to(device)
|
||||||
|
return vector
|
||||||
|
|
||||||
|
|
||||||
|
def save_sd_model_on_train_end(
|
||||||
|
args: argparse.Namespace,
|
||||||
|
src_path: str,
|
||||||
|
save_stable_diffusion_format: bool,
|
||||||
|
use_safetensors: bool,
|
||||||
|
save_dtype: torch.dtype,
|
||||||
|
epoch: int,
|
||||||
|
global_step: int,
|
||||||
|
text_encoder1,
|
||||||
|
text_encoder2,
|
||||||
|
unet,
|
||||||
|
vae,
|
||||||
|
text_projection,
|
||||||
|
logit_scale,
|
||||||
|
ckpt_info,
|
||||||
|
):
|
||||||
|
def sd_saver(ckpt_file, epoch_no, global_step):
|
||||||
|
sdxl_model_util.save_stable_diffusion_checkpoint(
|
||||||
|
ckpt_file,
|
||||||
|
text_encoder1,
|
||||||
|
text_encoder2,
|
||||||
|
unet,
|
||||||
|
epoch_no,
|
||||||
|
global_step,
|
||||||
|
ckpt_info,
|
||||||
|
vae,
|
||||||
|
text_projection,
|
||||||
|
logit_scale,
|
||||||
|
save_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
def diffusers_saver(out_dir):
|
||||||
|
raise NotImplementedError("diffusers_saver is not implemented")
|
||||||
|
|
||||||
|
train_util.save_sd_model_on_train_end_common(
|
||||||
|
args, save_stable_diffusion_format, use_safetensors, epoch, global_step, sd_saver, diffusers_saver
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している
|
||||||
|
# on_epoch_end: Trueならepoch終了時、Falseならstep経過時
|
||||||
|
def save_sd_model_on_epoch_end_or_stepwise(
|
||||||
|
args: argparse.Namespace,
|
||||||
|
on_epoch_end: bool,
|
||||||
|
accelerator,
|
||||||
|
src_path,
|
||||||
|
save_stable_diffusion_format: bool,
|
||||||
|
use_safetensors: bool,
|
||||||
|
save_dtype: torch.dtype,
|
||||||
|
epoch: int,
|
||||||
|
num_train_epochs: int,
|
||||||
|
global_step: int,
|
||||||
|
text_encoder1,
|
||||||
|
text_encoder2,
|
||||||
|
unet,
|
||||||
|
vae,
|
||||||
|
text_projection,
|
||||||
|
logit_scale,
|
||||||
|
ckpt_info,
|
||||||
|
):
|
||||||
|
def sd_saver(ckpt_file, epoch_no, global_step):
|
||||||
|
sdxl_model_util.save_stable_diffusion_checkpoint(
|
||||||
|
ckpt_file,
|
||||||
|
text_encoder1,
|
||||||
|
text_encoder2,
|
||||||
|
unet,
|
||||||
|
epoch_no,
|
||||||
|
global_step,
|
||||||
|
ckpt_info,
|
||||||
|
vae,
|
||||||
|
text_projection,
|
||||||
|
logit_scale,
|
||||||
|
save_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
def diffusers_saver(out_dir):
|
||||||
|
raise NotImplementedError("diffusers_saver is not implemented")
|
||||||
|
|
||||||
|
train_util.save_sd_model_on_epoch_end_or_stepwise_common(
|
||||||
|
args,
|
||||||
|
on_epoch_end,
|
||||||
|
accelerator,
|
||||||
|
save_stable_diffusion_format,
|
||||||
|
use_safetensors,
|
||||||
|
epoch,
|
||||||
|
num_train_epochs,
|
||||||
|
global_step,
|
||||||
|
sd_saver,
|
||||||
|
diffusers_saver,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# TextEncoderの出力をキャッシュする
|
||||||
|
# weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる
|
||||||
|
def cache_text_encoder_outputs(args, accelerator, tokenizers, text_encoders, data_loader, weight_dtype):
|
||||||
|
print("caching text encoder outputs")
|
||||||
|
|
||||||
|
tokenizer1, tokenizer2 = tokenizers
|
||||||
|
text_encoder1, text_encoder2 = text_encoders
|
||||||
|
text_encoder1.to(accelerator.device)
|
||||||
|
text_encoder2.to(accelerator.device)
|
||||||
|
if weight_dtype is not None:
|
||||||
|
text_encoder1.to(dtype=weight_dtype)
|
||||||
|
text_encoder2.to(dtype=weight_dtype)
|
||||||
|
|
||||||
|
text_encoder1_cache = {}
|
||||||
|
text_encoder2_cache = {}
|
||||||
|
for batch in tqdm(data_loader):
|
||||||
|
input_ids1_batch = batch["input_ids"]
|
||||||
|
input_ids2_batch = batch["input_ids2"]
|
||||||
|
|
||||||
|
# split batch to avoid OOM
|
||||||
|
# TODO specify batch size by args
|
||||||
|
for input_ids1, input_ids2 in zip(input_ids1_batch.split(1), input_ids2_batch.split(1)):
|
||||||
|
# remove input_ids already in cache
|
||||||
|
input_ids1 = input_ids1.squeeze(0)
|
||||||
|
input_ids2 = input_ids2.squeeze(0)
|
||||||
|
input_ids1 = [i for i in input_ids1 if i not in text_encoder1_cache]
|
||||||
|
input_ids2 = [i for i in input_ids2 if i not in text_encoder2_cache]
|
||||||
|
assert len(input_ids1) == len(input_ids2)
|
||||||
|
if len(input_ids1) == 0:
|
||||||
|
continue
|
||||||
|
input_ids1 = torch.stack(input_ids1).to(accelerator.device)
|
||||||
|
input_ids2 = torch.stack(input_ids2).to(accelerator.device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
encoder_hidden_states1, encoder_hidden_states2, pool2 = get_hidden_states(
|
||||||
|
args,
|
||||||
|
input_ids1,
|
||||||
|
input_ids2,
|
||||||
|
tokenizer1,
|
||||||
|
tokenizer2,
|
||||||
|
text_encoder1,
|
||||||
|
text_encoder2,
|
||||||
|
None if not args.full_fp16 else weight_dtype,
|
||||||
|
)
|
||||||
|
encoder_hidden_states1 = encoder_hidden_states1.detach().to("cpu")
|
||||||
|
encoder_hidden_states2 = encoder_hidden_states2.detach().to("cpu")
|
||||||
|
pool2 = pool2.to("cpu")
|
||||||
|
for input_id1, input_id2, hidden_states1, hidden_states2, p2 in zip(
|
||||||
|
input_ids1, input_ids2, encoder_hidden_states1, encoder_hidden_states2, pool2
|
||||||
|
):
|
||||||
|
text_encoder1_cache[tuple(input_id1.tolist())] = hidden_states1
|
||||||
|
text_encoder2_cache[tuple(input_id2.tolist())] = (hidden_states2, p2)
|
||||||
|
return text_encoder1_cache, text_encoder2_cache
|
||||||
|
|
||||||
|
|
||||||
|
def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
|
||||||
|
parser.add_argument(
|
||||||
|
"--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def verify_sdxl_training_args(args: argparse.Namespace):
|
||||||
|
assert (
|
||||||
|
not args.v2 and not args.v_parameterization
|
||||||
|
), "v2 or v_parameterization cannot be enabled in SDXL training / SDXL学習ではv2とv_parameterizationを有効にすることはできません"
|
||||||
|
if args.clip_skip is not None:
|
||||||
|
print("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません")
|
||||||
|
|
||||||
|
if args.multires_noise_iterations:
|
||||||
|
print(
|
||||||
|
f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET}, but noise_offset is disabled due to multires_noise_iterations / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されていますが、multires_noise_iterationsが有効になっているためnoise_offsetは無効になります"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if args.noise_offset is None:
|
||||||
|
args.noise_offset = DEFAULT_NOISE_OFFSET
|
||||||
|
elif args.noise_offset != DEFAULT_NOISE_OFFSET:
|
||||||
|
print(
|
||||||
|
f"Waring: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET} / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されています"
|
||||||
|
)
|
||||||
|
print(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました")
|
||||||
|
|
||||||
|
assert (
|
||||||
|
not args.weighted_captions
|
||||||
|
), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"
|
||||||
@@ -798,6 +798,19 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
def is_latent_cacheable(self):
|
def is_latent_cacheable(self):
|
||||||
return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])
|
return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])
|
||||||
|
|
||||||
|
def is_text_encoder_output_cacheable(self):
|
||||||
|
return all(
|
||||||
|
[
|
||||||
|
not (
|
||||||
|
subset.caption_dropout_rate > 0
|
||||||
|
or subset.shuffle_caption
|
||||||
|
or subset.token_warmup_step > 0
|
||||||
|
or subset.caption_tag_dropout_rate > 0
|
||||||
|
)
|
||||||
|
for subset in self.subsets
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
def is_disk_cached_latents_is_expected(self, reso, npz_path, flipped_npz_path):
|
def is_disk_cached_latents_is_expected(self, reso, npz_path, flipped_npz_path):
|
||||||
expected_latents_size = (reso[1] // 8, reso[0] // 8) # bucket_resoはWxHなので注意
|
expected_latents_size = (reso[1] // 8, reso[0] // 8) # bucket_resoはWxHなので注意
|
||||||
|
|
||||||
@@ -850,7 +863,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
cache_available = self.is_disk_cached_latents_is_expected(
|
cache_available = self.is_disk_cached_latents_is_expected(
|
||||||
info.bucket_reso, info.latents_npz, info.latents_npz_flipped if self.flip_aug else None
|
info.bucket_reso, info.latents_npz, info.latents_npz_flipped if subset.flip_aug else None
|
||||||
)
|
)
|
||||||
|
|
||||||
if cache_available: # do not add to batch
|
if cache_available: # do not add to batch
|
||||||
@@ -1719,6 +1732,9 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
|||||||
def is_latent_cacheable(self) -> bool:
|
def is_latent_cacheable(self) -> bool:
|
||||||
return all([dataset.is_latent_cacheable() for dataset in self.datasets])
|
return all([dataset.is_latent_cacheable() for dataset in self.datasets])
|
||||||
|
|
||||||
|
def is_text_encoder_output_cacheable(self) -> bool:
|
||||||
|
return all([dataset.is_text_encoder_output_cacheable() for dataset in self.datasets])
|
||||||
|
|
||||||
def set_current_epoch(self, epoch):
|
def set_current_epoch(self, epoch):
|
||||||
for dataset in self.datasets:
|
for dataset in self.datasets:
|
||||||
dataset.set_current_epoch(epoch)
|
dataset.set_current_epoch(epoch)
|
||||||
@@ -3284,11 +3300,17 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une
|
|||||||
return text_encoder, vae, unet, load_stable_diffusion_format
|
return text_encoder, vae, unet, load_stable_diffusion_format
|
||||||
|
|
||||||
|
|
||||||
|
# TODO remove this function in the future
|
||||||
def transform_if_model_is_DDP(text_encoder, unet, network=None):
|
def transform_if_model_is_DDP(text_encoder, unet, network=None):
|
||||||
# Transform text_encoder, unet and network from DistributedDataParallel
|
# Transform text_encoder, unet and network from DistributedDataParallel
|
||||||
return (model.module if type(model) == DDP else model for model in [text_encoder, unet, network] if model is not None)
|
return (model.module if type(model) == DDP else model for model in [text_encoder, unet, network] if model is not None)
|
||||||
|
|
||||||
|
|
||||||
|
def transform_models_if_DDP(models):
|
||||||
|
# Transform text_encoder, unet and network from DistributedDataParallel
|
||||||
|
return [model.module if type(model) == DDP else model for model in models if model is not None]
|
||||||
|
|
||||||
|
|
||||||
def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=False):
|
def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=False):
|
||||||
# load models for each process
|
# load models for each process
|
||||||
for pi in range(accelerator.state.num_processes):
|
for pi in range(accelerator.state.num_processes):
|
||||||
@@ -3430,6 +3452,42 @@ def save_sd_model_on_epoch_end_or_stepwise(
|
|||||||
text_encoder,
|
text_encoder,
|
||||||
unet,
|
unet,
|
||||||
vae,
|
vae,
|
||||||
|
):
|
||||||
|
def sd_saver(ckpt_file, epoch_no, global_step):
|
||||||
|
model_util.save_stable_diffusion_checkpoint(
|
||||||
|
args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae
|
||||||
|
)
|
||||||
|
|
||||||
|
def diffusers_saver(out_dir):
|
||||||
|
model_util.save_diffusers_checkpoint(
|
||||||
|
args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors
|
||||||
|
)
|
||||||
|
|
||||||
|
save_sd_model_on_epoch_end_or_stepwise_common(
|
||||||
|
args,
|
||||||
|
on_epoch_end,
|
||||||
|
accelerator,
|
||||||
|
save_stable_diffusion_format,
|
||||||
|
use_safetensors,
|
||||||
|
epoch,
|
||||||
|
num_train_epochs,
|
||||||
|
global_step,
|
||||||
|
sd_saver,
|
||||||
|
diffusers_saver,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def save_sd_model_on_epoch_end_or_stepwise_common(
|
||||||
|
args: argparse.Namespace,
|
||||||
|
on_epoch_end: bool,
|
||||||
|
accelerator,
|
||||||
|
save_stable_diffusion_format: bool,
|
||||||
|
use_safetensors: bool,
|
||||||
|
epoch: int,
|
||||||
|
num_train_epochs: int,
|
||||||
|
global_step: int,
|
||||||
|
sd_saver,
|
||||||
|
diffusers_saver,
|
||||||
):
|
):
|
||||||
if on_epoch_end:
|
if on_epoch_end:
|
||||||
epoch_no = epoch + 1
|
epoch_no = epoch + 1
|
||||||
@@ -3457,9 +3515,7 @@ def save_sd_model_on_epoch_end_or_stepwise(
|
|||||||
|
|
||||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||||
print(f"\nsaving checkpoint: {ckpt_file}")
|
print(f"\nsaving checkpoint: {ckpt_file}")
|
||||||
model_util.save_stable_diffusion_checkpoint(
|
sd_saver(ckpt_file, epoch_no, global_step)
|
||||||
args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae
|
|
||||||
)
|
|
||||||
|
|
||||||
if args.huggingface_repo_id is not None:
|
if args.huggingface_repo_id is not None:
|
||||||
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name)
|
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name)
|
||||||
@@ -3483,9 +3539,8 @@ def save_sd_model_on_epoch_end_or_stepwise(
|
|||||||
out_dir = os.path.join(args.output_dir, STEP_DIFFUSERS_DIR_NAME.format(model_name, global_step))
|
out_dir = os.path.join(args.output_dir, STEP_DIFFUSERS_DIR_NAME.format(model_name, global_step))
|
||||||
|
|
||||||
print(f"\nsaving model: {out_dir}")
|
print(f"\nsaving model: {out_dir}")
|
||||||
model_util.save_diffusers_checkpoint(
|
diffusers_saver(out_dir)
|
||||||
args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors
|
|
||||||
)
|
|
||||||
if args.huggingface_repo_id is not None:
|
if args.huggingface_repo_id is not None:
|
||||||
huggingface_util.upload(args, out_dir, "/" + model_name)
|
huggingface_util.upload(args, out_dir, "/" + model_name)
|
||||||
|
|
||||||
@@ -3578,6 +3633,30 @@ def save_sd_model_on_train_end(
|
|||||||
text_encoder,
|
text_encoder,
|
||||||
unet,
|
unet,
|
||||||
vae,
|
vae,
|
||||||
|
):
|
||||||
|
def sd_saver(ckpt_file, epoch_no, global_step):
|
||||||
|
model_util.save_stable_diffusion_checkpoint(
|
||||||
|
args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae
|
||||||
|
)
|
||||||
|
|
||||||
|
def diffusers_saver(out_dir):
|
||||||
|
model_util.save_diffusers_checkpoint(
|
||||||
|
args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors
|
||||||
|
)
|
||||||
|
|
||||||
|
save_sd_model_on_train_end_common(
|
||||||
|
args, save_stable_diffusion_format, use_safetensors, epoch, global_step, sd_saver, diffusers_saver
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def save_sd_model_on_train_end_common(
|
||||||
|
args: argparse.Namespace,
|
||||||
|
save_stable_diffusion_format: bool,
|
||||||
|
use_safetensors: bool,
|
||||||
|
epoch: int,
|
||||||
|
global_step: int,
|
||||||
|
sd_saver,
|
||||||
|
diffusers_saver,
|
||||||
):
|
):
|
||||||
model_name = default_if_none(args.output_name, DEFAULT_LAST_OUTPUT_NAME)
|
model_name = default_if_none(args.output_name, DEFAULT_LAST_OUTPUT_NAME)
|
||||||
|
|
||||||
@@ -3588,9 +3667,8 @@ def save_sd_model_on_train_end(
|
|||||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||||
|
|
||||||
print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}")
|
print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}")
|
||||||
model_util.save_stable_diffusion_checkpoint(
|
sd_saver(ckpt_file, epoch, global_step)
|
||||||
args.v2, ckpt_file, text_encoder, unet, src_path, epoch, global_step, save_dtype, vae
|
|
||||||
)
|
|
||||||
if args.huggingface_repo_id is not None:
|
if args.huggingface_repo_id is not None:
|
||||||
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True)
|
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True)
|
||||||
else:
|
else:
|
||||||
@@ -3598,9 +3676,8 @@ def save_sd_model_on_train_end(
|
|||||||
os.makedirs(out_dir, exist_ok=True)
|
os.makedirs(out_dir, exist_ok=True)
|
||||||
|
|
||||||
print(f"save trained model as Diffusers to {out_dir}")
|
print(f"save trained model as Diffusers to {out_dir}")
|
||||||
model_util.save_diffusers_checkpoint(
|
diffusers_saver(out_dir)
|
||||||
args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors
|
|
||||||
)
|
|
||||||
if args.huggingface_repo_id is not None:
|
if args.huggingface_repo_id is not None:
|
||||||
huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True)
|
huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True)
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,9 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from typing import List, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||||
|
from diffusers import AutoencoderKL
|
||||||
|
from transformers import CLIPTextModel
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import re
|
import re
|
||||||
@@ -400,7 +402,16 @@ def parse_block_lr_kwargs(nw_kwargs):
|
|||||||
return down_lr_weight, mid_lr_weight, up_lr_weight
|
return down_lr_weight, mid_lr_weight, up_lr_weight
|
||||||
|
|
||||||
|
|
||||||
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, neuron_dropout=None, **kwargs):
|
def create_network(
|
||||||
|
multiplier: float,
|
||||||
|
network_dim: Optional[int],
|
||||||
|
network_alpha: Optional[float],
|
||||||
|
vae: AutoencoderKL,
|
||||||
|
text_encoder: Union[CLIPTextModel, List[CLIPTextModel]],
|
||||||
|
unet,
|
||||||
|
neuron_dropout: Optional[float] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
if network_dim is None:
|
if network_dim is None:
|
||||||
network_dim = 4 # default
|
network_dim = 4 # default
|
||||||
if network_alpha is None:
|
if network_alpha is None:
|
||||||
@@ -719,33 +730,36 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
|
|||||||
class LoRANetwork(torch.nn.Module):
|
class LoRANetwork(torch.nn.Module):
|
||||||
NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数
|
NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数
|
||||||
|
|
||||||
# is it possible to apply conv_in and conv_out? -> yes, newer LoCon supports it (^^;)
|
|
||||||
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
||||||
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
||||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
||||||
LORA_PREFIX_UNET = "lora_unet"
|
LORA_PREFIX_UNET = "lora_unet"
|
||||||
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
||||||
|
|
||||||
|
# SDXL: must starts with LORA_PREFIX_TEXT_ENCODER
|
||||||
|
LORA_PREFIX_TEXT_ENCODER1 = "lora_te1"
|
||||||
|
LORA_PREFIX_TEXT_ENCODER2 = "lora_te2"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
text_encoder,
|
text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
|
||||||
unet,
|
unet,
|
||||||
multiplier=1.0,
|
multiplier: float = 1.0,
|
||||||
lora_dim=4,
|
lora_dim: int = 4,
|
||||||
alpha=1,
|
alpha: float = 1,
|
||||||
dropout=None,
|
dropout: Optional[float] = None,
|
||||||
rank_dropout=None,
|
rank_dropout: Optional[float] = None,
|
||||||
module_dropout=None,
|
module_dropout: Optional[float] = None,
|
||||||
conv_lora_dim=None,
|
conv_lora_dim: Optional[int] = None,
|
||||||
conv_alpha=None,
|
conv_alpha: Optional[float] = None,
|
||||||
block_dims=None,
|
block_dims: Optional[List[int]] = None,
|
||||||
block_alphas=None,
|
block_alphas: Optional[List[float]] = None,
|
||||||
conv_block_dims=None,
|
conv_block_dims: Optional[List[int]] = None,
|
||||||
conv_block_alphas=None,
|
conv_block_alphas: Optional[List[float]] = None,
|
||||||
modules_dim=None,
|
modules_dim: Optional[Dict[str, int]] = None,
|
||||||
modules_alpha=None,
|
modules_alpha: Optional[Dict[str, int]] = None,
|
||||||
module_class=LoRAModule,
|
module_class: Type[object] = LoRAModule,
|
||||||
varbose=False,
|
varbose: Optional[bool] = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
LoRA network: すごく引数が多いが、パターンは以下の通り
|
LoRA network: すごく引数が多いが、パターンは以下の通り
|
||||||
@@ -783,8 +797,21 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
|
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
|
||||||
|
|
||||||
# create module instances
|
# create module instances
|
||||||
def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]:
|
def create_modules(
|
||||||
prefix = LoRANetwork.LORA_PREFIX_UNET if is_unet else LoRANetwork.LORA_PREFIX_TEXT_ENCODER
|
is_unet: bool,
|
||||||
|
text_encoder_idx: Optional[int], # None, 1, 2
|
||||||
|
root_module: torch.nn.Module,
|
||||||
|
target_replace_modules: List[torch.nn.Module],
|
||||||
|
) -> List[LoRAModule]:
|
||||||
|
prefix = (
|
||||||
|
self.LORA_PREFIX_UNET
|
||||||
|
if is_unet
|
||||||
|
else (
|
||||||
|
self.LORA_PREFIX_TEXT_ENCODER
|
||||||
|
if text_encoder_idx is None
|
||||||
|
else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2)
|
||||||
|
)
|
||||||
|
)
|
||||||
loras = []
|
loras = []
|
||||||
skipped = []
|
skipped = []
|
||||||
for name, module in root_module.named_modules():
|
for name, module in root_module.named_modules():
|
||||||
@@ -800,11 +827,14 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
|
|
||||||
dim = None
|
dim = None
|
||||||
alpha = None
|
alpha = None
|
||||||
|
|
||||||
if modules_dim is not None:
|
if modules_dim is not None:
|
||||||
|
# モジュール指定あり
|
||||||
if lora_name in modules_dim:
|
if lora_name in modules_dim:
|
||||||
dim = modules_dim[lora_name]
|
dim = modules_dim[lora_name]
|
||||||
alpha = modules_alpha[lora_name]
|
alpha = modules_alpha[lora_name]
|
||||||
elif is_unet and block_dims is not None:
|
elif is_unet and block_dims is not None:
|
||||||
|
# U-Netでblock_dims指定あり
|
||||||
block_idx = get_block_index(lora_name)
|
block_idx = get_block_index(lora_name)
|
||||||
if is_linear or is_conv2d_1x1:
|
if is_linear or is_conv2d_1x1:
|
||||||
dim = block_dims[block_idx]
|
dim = block_dims[block_idx]
|
||||||
@@ -813,6 +843,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
dim = conv_block_dims[block_idx]
|
dim = conv_block_dims[block_idx]
|
||||||
alpha = conv_block_alphas[block_idx]
|
alpha = conv_block_alphas[block_idx]
|
||||||
else:
|
else:
|
||||||
|
# 通常、すべて対象とする
|
||||||
if is_linear or is_conv2d_1x1:
|
if is_linear or is_conv2d_1x1:
|
||||||
dim = self.lora_dim
|
dim = self.lora_dim
|
||||||
alpha = self.alpha
|
alpha = self.alpha
|
||||||
@@ -821,6 +852,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
alpha = self.conv_alpha
|
alpha = self.conv_alpha
|
||||||
|
|
||||||
if dim is None or dim == 0:
|
if dim is None or dim == 0:
|
||||||
|
# skipした情報を出力
|
||||||
if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None or conv_block_dims is not None):
|
if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None or conv_block_dims is not None):
|
||||||
skipped.append(lora_name)
|
skipped.append(lora_name)
|
||||||
continue
|
continue
|
||||||
@@ -838,7 +870,16 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
loras.append(lora)
|
loras.append(lora)
|
||||||
return loras, skipped
|
return loras, skipped
|
||||||
|
|
||||||
self.text_encoder_loras, skipped_te = create_modules(False, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
|
||||||
|
|
||||||
|
# create LoRA for text encoder
|
||||||
|
# 毎回すべてのモジュールを作るのは無駄なので要検討
|
||||||
|
self.text_encoder_loras = []
|
||||||
|
skipped_te = []
|
||||||
|
for i, text_encoder in enumerate(text_encoders):
|
||||||
|
text_encoder_loras, skipped = create_modules(False, i + 1, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||||||
|
self.text_encoder_loras.extend(text_encoder_loras)
|
||||||
|
skipped_te += skipped
|
||||||
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
||||||
|
|
||||||
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
||||||
@@ -846,7 +887,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None:
|
if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None:
|
||||||
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||||
|
|
||||||
self.unet_loras, skipped_un = create_modules(True, unet, target_modules)
|
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
|
||||||
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
||||||
|
|
||||||
skipped = skipped_te + skipped_un
|
skipped = skipped_te + skipped_un
|
||||||
@@ -961,6 +1002,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
|
|
||||||
return lr_weight
|
return lr_weight
|
||||||
|
|
||||||
|
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
|
||||||
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
||||||
self.requires_grad_(True)
|
self.requires_grad_(True)
|
||||||
all_params = []
|
all_params = []
|
||||||
|
|||||||
258
networks/sdxl_merge_lora.py
Normal file
258
networks/sdxl_merge_lora.py
Normal file
@@ -0,0 +1,258 @@
|
|||||||
|
import math
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from safetensors.torch import load_file, save_file
|
||||||
|
from tqdm import tqdm
|
||||||
|
from library import sdxl_model_util
|
||||||
|
import library.model_util as model_util
|
||||||
|
import lora
|
||||||
|
|
||||||
|
|
||||||
|
def load_state_dict(file_name, dtype):
|
||||||
|
if os.path.splitext(file_name)[1] == ".safetensors":
|
||||||
|
sd = load_file(file_name)
|
||||||
|
else:
|
||||||
|
sd = torch.load(file_name, map_location="cpu")
|
||||||
|
for key in list(sd.keys()):
|
||||||
|
if type(sd[key]) == torch.Tensor:
|
||||||
|
sd[key] = sd[key].to(dtype)
|
||||||
|
return sd
|
||||||
|
|
||||||
|
|
||||||
|
def save_to_file(file_name, model, state_dict, dtype):
|
||||||
|
if dtype is not None:
|
||||||
|
for key in list(state_dict.keys()):
|
||||||
|
if type(state_dict[key]) == torch.Tensor:
|
||||||
|
state_dict[key] = state_dict[key].to(dtype)
|
||||||
|
|
||||||
|
if os.path.splitext(file_name)[1] == ".safetensors":
|
||||||
|
save_file(model, file_name)
|
||||||
|
else:
|
||||||
|
torch.save(model, file_name)
|
||||||
|
|
||||||
|
|
||||||
|
def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_dtype):
|
||||||
|
text_encoder1.to(merge_dtype)
|
||||||
|
text_encoder1.to(merge_dtype)
|
||||||
|
unet.to(merge_dtype)
|
||||||
|
|
||||||
|
# create module map
|
||||||
|
name_to_module = {}
|
||||||
|
for i, root_module in enumerate([text_encoder1, text_encoder2, unet]):
|
||||||
|
if i <= 1:
|
||||||
|
if i == 0:
|
||||||
|
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1
|
||||||
|
else:
|
||||||
|
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER2
|
||||||
|
target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
|
||||||
|
else:
|
||||||
|
prefix = lora.LoRANetwork.LORA_PREFIX_UNET
|
||||||
|
target_replace_modules = (
|
||||||
|
lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||||
|
)
|
||||||
|
|
||||||
|
for name, module in root_module.named_modules():
|
||||||
|
if module.__class__.__name__ in target_replace_modules:
|
||||||
|
for child_name, child_module in module.named_modules():
|
||||||
|
if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d":
|
||||||
|
lora_name = prefix + "." + name + "." + child_name
|
||||||
|
lora_name = lora_name.replace(".", "_")
|
||||||
|
name_to_module[lora_name] = child_module
|
||||||
|
|
||||||
|
for model, ratio in zip(models, ratios):
|
||||||
|
print(f"loading: {model}")
|
||||||
|
lora_sd = load_state_dict(model, merge_dtype)
|
||||||
|
|
||||||
|
print(f"merging...")
|
||||||
|
for key in tqdm(lora_sd.keys()):
|
||||||
|
if "lora_down" in key:
|
||||||
|
up_key = key.replace("lora_down", "lora_up")
|
||||||
|
alpha_key = key[: key.index("lora_down")] + "alpha"
|
||||||
|
|
||||||
|
# find original module for this lora
|
||||||
|
module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
|
||||||
|
if module_name not in name_to_module:
|
||||||
|
print(f"no module found for LoRA weight: {key}")
|
||||||
|
continue
|
||||||
|
module = name_to_module[module_name]
|
||||||
|
# print(f"apply {key} to {module}")
|
||||||
|
|
||||||
|
down_weight = lora_sd[key]
|
||||||
|
up_weight = lora_sd[up_key]
|
||||||
|
|
||||||
|
dim = down_weight.size()[0]
|
||||||
|
alpha = lora_sd.get(alpha_key, dim)
|
||||||
|
scale = alpha / dim
|
||||||
|
|
||||||
|
# W <- W + U * D
|
||||||
|
weight = module.weight
|
||||||
|
# print(module_name, down_weight.size(), up_weight.size())
|
||||||
|
if len(weight.size()) == 2:
|
||||||
|
# linear
|
||||||
|
weight = weight + ratio * (up_weight @ down_weight) * scale
|
||||||
|
elif down_weight.size()[2:4] == (1, 1):
|
||||||
|
# conv2d 1x1
|
||||||
|
weight = (
|
||||||
|
weight
|
||||||
|
+ ratio
|
||||||
|
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||||
|
* scale
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# conv2d 3x3
|
||||||
|
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||||
|
# print(conved.size(), weight.size(), module.stride, module.padding)
|
||||||
|
weight = weight + ratio * conved * scale
|
||||||
|
|
||||||
|
module.weight = torch.nn.Parameter(weight)
|
||||||
|
|
||||||
|
|
||||||
|
def merge_lora_models(models, ratios, merge_dtype):
|
||||||
|
base_alphas = {} # alpha for merged model
|
||||||
|
base_dims = {}
|
||||||
|
|
||||||
|
merged_sd = {}
|
||||||
|
for model, ratio in zip(models, ratios):
|
||||||
|
print(f"loading: {model}")
|
||||||
|
lora_sd = load_state_dict(model, merge_dtype)
|
||||||
|
|
||||||
|
# get alpha and dim
|
||||||
|
alphas = {} # alpha for current model
|
||||||
|
dims = {} # dims for current model
|
||||||
|
for key in lora_sd.keys():
|
||||||
|
if "alpha" in key:
|
||||||
|
lora_module_name = key[: key.rfind(".alpha")]
|
||||||
|
alpha = float(lora_sd[key].detach().numpy())
|
||||||
|
alphas[lora_module_name] = alpha
|
||||||
|
if lora_module_name not in base_alphas:
|
||||||
|
base_alphas[lora_module_name] = alpha
|
||||||
|
elif "lora_down" in key:
|
||||||
|
lora_module_name = key[: key.rfind(".lora_down")]
|
||||||
|
dim = lora_sd[key].size()[0]
|
||||||
|
dims[lora_module_name] = dim
|
||||||
|
if lora_module_name not in base_dims:
|
||||||
|
base_dims[lora_module_name] = dim
|
||||||
|
|
||||||
|
for lora_module_name in dims.keys():
|
||||||
|
if lora_module_name not in alphas:
|
||||||
|
alpha = dims[lora_module_name]
|
||||||
|
alphas[lora_module_name] = alpha
|
||||||
|
if lora_module_name not in base_alphas:
|
||||||
|
base_alphas[lora_module_name] = alpha
|
||||||
|
|
||||||
|
print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
|
||||||
|
|
||||||
|
# merge
|
||||||
|
print(f"merging...")
|
||||||
|
for key in tqdm(lora_sd.keys()):
|
||||||
|
if "alpha" in key:
|
||||||
|
continue
|
||||||
|
|
||||||
|
lora_module_name = key[: key.rfind(".lora_")]
|
||||||
|
|
||||||
|
base_alpha = base_alphas[lora_module_name]
|
||||||
|
alpha = alphas[lora_module_name]
|
||||||
|
|
||||||
|
scale = math.sqrt(alpha / base_alpha) * ratio
|
||||||
|
|
||||||
|
if key in merged_sd:
|
||||||
|
assert (
|
||||||
|
merged_sd[key].size() == lora_sd[key].size()
|
||||||
|
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
|
||||||
|
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
|
||||||
|
else:
|
||||||
|
merged_sd[key] = lora_sd[key] * scale
|
||||||
|
|
||||||
|
# set alpha to sd
|
||||||
|
for lora_module_name, alpha in base_alphas.items():
|
||||||
|
key = lora_module_name + ".alpha"
|
||||||
|
merged_sd[key] = torch.tensor(alpha)
|
||||||
|
|
||||||
|
print("merged model")
|
||||||
|
print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
|
||||||
|
|
||||||
|
return merged_sd
|
||||||
|
|
||||||
|
|
||||||
|
def merge(args):
|
||||||
|
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
|
||||||
|
|
||||||
|
def str_to_dtype(p):
|
||||||
|
if p == "float":
|
||||||
|
return torch.float
|
||||||
|
if p == "fp16":
|
||||||
|
return torch.float16
|
||||||
|
if p == "bf16":
|
||||||
|
return torch.bfloat16
|
||||||
|
return None
|
||||||
|
|
||||||
|
merge_dtype = str_to_dtype(args.precision)
|
||||||
|
save_dtype = str_to_dtype(args.save_precision)
|
||||||
|
if save_dtype is None:
|
||||||
|
save_dtype = merge_dtype
|
||||||
|
|
||||||
|
if args.sd_model is not None:
|
||||||
|
print(f"loading SD model: {args.sd_model}")
|
||||||
|
|
||||||
|
(
|
||||||
|
text_model1,
|
||||||
|
text_model2,
|
||||||
|
vae,
|
||||||
|
unet,
|
||||||
|
text_projection,
|
||||||
|
logit_scale,
|
||||||
|
ckpt_info,
|
||||||
|
) = sdxl_model_util.load_models_from_sdxl_checkpoint(sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, args.sd_model, "cpu")
|
||||||
|
|
||||||
|
merge_to_sd_model(text_model2, text_model2, unet, args.models, args.ratios, merge_dtype)
|
||||||
|
|
||||||
|
print(f"saving SD model to: {args.save_to}")
|
||||||
|
sdxl_model_util.save_stable_diffusion_checkpoint(
|
||||||
|
args.save_to, text_model1, text_model2, unet, 0, 0, ckpt_info, vae, text_projection, logit_scale, save_dtype
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
state_dict = merge_lora_models(args.models, args.ratios, merge_dtype)
|
||||||
|
|
||||||
|
print(f"saving model to: {args.save_to}")
|
||||||
|
save_to_file(args.save_to, state_dict, state_dict, save_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_parser() -> argparse.ArgumentParser:
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--save_precision",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
choices=[None, "float", "fp16", "bf16"],
|
||||||
|
help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--precision",
|
||||||
|
type=str,
|
||||||
|
default="float",
|
||||||
|
choices=["float", "fp16", "bf16"],
|
||||||
|
help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--sd_model",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors"
|
||||||
|
)
|
||||||
|
parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = setup_parser()
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
merge(args)
|
||||||
@@ -11,10 +11,13 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import CLIPTokenizer
|
from transformers import CLIPTokenizer
|
||||||
from library import sdxl_model_util
|
|
||||||
from diffusers import EulerDiscreteScheduler
|
from diffusers import EulerDiscreteScheduler
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import open_clip
|
import open_clip
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
|
from library import model_util, sdxl_model_util
|
||||||
|
import networks.lora as lora
|
||||||
|
|
||||||
# scheduler: このあたりの設定はSD1/2と同じでいいらしい
|
# scheduler: このあたりの設定はSD1/2と同じでいいらしい
|
||||||
# scheduler: The settings around here seem to be the same as SD1/2
|
# scheduler: The settings around here seem to be the same as SD1/2
|
||||||
@@ -85,6 +88,13 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--prompt", type=str, default="A photo of a cat")
|
parser.add_argument("--prompt", type=str, default="A photo of a cat")
|
||||||
parser.add_argument("--negative_prompt", type=str, default="")
|
parser.add_argument("--negative_prompt", type=str, default="")
|
||||||
parser.add_argument("--output_dir", type=str, default=".")
|
parser.add_argument("--output_dir", type=str, default=".")
|
||||||
|
parser.add_argument(
|
||||||
|
"--lora_weights",
|
||||||
|
type=str,
|
||||||
|
nargs="*",
|
||||||
|
default=[],
|
||||||
|
help="LoRA weights, only supports networks.lora, each arguement is a `path;multiplier` (semi-colon separated)",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# HuggingFaceのmodel id
|
# HuggingFaceのmodel id
|
||||||
@@ -97,7 +107,7 @@ if __name__ == "__main__":
|
|||||||
# 本体RAMが少ない場合はGPUにロードするといいかも
|
# 本体RAMが少ない場合はGPUにロードするといいかも
|
||||||
# If the main RAM is small, it may be better to load it on the GPU
|
# If the main RAM is small, it may be better to load it on the GPU
|
||||||
text_model1, text_model2, vae, unet, text_projection, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
text_model1, text_model2, vae, unet, text_projection, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
||||||
"sdxl_base_v0-9", args.ckpt_path, "cpu"
|
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, args.ckpt_path, "cpu"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Text Encoder 1はSDXL本体でもHuggingFaceのものを使っている
|
# Text Encoder 1はSDXL本体でもHuggingFaceのものを使っている
|
||||||
@@ -134,6 +144,19 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
unet.set_use_memory_efficient_attention(True, False)
|
unet.set_use_memory_efficient_attention(True, False)
|
||||||
|
|
||||||
|
# LoRA
|
||||||
|
for weights_file in args.lora_weights:
|
||||||
|
if ";" in weights_file:
|
||||||
|
weights_file, multiplier = weights_file.split(";")
|
||||||
|
multiplier = float(multiplier)
|
||||||
|
else:
|
||||||
|
multiplier = 1.0
|
||||||
|
|
||||||
|
lora_model, weights_sd = lora.create_network_from_weights(
|
||||||
|
multiplier, weights_file, vae, [text_model1, text_model2], unet, None, True
|
||||||
|
)
|
||||||
|
lora_model.merge_to([text_model1, text_model2], unet, weights_sd, DTYPE, DEVICE)
|
||||||
|
|
||||||
# prepare embedding
|
# prepare embedding
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# vector
|
# vector
|
||||||
@@ -248,7 +271,7 @@ if __name__ == "__main__":
|
|||||||
latents = scheduler.step(noise_pred, t, latents).prev_sample
|
latents = scheduler.step(noise_pred, t, latents).prev_sample
|
||||||
|
|
||||||
# latents = 1 / 0.18215 * latents
|
# latents = 1 / 0.18215 * latents
|
||||||
latents = 1 / 0.13025 * latents
|
latents = 1 / sdxl_model_util.VAE_SCALE_FACTOR * latents
|
||||||
latents = latents.to(torch.float32)
|
latents = latents.to(torch.float32)
|
||||||
image = vae.decode(latents).sample
|
image = vae.decode(latents).sample
|
||||||
image = (image / 2 + 0.5).clamp(0, 1)
|
image = (image / 2 + 0.5).clamp(0, 1)
|
||||||
|
|||||||
605
sdxl_train.py
Normal file
605
sdxl_train.py
Normal file
@@ -0,0 +1,605 @@
|
|||||||
|
# training with captions
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import gc
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
from multiprocessing import Value
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
import torch
|
||||||
|
from accelerate.utils import set_seed
|
||||||
|
from diffusers import DDPMScheduler
|
||||||
|
from library import sdxl_model_util
|
||||||
|
|
||||||
|
import library.train_util as train_util
|
||||||
|
import library.config_util as config_util
|
||||||
|
import library.sdxl_train_util as sdxl_train_util
|
||||||
|
from library.config_util import (
|
||||||
|
ConfigSanitizer,
|
||||||
|
BlueprintGenerator,
|
||||||
|
)
|
||||||
|
import library.custom_train_functions as custom_train_functions
|
||||||
|
from library.custom_train_functions import (
|
||||||
|
apply_snr_weight,
|
||||||
|
prepare_scheduler_for_custom_training,
|
||||||
|
pyramid_noise_like,
|
||||||
|
apply_noise_offset,
|
||||||
|
scale_v_prediction_loss_like_noise_prediction,
|
||||||
|
)
|
||||||
|
from library.sdxl_original_unet import SdxlUNet2DConditionModel
|
||||||
|
|
||||||
|
|
||||||
|
def train(args):
|
||||||
|
train_util.verify_training_args(args)
|
||||||
|
train_util.prepare_dataset_args(args, True)
|
||||||
|
sdxl_train_util.verify_sdxl_training_args(args)
|
||||||
|
|
||||||
|
assert not args.weighted_captions, "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません"
|
||||||
|
assert (
|
||||||
|
not args.train_text_encoder or not args.cache_text_encoder_outputs
|
||||||
|
), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません"
|
||||||
|
|
||||||
|
cache_latents = args.cache_latents
|
||||||
|
|
||||||
|
if args.seed is not None:
|
||||||
|
set_seed(args.seed) # 乱数系列を初期化する
|
||||||
|
|
||||||
|
tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
|
||||||
|
|
||||||
|
# データセットを準備する
|
||||||
|
if args.dataset_class is None:
|
||||||
|
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
|
||||||
|
if args.dataset_config is not None:
|
||||||
|
print(f"Load dataset config from {args.dataset_config}")
|
||||||
|
user_config = config_util.load_user_config(args.dataset_config)
|
||||||
|
ignored = ["train_data_dir", "in_json"]
|
||||||
|
if any(getattr(args, attr) is not None for attr in ignored):
|
||||||
|
print(
|
||||||
|
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||||
|
", ".join(ignored)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
user_config = {
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"subsets": [
|
||||||
|
{
|
||||||
|
"image_dir": args.train_data_dir,
|
||||||
|
"metadata_file": args.in_json,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2])
|
||||||
|
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||||
|
else:
|
||||||
|
train_dataset_group = train_util.load_arbitrary_dataset(args, [tokenizer1, tokenizer2])
|
||||||
|
|
||||||
|
current_epoch = Value("i", 0)
|
||||||
|
current_step = Value("i", 0)
|
||||||
|
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||||
|
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
|
||||||
|
|
||||||
|
if args.debug_dataset:
|
||||||
|
train_util.debug_dataset(train_dataset_group, True)
|
||||||
|
return
|
||||||
|
if len(train_dataset_group) == 0:
|
||||||
|
print(
|
||||||
|
"No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if cache_latents:
|
||||||
|
assert (
|
||||||
|
train_dataset_group.is_latent_cacheable()
|
||||||
|
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||||
|
|
||||||
|
if args.cache_text_encoder_outputs:
|
||||||
|
assert (
|
||||||
|
train_dataset_group.is_text_encoder_output_cacheable()
|
||||||
|
), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
|
||||||
|
|
||||||
|
# acceleratorを準備する
|
||||||
|
print("prepare accelerator")
|
||||||
|
accelerator = train_util.prepare_accelerator(args)
|
||||||
|
|
||||||
|
# mixed precisionに対応した型を用意しておき適宜castする
|
||||||
|
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||||
|
|
||||||
|
# モデルを読み込む
|
||||||
|
(
|
||||||
|
load_stable_diffusion_format,
|
||||||
|
text_encoder1,
|
||||||
|
text_encoder2,
|
||||||
|
vae,
|
||||||
|
unet,
|
||||||
|
text_projection,
|
||||||
|
logit_scale,
|
||||||
|
ckpt_info,
|
||||||
|
) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
|
||||||
|
text_projection = text_projection.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
|
||||||
|
# verify load/save model formats
|
||||||
|
if load_stable_diffusion_format:
|
||||||
|
src_stable_diffusion_ckpt = args.pretrained_model_name_or_path
|
||||||
|
src_diffusers_model_path = None
|
||||||
|
else:
|
||||||
|
src_stable_diffusion_ckpt = None
|
||||||
|
src_diffusers_model_path = args.pretrained_model_name_or_path
|
||||||
|
|
||||||
|
if args.save_model_as is None:
|
||||||
|
save_stable_diffusion_format = load_stable_diffusion_format
|
||||||
|
use_safetensors = args.use_safetensors
|
||||||
|
else:
|
||||||
|
save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors"
|
||||||
|
use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower())
|
||||||
|
assert save_stable_diffusion_format, "save_model_as must be ckpt or safetensors / save_model_asはckptかsafetensorsである必要があります"
|
||||||
|
|
||||||
|
# Diffusers版のxformers使用フラグを設定する関数
|
||||||
|
def set_diffusers_xformers_flag(model, valid):
|
||||||
|
def fn_recursive_set_mem_eff(module: torch.nn.Module):
|
||||||
|
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
|
||||||
|
module.set_use_memory_efficient_attention_xformers(valid)
|
||||||
|
|
||||||
|
for child in module.children():
|
||||||
|
fn_recursive_set_mem_eff(child)
|
||||||
|
|
||||||
|
fn_recursive_set_mem_eff(model)
|
||||||
|
|
||||||
|
# モデルに xformers とか memory efficient attention を組み込む
|
||||||
|
if args.diffusers_xformers:
|
||||||
|
# もうU-Netを独自にしたので動かないけどVAEのxformersは動くはず
|
||||||
|
accelerator.print("Use xformers by Diffusers")
|
||||||
|
# set_diffusers_xformers_flag(unet, True)
|
||||||
|
set_diffusers_xformers_flag(vae, True)
|
||||||
|
else:
|
||||||
|
# Windows版のxformersはfloatで学習できなかったりxformersを使わない設定も可能にしておく必要がある
|
||||||
|
accelerator.print("Disable Diffusers' xformers")
|
||||||
|
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
||||||
|
vae.set_use_memory_efficient_attention_xformers(args.xformers)
|
||||||
|
|
||||||
|
# 学習を準備する
|
||||||
|
if cache_latents:
|
||||||
|
# vae.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
vae.to(accelerator.device, dtype=torch.float32) # VAE in float to avoid NaN
|
||||||
|
vae.requires_grad_(False)
|
||||||
|
vae.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
||||||
|
vae.to("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
|
# 学習を準備する:モデルを適切な状態にする
|
||||||
|
training_models = []
|
||||||
|
if args.gradient_checkpointing:
|
||||||
|
unet.enable_gradient_checkpointing()
|
||||||
|
training_models.append(unet)
|
||||||
|
|
||||||
|
if args.train_text_encoder:
|
||||||
|
# TODO each option for two text encoders?
|
||||||
|
accelerator.print("enable text encoder training")
|
||||||
|
if args.gradient_checkpointing:
|
||||||
|
text_encoder1.gradient_checkpointing_enable()
|
||||||
|
text_encoder2.gradient_checkpointing_enable()
|
||||||
|
training_models.append(text_encoder1)
|
||||||
|
training_models.append(text_encoder2)
|
||||||
|
else:
|
||||||
|
text_encoder1.requires_grad_(False)
|
||||||
|
text_encoder2.requires_grad_(False)
|
||||||
|
text_encoder1.eval()
|
||||||
|
text_encoder2.eval()
|
||||||
|
|
||||||
|
if not cache_latents:
|
||||||
|
vae.requires_grad_(False)
|
||||||
|
vae.eval()
|
||||||
|
vae.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
|
||||||
|
for m in training_models:
|
||||||
|
m.requires_grad_(True)
|
||||||
|
params = []
|
||||||
|
for m in training_models:
|
||||||
|
params.extend(m.parameters())
|
||||||
|
params_to_optimize = params
|
||||||
|
|
||||||
|
# calculate number of trainable parameters
|
||||||
|
n_params = 0
|
||||||
|
for p in params:
|
||||||
|
n_params += p.numel()
|
||||||
|
accelerator.print(f"number of models: {len(training_models)}")
|
||||||
|
accelerator.print(f"number of trainable parameters: {n_params}")
|
||||||
|
|
||||||
|
# 学習に必要なクラスを準備する
|
||||||
|
accelerator.print("prepare optimizer, data loader etc.")
|
||||||
|
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
|
||||||
|
|
||||||
|
# dataloaderを準備する
|
||||||
|
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||||
|
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||||
|
train_dataloader = torch.utils.data.DataLoader(
|
||||||
|
train_dataset_group,
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=True,
|
||||||
|
collate_fn=collater,
|
||||||
|
num_workers=n_workers,
|
||||||
|
persistent_workers=args.persistent_data_loader_workers,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 学習ステップ数を計算する
|
||||||
|
if args.max_train_epochs is not None:
|
||||||
|
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||||
|
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
||||||
|
)
|
||||||
|
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||||
|
|
||||||
|
# データセット側にも学習ステップを送信
|
||||||
|
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||||
|
|
||||||
|
# lr schedulerを用意する
|
||||||
|
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||||
|
|
||||||
|
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
||||||
|
if args.full_fp16:
|
||||||
|
assert (
|
||||||
|
args.mixed_precision == "fp16"
|
||||||
|
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
||||||
|
accelerator.print("enable full fp16 training.")
|
||||||
|
unet.to(weight_dtype)
|
||||||
|
text_encoder1.to(weight_dtype)
|
||||||
|
text_encoder2.to(weight_dtype)
|
||||||
|
|
||||||
|
# acceleratorがなんかよろしくやってくれるらしい
|
||||||
|
if args.train_text_encoder:
|
||||||
|
unet, text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||||
|
unet, text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler
|
||||||
|
)
|
||||||
|
|
||||||
|
# transform DDP after prepare
|
||||||
|
text_encoder1, text_encoder2, unet = train_util.transform_models_if_DDP([text_encoder1, text_encoder2, unet])
|
||||||
|
else:
|
||||||
|
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
||||||
|
(unet,) = train_util.transform_models_if_DDP([unet])
|
||||||
|
text_encoder1.to(weight_dtype)
|
||||||
|
text_encoder2.to(weight_dtype)
|
||||||
|
text_encoder1.eval()
|
||||||
|
text_encoder2.eval()
|
||||||
|
|
||||||
|
# TextEncoderの出力をキャッシュする
|
||||||
|
if args.cache_text_encoder_outputs:
|
||||||
|
text_encoder1_cache, text_encoder2_cache = sdxl_train_util.cache_text_encoder_outputs(
|
||||||
|
args, accelerator, (tokenizer1, tokenizer2), (text_encoder1, text_encoder2), train_dataloader, None
|
||||||
|
)
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
text_encoder1.to("cpu")
|
||||||
|
text_encoder2.to("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
else:
|
||||||
|
text_encoder1_cache = None
|
||||||
|
text_encoder2_cache = None
|
||||||
|
text_encoder1.to(accelerator.device)
|
||||||
|
text_encoder2.to(accelerator.device)
|
||||||
|
|
||||||
|
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
||||||
|
if args.full_fp16:
|
||||||
|
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||||
|
|
||||||
|
# resumeする
|
||||||
|
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
||||||
|
|
||||||
|
# epoch数を計算する
|
||||||
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||||
|
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||||
|
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
||||||
|
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
||||||
|
|
||||||
|
# 学習する
|
||||||
|
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||||
|
accelerator.print("running training / 学習開始")
|
||||||
|
accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
|
||||||
|
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||||
|
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
||||||
|
accelerator.print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
||||||
|
accelerator.print(
|
||||||
|
f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
|
||||||
|
)
|
||||||
|
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
||||||
|
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
||||||
|
|
||||||
|
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
||||||
|
global_step = 0
|
||||||
|
|
||||||
|
noise_scheduler = DDPMScheduler(
|
||||||
|
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
|
||||||
|
)
|
||||||
|
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
|
||||||
|
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name)
|
||||||
|
|
||||||
|
for epoch in range(num_train_epochs):
|
||||||
|
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||||
|
current_epoch.value = epoch + 1
|
||||||
|
|
||||||
|
for m in training_models:
|
||||||
|
m.train()
|
||||||
|
|
||||||
|
loss_total = 0
|
||||||
|
for step, batch in enumerate(train_dataloader):
|
||||||
|
current_step.value = global_step
|
||||||
|
# with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
|
||||||
|
if True:
|
||||||
|
if "latents" in batch and batch["latents"] is not None:
|
||||||
|
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
|
||||||
|
else:
|
||||||
|
with torch.no_grad():
|
||||||
|
# latentに変換
|
||||||
|
# latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
|
||||||
|
latents = vae.encode(batch["images"].to(torch.float32)).latent_dist.sample().to(weight_dtype)
|
||||||
|
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
|
||||||
|
b_size = latents.shape[0]
|
||||||
|
|
||||||
|
input_ids1 = batch["input_ids"]
|
||||||
|
input_ids2 = batch["input_ids2"]
|
||||||
|
if not args.cache_text_encoder_outputs:
|
||||||
|
with torch.set_grad_enabled(args.train_text_encoder):
|
||||||
|
# Get the text embedding for conditioning
|
||||||
|
# TODO support weighted captions
|
||||||
|
# if args.weighted_captions:
|
||||||
|
# encoder_hidden_states = get_weighted_text_embeddings(
|
||||||
|
# tokenizer,
|
||||||
|
# text_encoder,
|
||||||
|
# batch["captions"],
|
||||||
|
# accelerator.device,
|
||||||
|
# args.max_token_length // 75 if args.max_token_length else 1,
|
||||||
|
# clip_skip=args.clip_skip,
|
||||||
|
# )
|
||||||
|
# else:
|
||||||
|
input_ids1 = input_ids1.to(accelerator.device)
|
||||||
|
input_ids2 = input_ids2.to(accelerator.device)
|
||||||
|
encoder_hidden_states1, encoder_hidden_states2, pool2 = sdxl_train_util.get_hidden_states(
|
||||||
|
args,
|
||||||
|
input_ids1,
|
||||||
|
input_ids2,
|
||||||
|
tokenizer1,
|
||||||
|
tokenizer2,
|
||||||
|
text_encoder1,
|
||||||
|
text_encoder2,
|
||||||
|
None if not args.full_fp16 else weight_dtype,
|
||||||
|
)
|
||||||
|
pool2 = pool2 @ text_projection.to(pool2.dtype)
|
||||||
|
else:
|
||||||
|
encoder_hidden_states1 = []
|
||||||
|
encoder_hidden_states2 = []
|
||||||
|
pool2 = []
|
||||||
|
for input_id1, input_id2 in zip(input_ids1, input_ids2):
|
||||||
|
input_id1 = input_id1.squeeze(0)
|
||||||
|
input_id2 = input_id2.squeeze(0)
|
||||||
|
encoder_hidden_states1.append(text_encoder1_cache[tuple(input_id1.tolist())])
|
||||||
|
hidden_states2, p2 = text_encoder2_cache[tuple(input_id2.tolist())]
|
||||||
|
encoder_hidden_states2.append(hidden_states2)
|
||||||
|
pool2.append(p2)
|
||||||
|
encoder_hidden_states1 = torch.stack(encoder_hidden_states1).to(accelerator.device).to(weight_dtype)
|
||||||
|
encoder_hidden_states2 = torch.stack(encoder_hidden_states2).to(accelerator.device).to(weight_dtype)
|
||||||
|
pool2 = torch.stack(pool2).to(accelerator.device).to(weight_dtype)
|
||||||
|
|
||||||
|
pool2 = pool2 @ text_projection.to(pool2.dtype)
|
||||||
|
|
||||||
|
# get size embeddings
|
||||||
|
orig_size = batch["original_sizes_hw"]
|
||||||
|
crop_size = batch["crop_top_lefts"]
|
||||||
|
target_size = batch["target_sizes_hw"]
|
||||||
|
embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
|
||||||
|
|
||||||
|
# concat embeddings
|
||||||
|
vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
|
||||||
|
text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
|
||||||
|
|
||||||
|
# Sample noise that we'll add to the latents
|
||||||
|
noise = torch.randn_like(latents, device=latents.device)
|
||||||
|
if args.noise_offset:
|
||||||
|
noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale)
|
||||||
|
elif args.multires_noise_iterations:
|
||||||
|
noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount)
|
||||||
|
|
||||||
|
# Sample a random timestep for each image
|
||||||
|
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
|
||||||
|
timesteps = timesteps.long()
|
||||||
|
|
||||||
|
# Add noise to the latents according to the noise magnitude at each timestep
|
||||||
|
# (this is the forward diffusion process)
|
||||||
|
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||||
|
|
||||||
|
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
||||||
|
|
||||||
|
# Predict the noise residual
|
||||||
|
with accelerator.autocast():
|
||||||
|
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
|
||||||
|
|
||||||
|
target = noise
|
||||||
|
|
||||||
|
if args.min_snr_gamma:
|
||||||
|
# do not mean over batch dimension for snr weight or scale v-pred loss
|
||||||
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||||
|
loss = loss.mean([1, 2, 3])
|
||||||
|
|
||||||
|
if args.min_snr_gamma:
|
||||||
|
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
||||||
|
|
||||||
|
loss = loss.mean() # mean over batch dimension
|
||||||
|
else:
|
||||||
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
||||||
|
|
||||||
|
accelerator.backward(loss)
|
||||||
|
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||||
|
params_to_clip = []
|
||||||
|
for m in training_models:
|
||||||
|
params_to_clip.extend(m.parameters())
|
||||||
|
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||||
|
|
||||||
|
optimizer.step()
|
||||||
|
lr_scheduler.step()
|
||||||
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
|
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||||
|
if accelerator.sync_gradients:
|
||||||
|
progress_bar.update(1)
|
||||||
|
global_step += 1
|
||||||
|
|
||||||
|
# sdxl_train_util.sample_images(
|
||||||
|
# accelerator,
|
||||||
|
# args,
|
||||||
|
# None,
|
||||||
|
# global_step,
|
||||||
|
# accelerator.device,
|
||||||
|
# vae,
|
||||||
|
# tokenizer1,
|
||||||
|
# tokenizer2,
|
||||||
|
# text_encoder1,
|
||||||
|
# text_encoder2,
|
||||||
|
# unet,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# 指定ステップごとにモデルを保存
|
||||||
|
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
|
||||||
|
sdxl_train_util.save_sd_model_on_epoch_end_or_stepwise(
|
||||||
|
args,
|
||||||
|
False,
|
||||||
|
accelerator,
|
||||||
|
src_path,
|
||||||
|
save_stable_diffusion_format,
|
||||||
|
use_safetensors,
|
||||||
|
save_dtype,
|
||||||
|
epoch,
|
||||||
|
num_train_epochs,
|
||||||
|
global_step,
|
||||||
|
accelerator.unwrap_model(text_encoder1),
|
||||||
|
accelerator.unwrap_model(text_encoder2),
|
||||||
|
accelerator.unwrap_model(unet),
|
||||||
|
vae,
|
||||||
|
text_projection,
|
||||||
|
logit_scale,
|
||||||
|
ckpt_info,
|
||||||
|
)
|
||||||
|
|
||||||
|
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
||||||
|
if args.logging_dir is not None:
|
||||||
|
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
||||||
|
if (
|
||||||
|
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy"
|
||||||
|
): # tracking d*lr value
|
||||||
|
logs["lr/d*lr"] = (
|
||||||
|
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
|
||||||
|
)
|
||||||
|
accelerator.log(logs, step=global_step)
|
||||||
|
|
||||||
|
# TODO moving averageにする
|
||||||
|
loss_total += current_loss
|
||||||
|
avr_loss = loss_total / (step + 1)
|
||||||
|
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||||
|
progress_bar.set_postfix(**logs)
|
||||||
|
|
||||||
|
if global_step >= args.max_train_steps:
|
||||||
|
break
|
||||||
|
|
||||||
|
if args.logging_dir is not None:
|
||||||
|
logs = {"loss/epoch": loss_total / len(train_dataloader)}
|
||||||
|
accelerator.log(logs, step=epoch + 1)
|
||||||
|
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
|
if args.save_every_n_epochs is not None:
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
|
||||||
|
sdxl_train_util.save_sd_model_on_epoch_end_or_stepwise(
|
||||||
|
args,
|
||||||
|
True,
|
||||||
|
accelerator,
|
||||||
|
src_path,
|
||||||
|
save_stable_diffusion_format,
|
||||||
|
use_safetensors,
|
||||||
|
save_dtype,
|
||||||
|
epoch,
|
||||||
|
num_train_epochs,
|
||||||
|
global_step,
|
||||||
|
accelerator.unwrap_model(text_encoder1),
|
||||||
|
accelerator.unwrap_model(text_encoder2),
|
||||||
|
accelerator.unwrap_model(unet),
|
||||||
|
vae,
|
||||||
|
text_projection,
|
||||||
|
logit_scale,
|
||||||
|
ckpt_info,
|
||||||
|
)
|
||||||
|
|
||||||
|
# train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||||
|
|
||||||
|
is_main_process = accelerator.is_main_process
|
||||||
|
# if is_main_process:
|
||||||
|
unet = accelerator.unwrap_model(unet)
|
||||||
|
text_encoder1 = accelerator.unwrap_model(text_encoder1)
|
||||||
|
text_encoder2 = accelerator.unwrap_model(text_encoder2)
|
||||||
|
|
||||||
|
accelerator.end_training()
|
||||||
|
|
||||||
|
if args.save_state: # and is_main_process:
|
||||||
|
train_util.save_state_on_train_end(args, accelerator)
|
||||||
|
|
||||||
|
del accelerator # この後メモリを使うのでこれは消す
|
||||||
|
|
||||||
|
if is_main_process:
|
||||||
|
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
|
||||||
|
sdxl_train_util.save_sd_model_on_train_end(
|
||||||
|
args,
|
||||||
|
src_path,
|
||||||
|
save_stable_diffusion_format,
|
||||||
|
use_safetensors,
|
||||||
|
save_dtype,
|
||||||
|
epoch,
|
||||||
|
global_step,
|
||||||
|
text_encoder1,
|
||||||
|
text_encoder2,
|
||||||
|
unet,
|
||||||
|
vae,
|
||||||
|
text_projection,
|
||||||
|
logit_scale,
|
||||||
|
ckpt_info,
|
||||||
|
)
|
||||||
|
print("model saved.")
|
||||||
|
|
||||||
|
|
||||||
|
def setup_parser() -> argparse.ArgumentParser:
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
train_util.add_sd_models_arguments(parser)
|
||||||
|
train_util.add_dataset_arguments(parser, True, True, True)
|
||||||
|
train_util.add_training_arguments(parser, False)
|
||||||
|
train_util.add_sd_saving_arguments(parser)
|
||||||
|
train_util.add_optimizer_arguments(parser)
|
||||||
|
config_util.add_config_arguments(parser)
|
||||||
|
custom_train_functions.add_custom_train_arguments(parser)
|
||||||
|
sdxl_train_util.add_sdxl_training_arguments(parser)
|
||||||
|
|
||||||
|
parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
|
||||||
|
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = setup_parser()
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
args = train_util.read_config_from_file(args, parser)
|
||||||
|
|
||||||
|
train(args)
|
||||||
172
sdxl_train_network.py
Normal file
172
sdxl_train_network.py
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
import argparse
|
||||||
|
import torch
|
||||||
|
from library import sdxl_model_util, sdxl_train_util, train_util
|
||||||
|
import train_network
|
||||||
|
|
||||||
|
|
||||||
|
class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR
|
||||||
|
|
||||||
|
def assert_extra_args(self, args, train_dataset_group):
|
||||||
|
super().assert_extra_args(args)
|
||||||
|
sdxl_train_util.verify_sdxl_training_args(args)
|
||||||
|
|
||||||
|
if args.cache_text_encoder_outputs:
|
||||||
|
assert (
|
||||||
|
train_dataset_group.is_text_encoder_output_cacheable()
|
||||||
|
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
|
||||||
|
|
||||||
|
assert (
|
||||||
|
args.network_train_unet_only or not args.cache_text_encoder_outputs
|
||||||
|
), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません"
|
||||||
|
|
||||||
|
def load_target_model(self, args, weight_dtype, accelerator):
|
||||||
|
(
|
||||||
|
load_stable_diffusion_format,
|
||||||
|
text_encoder1,
|
||||||
|
text_encoder2,
|
||||||
|
vae,
|
||||||
|
unet,
|
||||||
|
text_projection,
|
||||||
|
logit_scale,
|
||||||
|
ckpt_info,
|
||||||
|
) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, weight_dtype)
|
||||||
|
|
||||||
|
self.load_stable_diffusion_format = load_stable_diffusion_format
|
||||||
|
self.text_projection = text_projection.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
self.logit_scale = logit_scale
|
||||||
|
self.ckpt_info = ckpt_info
|
||||||
|
|
||||||
|
return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, [text_encoder1, text_encoder2], vae, unet
|
||||||
|
|
||||||
|
def load_tokenizer(self, args):
|
||||||
|
tokenizer = sdxl_train_util.load_tokenizers(args)
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
def is_text_encoder_outputs_cached(self, args):
|
||||||
|
return args.cache_text_encoder_outputs
|
||||||
|
|
||||||
|
def cache_text_encoder_outputs_if_needed(
|
||||||
|
self, args, accelerator, unet, vae, tokenizers, text_encoders, data_loader, weight_dtype
|
||||||
|
):
|
||||||
|
if args.cache_text_encoder_outputs:
|
||||||
|
if not args.lowram:
|
||||||
|
# メモリ消費を減らす
|
||||||
|
print("move vae and unet to cpu to save memory")
|
||||||
|
org_vae_device = vae.device
|
||||||
|
org_unet_device = unet.device
|
||||||
|
vae.to("cpu")
|
||||||
|
unet.to("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
text_encoder1_cache, text_encoder2_cache = sdxl_train_util.cache_text_encoder_outputs(
|
||||||
|
args, accelerator, tokenizers, text_encoders, data_loader, weight_dtype
|
||||||
|
)
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
text_encoders[0].to("cpu")
|
||||||
|
text_encoders[1].to("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
self.text_encoder1_cache = text_encoder1_cache
|
||||||
|
self.text_encoder2_cache = text_encoder2_cache
|
||||||
|
|
||||||
|
if not args.lowram:
|
||||||
|
print("move vae and unet back to original device")
|
||||||
|
vae.to(org_vae_device)
|
||||||
|
unet.to(org_unet_device)
|
||||||
|
else:
|
||||||
|
self.text_encoder1_cache = None
|
||||||
|
self.text_encoder2_cache = None
|
||||||
|
|
||||||
|
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
|
||||||
|
text_encoders[0].to(accelerator.device)
|
||||||
|
text_encoders[1].to(accelerator.device)
|
||||||
|
|
||||||
|
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
|
||||||
|
input_ids1 = batch["input_ids"]
|
||||||
|
input_ids2 = batch["input_ids2"]
|
||||||
|
if not args.cache_text_encoder_outputs:
|
||||||
|
with torch.enable_grad():
|
||||||
|
# Get the text embedding for conditioning
|
||||||
|
# TODO support weighted captions
|
||||||
|
# if args.weighted_captions:
|
||||||
|
# encoder_hidden_states = get_weighted_text_embeddings(
|
||||||
|
# tokenizer,
|
||||||
|
# text_encoder,
|
||||||
|
# batch["captions"],
|
||||||
|
# accelerator.device,
|
||||||
|
# args.max_token_length // 75 if args.max_token_length else 1,
|
||||||
|
# clip_skip=args.clip_skip,
|
||||||
|
# )
|
||||||
|
# else:
|
||||||
|
input_ids1 = input_ids1.to(accelerator.device)
|
||||||
|
input_ids2 = input_ids2.to(accelerator.device)
|
||||||
|
encoder_hidden_states1, encoder_hidden_states2, pool2 = sdxl_train_util.get_hidden_states(
|
||||||
|
args,
|
||||||
|
input_ids1,
|
||||||
|
input_ids2,
|
||||||
|
tokenizers[0],
|
||||||
|
tokenizers[1],
|
||||||
|
text_encoders[0],
|
||||||
|
text_encoders[1],
|
||||||
|
None if not args.full_fp16 else weight_dtype,
|
||||||
|
)
|
||||||
|
pool2 = pool2 @ self.text_projection.to(pool2.dtype)
|
||||||
|
else:
|
||||||
|
encoder_hidden_states1 = []
|
||||||
|
encoder_hidden_states2 = []
|
||||||
|
pool2 = []
|
||||||
|
for input_id1, input_id2 in zip(input_ids1, input_ids2):
|
||||||
|
input_id1 = input_id1.squeeze(0)
|
||||||
|
input_id2 = input_id2.squeeze(0)
|
||||||
|
encoder_hidden_states1.append(self.text_encoder1_cache[tuple(input_id1.tolist())])
|
||||||
|
hidden_states2, p2 = self.text_encoder2_cache[tuple(input_id2.tolist())]
|
||||||
|
encoder_hidden_states2.append(hidden_states2)
|
||||||
|
pool2.append(p2)
|
||||||
|
encoder_hidden_states1 = torch.stack(encoder_hidden_states1).to(accelerator.device).to(weight_dtype)
|
||||||
|
encoder_hidden_states2 = torch.stack(encoder_hidden_states2).to(accelerator.device).to(weight_dtype)
|
||||||
|
pool2 = torch.stack(pool2).to(accelerator.device).to(weight_dtype)
|
||||||
|
|
||||||
|
pool2 = pool2 @ self.text_projection.to(weight_dtype)
|
||||||
|
|
||||||
|
return encoder_hidden_states1, encoder_hidden_states2, pool2
|
||||||
|
|
||||||
|
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
|
||||||
|
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
||||||
|
|
||||||
|
# get size embeddings
|
||||||
|
orig_size = batch["original_sizes_hw"]
|
||||||
|
crop_size = batch["crop_top_lefts"]
|
||||||
|
target_size = batch["target_sizes_hw"]
|
||||||
|
embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
|
||||||
|
|
||||||
|
# concat embeddings
|
||||||
|
encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds
|
||||||
|
vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
|
||||||
|
text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
|
||||||
|
|
||||||
|
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
|
||||||
|
return noise_pred
|
||||||
|
|
||||||
|
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
|
||||||
|
print("sample_images is not implemented")
|
||||||
|
|
||||||
|
|
||||||
|
def setup_parser() -> argparse.ArgumentParser:
|
||||||
|
parser = train_network.setup_parser()
|
||||||
|
sdxl_train_util.add_sdxl_training_arguments(parser)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = setup_parser()
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
args = train_util.read_config_from_file(args, parser)
|
||||||
|
|
||||||
|
trainer = SdxlNetworkTrainer()
|
||||||
|
trainer.train(args)
|
||||||
1525
train_network.py
1525
train_network.py
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user