mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
add sdxl fine-tuning and LoRA
This commit is contained in:
@@ -1061,6 +1061,16 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt
|
||||
|
||||
return text_model, vae, unet
|
||||
|
||||
def get_model_version_str_for_sd1_sd2(v2, v_parameterization):
|
||||
# only for reference
|
||||
version_str = "sd"
|
||||
if v2:
|
||||
version_str += "_v2"
|
||||
else:
|
||||
version_str += "_v1"
|
||||
if v_parameterization:
|
||||
version_str += "_v"
|
||||
return version_str
|
||||
|
||||
def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False):
|
||||
def convert_key(key):
|
||||
|
||||
@@ -6,6 +6,10 @@ from library import model_util
|
||||
from library import sdxl_original_unet
|
||||
|
||||
|
||||
VAE_SCALE_FACTOR = 0.13025
|
||||
MODEL_VERSION_SDXL_BASE_V0_9 = "sdxl_base_v0-9"
|
||||
|
||||
|
||||
def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
|
||||
SDXL_KEY_PREFIX = "conditioner.embedders.1.model."
|
||||
|
||||
@@ -76,8 +80,8 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
|
||||
return new_sd, text_projection, logit_scale
|
||||
|
||||
|
||||
def load_models_from_sdxl_checkpoint(model_type, ckpt_path, map_location):
|
||||
# model_type is reserved to future use
|
||||
def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location):
|
||||
# model_version is reserved for future use
|
||||
|
||||
# Load the state dict
|
||||
if model_util.is_safetensors(ckpt_path):
|
||||
|
||||
@@ -1069,8 +1069,8 @@ class SdxlUNet2DConditionModel(nn.Module):
|
||||
t_emb = t_emb.to(x.dtype)
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
assert y.shape[0] == x.shape[0]
|
||||
assert x.dtype == y.dtype
|
||||
assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}"
|
||||
assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}"
|
||||
# assert x.dtype == self.dtype
|
||||
emb = emb + self.label_emb(y)
|
||||
|
||||
@@ -1105,6 +1105,8 @@ class SdxlUNet2DConditionModel(nn.Module):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import time
|
||||
|
||||
print("create unet")
|
||||
unet = SdxlUNet2DConditionModel()
|
||||
|
||||
@@ -1132,8 +1134,11 @@ if __name__ == "__main__":
|
||||
print("start training")
|
||||
steps = 10
|
||||
batch_size = 1
|
||||
|
||||
for step in range(steps):
|
||||
print(f"step {step}")
|
||||
if step == 1:
|
||||
time_start = time.perf_counter()
|
||||
|
||||
x = torch.randn(batch_size, 4, 128, 128).cuda() # 1024x1024
|
||||
t = torch.randint(low=0, high=10, size=(batch_size,), device="cuda")
|
||||
@@ -1149,3 +1154,6 @@ if __name__ == "__main__":
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
time_end = time.perf_counter()
|
||||
print(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps")
|
||||
|
||||
384
library/sdxl_train_util.py
Normal file
384
library/sdxl_train_util.py
Normal file
@@ -0,0 +1,384 @@
|
||||
import argparse
|
||||
import gc
|
||||
import math
|
||||
import os
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from transformers import CLIPTokenizer
|
||||
import open_clip
|
||||
from library import model_util, sdxl_model_util, train_util
|
||||
|
||||
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
||||
|
||||
DEFAULT_NOISE_OFFSET = 0.0357
|
||||
|
||||
|
||||
# TODO: separate checkpoint for each U-Net/Text Encoder/VAE
|
||||
def load_target_model(args, accelerator, model_version: str, weight_dtype):
|
||||
# load models for each process
|
||||
for pi in range(accelerator.state.num_processes):
|
||||
if pi == accelerator.state.local_process_index:
|
||||
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
|
||||
|
||||
(
|
||||
load_stable_diffusion_format,
|
||||
text_encoder1,
|
||||
text_encoder2,
|
||||
vae,
|
||||
unet,
|
||||
text_projection,
|
||||
logit_scale,
|
||||
ckpt_info,
|
||||
) = _load_target_model(args, model_version, weight_dtype, accelerator.device if args.lowram else "cpu")
|
||||
|
||||
# work on low-ram device
|
||||
if args.lowram:
|
||||
text_encoder1.to(accelerator.device)
|
||||
text_encoder2.to(accelerator.device)
|
||||
unet.to(accelerator.device)
|
||||
vae.to(accelerator.device)
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
text_encoder1, text_encoder2, unet = train_util.transform_models_if_DDP([text_encoder1, text_encoder2, unet])
|
||||
|
||||
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, text_projection, logit_scale, ckpt_info
|
||||
|
||||
|
||||
def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtype, device="cpu"):
|
||||
# only supports StableDiffusion
|
||||
name_or_path = args.pretrained_model_name_or_path
|
||||
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
|
||||
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
|
||||
assert (
|
||||
load_stable_diffusion_format
|
||||
), f"only supports StableDiffusion format for SDXL / SDXLではStableDiffusion形式のみサポートしています: {name_or_path}"
|
||||
|
||||
print(f"load StableDiffusion checkpoint: {name_or_path}")
|
||||
(
|
||||
text_encoder1,
|
||||
text_encoder2,
|
||||
vae,
|
||||
unet,
|
||||
text_projection,
|
||||
logit_scale,
|
||||
ckpt_info,
|
||||
) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device)
|
||||
|
||||
# VAEを読み込む
|
||||
if args.vae is not None:
|
||||
vae = model_util.load_vae(args.vae, weight_dtype)
|
||||
print("additional VAE loaded")
|
||||
|
||||
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, text_projection, logit_scale, ckpt_info
|
||||
|
||||
|
||||
class WrapperTokenizer:
|
||||
# open clipのtokenizerをHuggingFaceのtokenizerと同じ形で使えるようにする
|
||||
def __init__(self):
|
||||
open_clip_tokenizer = open_clip.tokenizer._tokenizer
|
||||
self.model_max_length = 77
|
||||
self.bos_token_id = open_clip_tokenizer.all_special_ids[0]
|
||||
self.eos_token_id = open_clip_tokenizer.all_special_ids[1]
|
||||
self.pad_token_id = 0 # 結果から推定している
|
||||
|
||||
def __call__(self, *args: Any, **kwds: Any) -> Any:
|
||||
return self.tokenize(*args, **kwds)
|
||||
|
||||
def tokenize(self, text, padding, truncation, max_length, return_tensors):
|
||||
assert padding == "max_length"
|
||||
assert truncation == True
|
||||
assert return_tensors == "pt"
|
||||
input_ids = open_clip.tokenize(text, context_length=max_length)
|
||||
return SimpleNamespace(**{"input_ids": input_ids})
|
||||
|
||||
|
||||
def load_tokenizers(args: argparse.Namespace):
|
||||
print("prepare tokenizers")
|
||||
original_path = TOKENIZER_PATH
|
||||
|
||||
tokenizer1: CLIPTokenizer = None
|
||||
if args.tokenizer_cache_dir:
|
||||
local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_"))
|
||||
if os.path.exists(local_tokenizer_path):
|
||||
print(f"load tokenizer from cache: {local_tokenizer_path}")
|
||||
tokenizer1 = CLIPTokenizer.from_pretrained(local_tokenizer_path)
|
||||
|
||||
if tokenizer1 is None:
|
||||
tokenizer1 = CLIPTokenizer.from_pretrained(original_path)
|
||||
|
||||
if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
|
||||
print(f"save Tokenizer to cache: {local_tokenizer_path}")
|
||||
tokenizer1.save_pretrained(local_tokenizer_path)
|
||||
|
||||
if hasattr(args, "max_token_length") and args.max_token_length is not None:
|
||||
print(f"update token length: {args.max_token_length}")
|
||||
|
||||
# tokenizer2 is from open_clip
|
||||
# TODO caching
|
||||
tokenizer2 = WrapperTokenizer()
|
||||
|
||||
return [tokenizer1, tokenizer2]
|
||||
|
||||
|
||||
def get_hidden_states(
|
||||
args: argparse.Namespace, input_ids1, input_ids2, tokenizer1, tokenizer2, text_encoder1, text_encoder2, weight_dtype=None
|
||||
):
|
||||
input_ids1 = input_ids1.reshape((-1, tokenizer1.model_max_length)) # batch_size*n, 77
|
||||
input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) # batch_size*n, 77
|
||||
|
||||
# text_encoder1
|
||||
enc_out = text_encoder1(input_ids1, output_hidden_states=True, return_dict=True)
|
||||
hidden_states1 = enc_out["hidden_states"][11]
|
||||
|
||||
# text_encoder2
|
||||
enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True)
|
||||
hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer
|
||||
pool2 = enc_out["pooler_output"]
|
||||
|
||||
if args.max_token_length is not None:
|
||||
# bs*3, 77, 768 or 1024
|
||||
# encoder1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
|
||||
states_list = [hidden_states1[:, 0].unsqueeze(1)] # <BOS>
|
||||
for i in range(1, args.max_token_length, tokenizer1.model_max_length):
|
||||
states_list.append(hidden_states1[:, i : i + tokenizer1.model_max_length - 2]) # <BOS> の後から <EOS> の前まで
|
||||
states_list.append(hidden_states1[:, -1].unsqueeze(1)) # <EOS>
|
||||
hidden_states1 = torch.cat(states_list, dim=1)
|
||||
|
||||
# v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
|
||||
states_list = [hidden_states2[:, 0].unsqueeze(1)] # <BOS>
|
||||
for i in range(1, args.max_token_length, tokenizer2.model_max_length):
|
||||
chunk = hidden_states2[:, i : i + tokenizer2.model_max_length - 2] # <BOS> の後から 最後の前まで
|
||||
if i > 0:
|
||||
for j in range(len(chunk)):
|
||||
if input_ids2[j, 1] == tokenizer2.eos_token: # 空、つまり <BOS> <EOS> <PAD> ...のパターン
|
||||
chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
|
||||
states_list.append(chunk) # <BOS> の後から <EOS> の前まで
|
||||
states_list.append(hidden_states2[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
|
||||
hidden_states2 = torch.cat(states_list, dim=1)
|
||||
|
||||
if weight_dtype is not None:
|
||||
# this is required for additional network training
|
||||
hidden_states1 = hidden_states1.to(weight_dtype)
|
||||
hidden_states2 = hidden_states2.to(weight_dtype)
|
||||
|
||||
return hidden_states1, hidden_states2, pool2
|
||||
|
||||
|
||||
def timestep_embedding(timesteps, dim, max_period=10000):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an [N x dim] Tensor of positional embeddings.
|
||||
"""
|
||||
half = dim // 2
|
||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
||||
device=timesteps.device
|
||||
)
|
||||
args = timesteps[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
|
||||
def get_timestep_embedding(x, outdim):
|
||||
assert len(x.shape) == 2
|
||||
b, dims = x.shape[0], x.shape[1]
|
||||
x = torch.flatten(x)
|
||||
emb = timestep_embedding(x, outdim)
|
||||
emb = torch.reshape(emb, (b, dims * outdim))
|
||||
return emb
|
||||
|
||||
|
||||
def get_size_embeddings(orig_size, crop_size, target_size, device):
|
||||
emb1 = get_timestep_embedding(orig_size, 256)
|
||||
emb2 = get_timestep_embedding(crop_size, 256)
|
||||
emb3 = get_timestep_embedding(target_size, 256)
|
||||
vector = torch.cat([emb1, emb2, emb3], dim=1).to(device)
|
||||
return vector
|
||||
|
||||
|
||||
def save_sd_model_on_train_end(
|
||||
args: argparse.Namespace,
|
||||
src_path: str,
|
||||
save_stable_diffusion_format: bool,
|
||||
use_safetensors: bool,
|
||||
save_dtype: torch.dtype,
|
||||
epoch: int,
|
||||
global_step: int,
|
||||
text_encoder1,
|
||||
text_encoder2,
|
||||
unet,
|
||||
vae,
|
||||
text_projection,
|
||||
logit_scale,
|
||||
ckpt_info,
|
||||
):
|
||||
def sd_saver(ckpt_file, epoch_no, global_step):
|
||||
sdxl_model_util.save_stable_diffusion_checkpoint(
|
||||
ckpt_file,
|
||||
text_encoder1,
|
||||
text_encoder2,
|
||||
unet,
|
||||
epoch_no,
|
||||
global_step,
|
||||
ckpt_info,
|
||||
vae,
|
||||
text_projection,
|
||||
logit_scale,
|
||||
save_dtype,
|
||||
)
|
||||
|
||||
def diffusers_saver(out_dir):
|
||||
raise NotImplementedError("diffusers_saver is not implemented")
|
||||
|
||||
train_util.save_sd_model_on_train_end_common(
|
||||
args, save_stable_diffusion_format, use_safetensors, epoch, global_step, sd_saver, diffusers_saver
|
||||
)
|
||||
|
||||
|
||||
# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している
|
||||
# on_epoch_end: Trueならepoch終了時、Falseならstep経過時
|
||||
def save_sd_model_on_epoch_end_or_stepwise(
|
||||
args: argparse.Namespace,
|
||||
on_epoch_end: bool,
|
||||
accelerator,
|
||||
src_path,
|
||||
save_stable_diffusion_format: bool,
|
||||
use_safetensors: bool,
|
||||
save_dtype: torch.dtype,
|
||||
epoch: int,
|
||||
num_train_epochs: int,
|
||||
global_step: int,
|
||||
text_encoder1,
|
||||
text_encoder2,
|
||||
unet,
|
||||
vae,
|
||||
text_projection,
|
||||
logit_scale,
|
||||
ckpt_info,
|
||||
):
|
||||
def sd_saver(ckpt_file, epoch_no, global_step):
|
||||
sdxl_model_util.save_stable_diffusion_checkpoint(
|
||||
ckpt_file,
|
||||
text_encoder1,
|
||||
text_encoder2,
|
||||
unet,
|
||||
epoch_no,
|
||||
global_step,
|
||||
ckpt_info,
|
||||
vae,
|
||||
text_projection,
|
||||
logit_scale,
|
||||
save_dtype,
|
||||
)
|
||||
|
||||
def diffusers_saver(out_dir):
|
||||
raise NotImplementedError("diffusers_saver is not implemented")
|
||||
|
||||
train_util.save_sd_model_on_epoch_end_or_stepwise_common(
|
||||
args,
|
||||
on_epoch_end,
|
||||
accelerator,
|
||||
save_stable_diffusion_format,
|
||||
use_safetensors,
|
||||
epoch,
|
||||
num_train_epochs,
|
||||
global_step,
|
||||
sd_saver,
|
||||
diffusers_saver,
|
||||
)
|
||||
|
||||
|
||||
# TextEncoderの出力をキャッシュする
|
||||
# weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる
|
||||
def cache_text_encoder_outputs(args, accelerator, tokenizers, text_encoders, data_loader, weight_dtype):
|
||||
print("caching text encoder outputs")
|
||||
|
||||
tokenizer1, tokenizer2 = tokenizers
|
||||
text_encoder1, text_encoder2 = text_encoders
|
||||
text_encoder1.to(accelerator.device)
|
||||
text_encoder2.to(accelerator.device)
|
||||
if weight_dtype is not None:
|
||||
text_encoder1.to(dtype=weight_dtype)
|
||||
text_encoder2.to(dtype=weight_dtype)
|
||||
|
||||
text_encoder1_cache = {}
|
||||
text_encoder2_cache = {}
|
||||
for batch in tqdm(data_loader):
|
||||
input_ids1_batch = batch["input_ids"]
|
||||
input_ids2_batch = batch["input_ids2"]
|
||||
|
||||
# split batch to avoid OOM
|
||||
# TODO specify batch size by args
|
||||
for input_ids1, input_ids2 in zip(input_ids1_batch.split(1), input_ids2_batch.split(1)):
|
||||
# remove input_ids already in cache
|
||||
input_ids1 = input_ids1.squeeze(0)
|
||||
input_ids2 = input_ids2.squeeze(0)
|
||||
input_ids1 = [i for i in input_ids1 if i not in text_encoder1_cache]
|
||||
input_ids2 = [i for i in input_ids2 if i not in text_encoder2_cache]
|
||||
assert len(input_ids1) == len(input_ids2)
|
||||
if len(input_ids1) == 0:
|
||||
continue
|
||||
input_ids1 = torch.stack(input_ids1).to(accelerator.device)
|
||||
input_ids2 = torch.stack(input_ids2).to(accelerator.device)
|
||||
|
||||
with torch.no_grad():
|
||||
encoder_hidden_states1, encoder_hidden_states2, pool2 = get_hidden_states(
|
||||
args,
|
||||
input_ids1,
|
||||
input_ids2,
|
||||
tokenizer1,
|
||||
tokenizer2,
|
||||
text_encoder1,
|
||||
text_encoder2,
|
||||
None if not args.full_fp16 else weight_dtype,
|
||||
)
|
||||
encoder_hidden_states1 = encoder_hidden_states1.detach().to("cpu")
|
||||
encoder_hidden_states2 = encoder_hidden_states2.detach().to("cpu")
|
||||
pool2 = pool2.to("cpu")
|
||||
for input_id1, input_id2, hidden_states1, hidden_states2, p2 in zip(
|
||||
input_ids1, input_ids2, encoder_hidden_states1, encoder_hidden_states2, pool2
|
||||
):
|
||||
text_encoder1_cache[tuple(input_id1.tolist())] = hidden_states1
|
||||
text_encoder2_cache[tuple(input_id2.tolist())] = (hidden_states2, p2)
|
||||
return text_encoder1_cache, text_encoder2_cache
|
||||
|
||||
|
||||
def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
|
||||
)
|
||||
|
||||
|
||||
def verify_sdxl_training_args(args: argparse.Namespace):
|
||||
assert (
|
||||
not args.v2 and not args.v_parameterization
|
||||
), "v2 or v_parameterization cannot be enabled in SDXL training / SDXL学習ではv2とv_parameterizationを有効にすることはできません"
|
||||
if args.clip_skip is not None:
|
||||
print("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません")
|
||||
|
||||
if args.multires_noise_iterations:
|
||||
print(
|
||||
f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET}, but noise_offset is disabled due to multires_noise_iterations / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されていますが、multires_noise_iterationsが有効になっているためnoise_offsetは無効になります"
|
||||
)
|
||||
else:
|
||||
if args.noise_offset is None:
|
||||
args.noise_offset = DEFAULT_NOISE_OFFSET
|
||||
elif args.noise_offset != DEFAULT_NOISE_OFFSET:
|
||||
print(
|
||||
f"Waring: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET} / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されています"
|
||||
)
|
||||
print(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました")
|
||||
|
||||
assert (
|
||||
not args.weighted_captions
|
||||
), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"
|
||||
@@ -798,6 +798,19 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
def is_latent_cacheable(self):
|
||||
return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])
|
||||
|
||||
def is_text_encoder_output_cacheable(self):
|
||||
return all(
|
||||
[
|
||||
not (
|
||||
subset.caption_dropout_rate > 0
|
||||
or subset.shuffle_caption
|
||||
or subset.token_warmup_step > 0
|
||||
or subset.caption_tag_dropout_rate > 0
|
||||
)
|
||||
for subset in self.subsets
|
||||
]
|
||||
)
|
||||
|
||||
def is_disk_cached_latents_is_expected(self, reso, npz_path, flipped_npz_path):
|
||||
expected_latents_size = (reso[1] // 8, reso[0] // 8) # bucket_resoはWxHなので注意
|
||||
|
||||
@@ -850,7 +863,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
continue
|
||||
|
||||
cache_available = self.is_disk_cached_latents_is_expected(
|
||||
info.bucket_reso, info.latents_npz, info.latents_npz_flipped if self.flip_aug else None
|
||||
info.bucket_reso, info.latents_npz, info.latents_npz_flipped if subset.flip_aug else None
|
||||
)
|
||||
|
||||
if cache_available: # do not add to batch
|
||||
@@ -1719,6 +1732,9 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
def is_latent_cacheable(self) -> bool:
|
||||
return all([dataset.is_latent_cacheable() for dataset in self.datasets])
|
||||
|
||||
def is_text_encoder_output_cacheable(self) -> bool:
|
||||
return all([dataset.is_text_encoder_output_cacheable() for dataset in self.datasets])
|
||||
|
||||
def set_current_epoch(self, epoch):
|
||||
for dataset in self.datasets:
|
||||
dataset.set_current_epoch(epoch)
|
||||
@@ -3284,11 +3300,17 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une
|
||||
return text_encoder, vae, unet, load_stable_diffusion_format
|
||||
|
||||
|
||||
# TODO remove this function in the future
|
||||
def transform_if_model_is_DDP(text_encoder, unet, network=None):
|
||||
# Transform text_encoder, unet and network from DistributedDataParallel
|
||||
return (model.module if type(model) == DDP else model for model in [text_encoder, unet, network] if model is not None)
|
||||
|
||||
|
||||
def transform_models_if_DDP(models):
|
||||
# Transform text_encoder, unet and network from DistributedDataParallel
|
||||
return [model.module if type(model) == DDP else model for model in models if model is not None]
|
||||
|
||||
|
||||
def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=False):
|
||||
# load models for each process
|
||||
for pi in range(accelerator.state.num_processes):
|
||||
@@ -3430,6 +3452,42 @@ def save_sd_model_on_epoch_end_or_stepwise(
|
||||
text_encoder,
|
||||
unet,
|
||||
vae,
|
||||
):
|
||||
def sd_saver(ckpt_file, epoch_no, global_step):
|
||||
model_util.save_stable_diffusion_checkpoint(
|
||||
args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae
|
||||
)
|
||||
|
||||
def diffusers_saver(out_dir):
|
||||
model_util.save_diffusers_checkpoint(
|
||||
args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors
|
||||
)
|
||||
|
||||
save_sd_model_on_epoch_end_or_stepwise_common(
|
||||
args,
|
||||
on_epoch_end,
|
||||
accelerator,
|
||||
save_stable_diffusion_format,
|
||||
use_safetensors,
|
||||
epoch,
|
||||
num_train_epochs,
|
||||
global_step,
|
||||
sd_saver,
|
||||
diffusers_saver,
|
||||
)
|
||||
|
||||
|
||||
def save_sd_model_on_epoch_end_or_stepwise_common(
|
||||
args: argparse.Namespace,
|
||||
on_epoch_end: bool,
|
||||
accelerator,
|
||||
save_stable_diffusion_format: bool,
|
||||
use_safetensors: bool,
|
||||
epoch: int,
|
||||
num_train_epochs: int,
|
||||
global_step: int,
|
||||
sd_saver,
|
||||
diffusers_saver,
|
||||
):
|
||||
if on_epoch_end:
|
||||
epoch_no = epoch + 1
|
||||
@@ -3457,9 +3515,7 @@ def save_sd_model_on_epoch_end_or_stepwise(
|
||||
|
||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||
print(f"\nsaving checkpoint: {ckpt_file}")
|
||||
model_util.save_stable_diffusion_checkpoint(
|
||||
args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae
|
||||
)
|
||||
sd_saver(ckpt_file, epoch_no, global_step)
|
||||
|
||||
if args.huggingface_repo_id is not None:
|
||||
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name)
|
||||
@@ -3483,9 +3539,8 @@ def save_sd_model_on_epoch_end_or_stepwise(
|
||||
out_dir = os.path.join(args.output_dir, STEP_DIFFUSERS_DIR_NAME.format(model_name, global_step))
|
||||
|
||||
print(f"\nsaving model: {out_dir}")
|
||||
model_util.save_diffusers_checkpoint(
|
||||
args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors
|
||||
)
|
||||
diffusers_saver(out_dir)
|
||||
|
||||
if args.huggingface_repo_id is not None:
|
||||
huggingface_util.upload(args, out_dir, "/" + model_name)
|
||||
|
||||
@@ -3578,6 +3633,30 @@ def save_sd_model_on_train_end(
|
||||
text_encoder,
|
||||
unet,
|
||||
vae,
|
||||
):
|
||||
def sd_saver(ckpt_file, epoch_no, global_step):
|
||||
model_util.save_stable_diffusion_checkpoint(
|
||||
args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae
|
||||
)
|
||||
|
||||
def diffusers_saver(out_dir):
|
||||
model_util.save_diffusers_checkpoint(
|
||||
args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors
|
||||
)
|
||||
|
||||
save_sd_model_on_train_end_common(
|
||||
args, save_stable_diffusion_format, use_safetensors, epoch, global_step, sd_saver, diffusers_saver
|
||||
)
|
||||
|
||||
|
||||
def save_sd_model_on_train_end_common(
|
||||
args: argparse.Namespace,
|
||||
save_stable_diffusion_format: bool,
|
||||
use_safetensors: bool,
|
||||
epoch: int,
|
||||
global_step: int,
|
||||
sd_saver,
|
||||
diffusers_saver,
|
||||
):
|
||||
model_name = default_if_none(args.output_name, DEFAULT_LAST_OUTPUT_NAME)
|
||||
|
||||
@@ -3588,9 +3667,8 @@ def save_sd_model_on_train_end(
|
||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||
|
||||
print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}")
|
||||
model_util.save_stable_diffusion_checkpoint(
|
||||
args.v2, ckpt_file, text_encoder, unet, src_path, epoch, global_step, save_dtype, vae
|
||||
)
|
||||
sd_saver(ckpt_file, epoch, global_step)
|
||||
|
||||
if args.huggingface_repo_id is not None:
|
||||
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True)
|
||||
else:
|
||||
@@ -3598,9 +3676,8 @@ def save_sd_model_on_train_end(
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
print(f"save trained model as Diffusers to {out_dir}")
|
||||
model_util.save_diffusers_checkpoint(
|
||||
args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors
|
||||
)
|
||||
diffusers_saver(out_dir)
|
||||
|
||||
if args.huggingface_repo_id is not None:
|
||||
huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user