From 61a61c51ee59184f24440a629afb2dcb760f0136 Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Sun, 1 Jan 2023 21:46:38 +0900 Subject: [PATCH 01/26] Add --save_last_n_epochs option --- train_db.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/train_db.py b/train_db.py index 1dde882c..5a1fb721 100644 --- a/train_db.py +++ b/train_db.py @@ -27,6 +27,7 @@ import itertools import math import os import random +import shutil from tqdm import tqdm import torch @@ -1101,16 +1102,28 @@ def train(args): ckpt_file = os.path.join(args.output_dir, model_util.get_epoch_ckpt_name(use_safetensors, epoch + 1)) model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, unwrap_model(text_encoder), unwrap_model(unet), src_stable_diffusion_ckpt, epoch + 1, global_step, save_dtype, vae) + if args.save_last_n_epochs is not None: + old_ckpt_file = os.path.join(args.output_dir, model_util.get_epoch_ckpt_name(use_safetensors, epoch + 1 - args.save_last_n_epochs)) + if os.path.exists(old_ckpt_file): + os.remove(old_ckpt_file) else: out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1)) os.makedirs(out_dir, exist_ok=True) model_util.save_diffusers_checkpoint(args.v2, out_dir, unwrap_model(text_encoder), unwrap_model(unet), src_diffusers_model_path, use_safetensors=use_safetensors) + if args.save_last_n_epochs is not None: + out_dir_old = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1 - args.save_last_n_epochs)) + if os.path.exists(out_dir_old): + shutil.rmtree(out_dir_old) if args.save_state: print("saving state.") accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(epoch + 1))) + if args.save_last_n_epochs is not None: + state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(epoch + 1 - args.save_last_n_epochs)) + if os.path.exists(state_dir_old): + shutil.rmtree(state_dir_old) is_main_process = accelerator.is_main_process if is_main_process: @@ -1173,6 +1186,8 @@ if __name__ == '__main__': help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存する(save_model_as未指定時)") parser.add_argument("--save_every_n_epochs", type=int, default=None, help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する") + parser.add_argument("--save_last_n_epochs", type=int, default=None, + help="save last N checkpoints / 最大Nエポック保存する") parser.add_argument("--save_state", action="store_true", help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する") parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate") From 85d8b4912955228ab57f194b0eba951a338849b0 Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Sun, 1 Jan 2023 23:36:20 +0900 Subject: [PATCH 02/26] Fix calculation for the old epoch --- train_db.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/train_db.py b/train_db.py index 5a1fb721..a0e2357f 100644 --- a/train_db.py +++ b/train_db.py @@ -1103,7 +1103,7 @@ def train(args): model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, unwrap_model(text_encoder), unwrap_model(unet), src_stable_diffusion_ckpt, epoch + 1, global_step, save_dtype, vae) if args.save_last_n_epochs is not None: - old_ckpt_file = os.path.join(args.output_dir, model_util.get_epoch_ckpt_name(use_safetensors, epoch + 1 - args.save_last_n_epochs)) + old_ckpt_file = os.path.join(args.output_dir, model_util.get_epoch_ckpt_name(use_safetensors, epoch + 1 - args.save_every_n_epochs * args.save_last_n_epochs)) if os.path.exists(old_ckpt_file): os.remove(old_ckpt_file) else: @@ -1113,7 +1113,7 @@ def train(args): unwrap_model(unet), src_diffusers_model_path, use_safetensors=use_safetensors) if args.save_last_n_epochs is not None: - out_dir_old = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1 - args.save_last_n_epochs)) + out_dir_old = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1 - args.save_every_n_epochs * args.save_last_n_epochs)) if os.path.exists(out_dir_old): shutil.rmtree(out_dir_old) @@ -1121,7 +1121,7 @@ def train(args): print("saving state.") accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(epoch + 1))) if args.save_last_n_epochs is not None: - state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(epoch + 1 - args.save_last_n_epochs)) + state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(epoch + 1 - args.save_every_n_epochs * args.save_last_n_epochs)) if os.path.exists(state_dir_old): shutil.rmtree(state_dir_old) From 6b522b34c1529ac222ac6ee469a6023492627b55 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 2 Jan 2023 16:08:21 +0900 Subject: [PATCH 03/26] move code for xformers to train_util --- fine_tune.py | 328 +-------------- library/train_util.py | 910 ++++++++++++++++++++++++++++++++++++++++++ train_db.py | 307 +------------- train_network.py | 896 +---------------------------------------- 4 files changed, 940 insertions(+), 1501 deletions(-) create mode 100644 library/train_util.py diff --git a/fine_tune.py b/fine_tune.py index 49d84dcc..53ace2e8 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -1,35 +1,4 @@ -# v2: select precision for saved checkpoint -# v3: add logging for tensorboard, fix to shuffle=False in DataLoader (shuffling is in dataset) -# v4: support SD2.0, add lr scheduler options, supports save_every_n_epochs and save_state for DiffUsers model -# v5: refactor to use model_util, support safetensors, add settings to use Diffusers' xformers, add log prefix -# v6: model_util update -# v7: support Diffusers 0.10.0 (v-parameterization training, safetensors in Diffusers) and accelerate 0.15.0, support full path in metadata -# v8: experimental full fp16 training. -# v9: add keep_tokens and save_model_as option, flip augmentation - -# このスクリプトのライセンスは、train_dreambooth.pyと同じくApache License 2.0とします -# License: -# Copyright 2022 Kohya S. @kohya_ss -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# License of included scripts: - -# Diffusers: ASL 2.0 https://github.com/huggingface/diffusers/blob/main/LICENSE - -# Memory efficient attention: -# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py -# MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE +# training with captions import argparse import math @@ -51,17 +20,7 @@ from einops import rearrange from torch import einsum import library.model_util as model_util - -# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う -TOKENIZER_PATH = "openai/clip-vit-large-patch14" -V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ - -# checkpointファイル名 -EPOCH_STATE_NAME = "epoch-{:06d}-state" -LAST_STATE_NAME = "last-state" - -LAST_DIFFUSERS_DIR_NAME = "last" -EPOCH_DIFFUSERS_DIR_NAME = "epoch-{:06d}" +import library.train_util as train_util def collate_fn(examples): @@ -288,9 +247,9 @@ def train(args): # tokenizerを読み込む print("prepare tokenizer") if args.v2: - tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer") + tokenizer = CLIPTokenizer.from_pretrained(train_util.V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer") else: - tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) + tokenizer = CLIPTokenizer.from_pretrained(train_util.TOKENIZER_PATH) if args.max_token_length is not None: print(f"update token length: {args.max_token_length}") @@ -403,7 +362,7 @@ def train(args): # Windows版のxformersはfloatで学習できないのでxformersを使わない設定も可能にしておく必要がある print("Disable Diffusers' xformers") set_diffusers_xformers_flag(unet, False) - replace_unet_modules(unet, args.mem_eff_attn, args.xformers) + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) if not fine_tuning: # Hypernetwork @@ -667,7 +626,7 @@ def train(args): model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, unwrap_model(text_encoder), unwrap_model(unet), src_stable_diffusion_ckpt, epoch + 1, global_step, save_dtype, vae) else: - out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1)) + out_dir = os.path.join(args.output_dir, train_util.EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1)) os.makedirs(out_dir, exist_ok=True) model_util.save_diffusers_checkpoint(args.v2, out_dir, unwrap_model(text_encoder), unwrap_model(unet), src_diffusers_model_path, vae=vae, use_safetensors=use_safetensors) @@ -676,7 +635,7 @@ def train(args): if args.save_state: print("saving state.") - accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(epoch + 1))) + accelerator.save_state(os.path.join(args.output_dir, train_util.EPOCH_STATE_NAME.format(epoch + 1))) is_main_process = accelerator.is_main_process if is_main_process: @@ -690,7 +649,7 @@ def train(args): if args.save_state: print("saving last state.") - accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME)) + accelerator.save_state(os.path.join(args.output_dir, train_util.LAST_STATE_NAME)) del accelerator # この後メモリを使うのでこれは消す @@ -706,7 +665,7 @@ def train(args): else: # Create the pipeline using using the trained modules and save it. print(f"save trained model as Diffusers to {args.output_dir}") - out_dir = os.path.join(args.output_dir, LAST_DIFFUSERS_DIR_NAME) + out_dir = os.path.join(args.output_dir, train_util.LAST_DIFFUSERS_DIR_NAME) os.makedirs(out_dir, exist_ok=True) model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet, src_diffusers_model_path, vae=vae, use_safetensors=use_safetensors) @@ -717,275 +676,6 @@ def train(args): print("model saved.") -# region モジュール入れ替え部 -""" -高速化のためのモジュール入れ替え -""" - -# FlashAttentionを使うCrossAttention -# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py -# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE - -# constants - -EPSILON = 1e-6 - -# helper functions - - -def exists(val): - return val is not None - - -def default(val, d): - return val if exists(val) else d - -# flash attention forwards and backwards - -# https://arxiv.org/abs/2205.14135 - - -class FlashAttentionFunction(torch.autograd.function.Function): - @ staticmethod - @ torch.no_grad() - def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): - """ Algorithm 2 in the paper """ - - device = q.device - dtype = q.dtype - max_neg_value = -torch.finfo(q.dtype).max - qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) - - o = torch.zeros_like(q) - all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device) - all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device) - - scale = (q.shape[-1] ** -0.5) - - if not exists(mask): - mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) - else: - mask = rearrange(mask, 'b n -> b 1 1 n') - mask = mask.split(q_bucket_size, dim=-1) - - row_splits = zip( - q.split(q_bucket_size, dim=-2), - o.split(q_bucket_size, dim=-2), - mask, - all_row_sums.split(q_bucket_size, dim=-2), - all_row_maxes.split(q_bucket_size, dim=-2), - ) - - for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): - q_start_index = ind * q_bucket_size - qk_len_diff - - col_splits = zip( - k.split(k_bucket_size, dim=-2), - v.split(k_bucket_size, dim=-2), - ) - - for k_ind, (kc, vc) in enumerate(col_splits): - k_start_index = k_ind * k_bucket_size - - attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale - - if exists(row_mask): - attn_weights.masked_fill_(~row_mask, max_neg_value) - - if causal and q_start_index < (k_start_index + k_bucket_size - 1): - causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, - device=device).triu(q_start_index - k_start_index + 1) - attn_weights.masked_fill_(causal_mask, max_neg_value) - - block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) - attn_weights -= block_row_maxes - exp_weights = torch.exp(attn_weights) - - if exists(row_mask): - exp_weights.masked_fill_(~row_mask, 0.) - - block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON) - - new_row_maxes = torch.maximum(block_row_maxes, row_maxes) - - exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc) - - exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) - exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) - - new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums - - oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values) - - row_maxes.copy_(new_row_maxes) - row_sums.copy_(new_row_sums) - - ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) - ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) - - return o - - @ staticmethod - @ torch.no_grad() - def backward(ctx, do): - """ Algorithm 4 in the paper """ - - causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args - q, k, v, o, l, m = ctx.saved_tensors - - device = q.device - - max_neg_value = -torch.finfo(q.dtype).max - qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) - - dq = torch.zeros_like(q) - dk = torch.zeros_like(k) - dv = torch.zeros_like(v) - - row_splits = zip( - q.split(q_bucket_size, dim=-2), - o.split(q_bucket_size, dim=-2), - do.split(q_bucket_size, dim=-2), - mask, - l.split(q_bucket_size, dim=-2), - m.split(q_bucket_size, dim=-2), - dq.split(q_bucket_size, dim=-2) - ) - - for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): - q_start_index = ind * q_bucket_size - qk_len_diff - - col_splits = zip( - k.split(k_bucket_size, dim=-2), - v.split(k_bucket_size, dim=-2), - dk.split(k_bucket_size, dim=-2), - dv.split(k_bucket_size, dim=-2), - ) - - for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): - k_start_index = k_ind * k_bucket_size - - attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale - - if causal and q_start_index < (k_start_index + k_bucket_size - 1): - causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, - device=device).triu(q_start_index - k_start_index + 1) - attn_weights.masked_fill_(causal_mask, max_neg_value) - - exp_attn_weights = torch.exp(attn_weights - mc) - - if exists(row_mask): - exp_attn_weights.masked_fill_(~row_mask, 0.) - - p = exp_attn_weights / lc - - dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc) - dp = einsum('... i d, ... j d -> ... i j', doc, vc) - - D = (doc * oc).sum(dim=-1, keepdims=True) - ds = p * scale * (dp - D) - - dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc) - dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc) - - dqc.add_(dq_chunk) - dkc.add_(dk_chunk) - dvc.add_(dv_chunk) - - return dq, dk, dv, None, None, None, None - - -def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers): - if mem_eff_attn: - replace_unet_cross_attn_to_memory_efficient() - elif xformers: - replace_unet_cross_attn_to_xformers() - - -def replace_unet_cross_attn_to_memory_efficient(): - print("Replace CrossAttention.forward to use FlashAttention (not xformers)") - flash_func = FlashAttentionFunction - - def forward_flash_attn(self, x, context=None, mask=None): - q_bucket_size = 512 - k_bucket_size = 1024 - - h = self.heads - q = self.to_q(x) - - context = context if context is not None else x - context = context.to(x.dtype) - - if hasattr(self, 'hypernetwork') and self.hypernetwork is not None: - context_k, context_v = self.hypernetwork.forward(x, context) - context_k = context_k.to(x.dtype) - context_v = context_v.to(x.dtype) - else: - context_k = context - context_v = context - - k = self.to_k(context_k) - v = self.to_v(context_v) - del context, x - - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) - - out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size) - - out = rearrange(out, 'b h n d -> b n (h d)') - - # diffusers 0.7.0~ わざわざ変えるなよ (;´Д`) - out = self.to_out[0](out) - out = self.to_out[1](out) - return out - - diffusers.models.attention.CrossAttention.forward = forward_flash_attn - - -def replace_unet_cross_attn_to_xformers(): - print("Replace CrossAttention.forward to use xformers") - try: - import xformers.ops - except ImportError: - raise ImportError("No xformers / xformersがインストールされていないようです") - - def forward_xformers(self, x, context=None, mask=None): - h = self.heads - q_in = self.to_q(x) - - context = default(context, x) - context = context.to(x.dtype) - - if hasattr(self, 'hypernetwork') and self.hypernetwork is not None: - context_k, context_v = self.hypernetwork.forward(x, context) - context_k = context_k.to(x.dtype) - context_v = context_v.to(x.dtype) - else: - context_k = context - context_v = context - - k_in = self.to_k(context_k) - v_in = self.to_v(context_v) - - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) - del q_in, k_in, v_in - - q = q.contiguous() - k = k.contiguous() - v = v.contiguous() - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる - - out = rearrange(out, 'b n h d -> b n (h d)', h=h) - - # diffusers 0.7.0~ - out = self.to_out[0](out) - out = self.to_out[1](out) - return out - - diffusers.models.attention.CrossAttention.forward = forward_xformers -# endregion - - if __name__ == '__main__': # torch.cuda.set_per_process_memory_fraction(0.48) parser = argparse.ArgumentParser() diff --git a/library/train_util.py b/library/train_util.py new file mode 100644 index 00000000..f525a431 --- /dev/null +++ b/library/train_util.py @@ -0,0 +1,910 @@ +# common functions for training + +import json +from typing import NamedTuple +from torch.autograd.function import Function +import glob +import math +import os +import random + +from tqdm import tqdm +import torch +from torchvision import transforms +from transformers import CLIPTokenizer +import diffusers +import albumentations as albu +import numpy as np +from PIL import Image +import cv2 +from einops import rearrange +from torch import einsum + +import library.model_util as model_util + +# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う +TOKENIZER_PATH = "openai/clip-vit-large-patch14" +V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ + +# checkpointファイル名 +EPOCH_STATE_NAME = "epoch-{:06d}-state" +LAST_STATE_NAME = "last-state" + +EPOCH_FILE_NAME = "epoch-{:06d}" +LAST_FILE_NAME = "last" + + +# region dataset + +class ImageInfo(): + def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None: + self.image_key: str = image_key + self.num_repeats: int = num_repeats + self.caption: str = caption + self.is_reg: bool = is_reg + self.absolute_path: str = absolute_path + self.image_size: tuple[int, int] = None + self.bucket_reso: tuple[int, int] = None + self.latents: torch.Tensor = None + self.latents_flipped: torch.Tensor = None + self.latents_npz: str = None + self.latents_npz_flipped: str = None + + +class BucketBatchIndex(NamedTuple): + bucket_index: int + batch_index: int + + +class BaseDataset(torch.utils.data.Dataset): + def __init__(self, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, flip_aug: bool, color_aug: bool, face_crop_aug_range, debug_dataset: bool) -> None: + super().__init__() + self.tokenizer: CLIPTokenizer = tokenizer + self.max_token_length = max_token_length + self.shuffle_caption = shuffle_caption + self.shuffle_keep_tokens = shuffle_keep_tokens + self.width, self.height = resolution + self.face_crop_aug_range = face_crop_aug_range + self.flip_aug = flip_aug + self.color_aug = color_aug + self.debug_dataset = debug_dataset + + self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2 + + # augmentation + flip_p = 0.5 if flip_aug else 0.0 + if color_aug: + # わりと弱めの色合いaugmentation:brightness/contrastあたりは画像のpixel valueの最大値・最小値を変えてしまうのでよくないのではという想定でgamma/hueあたりを触る + self.aug = albu.Compose([ + albu.OneOf([ + albu.HueSaturationValue(8, 0, 0, p=.5), + albu.RandomGamma((95, 105), p=.5), + ], p=.33), + albu.HorizontalFlip(p=flip_p) + ], p=1.) + elif flip_aug: + self.aug = albu.Compose([ + albu.HorizontalFlip(p=flip_p) + ], p=1.) + else: + self.aug = None + + self.image_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ]) + + self.image_data: dict[str, ImageInfo] = {} + + def process_caption(self, caption): + if self.shuffle_caption: + tokens = caption.strip().split(",") + if self.shuffle_keep_tokens is None: + random.shuffle(tokens) + else: + if len(tokens) > self.shuffle_keep_tokens: + keep_tokens = tokens[:self.shuffle_keep_tokens] + tokens = tokens[self.shuffle_keep_tokens:] + random.shuffle(tokens) + tokens = keep_tokens + tokens + caption = ",".join(tokens).strip() + return caption + + def get_input_ids(self, caption): + input_ids = self.tokenizer(caption, padding="max_length", truncation=True, + max_length=self.tokenizer_max_length, return_tensors="pt").input_ids + + if self.tokenizer_max_length > self.tokenizer.model_max_length: + input_ids = input_ids.squeeze(0) + iids_list = [] + if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id: + # v1 + # 77以上の時は " .... " でトータル227とかになっているので、"..."の三連に変換する + # 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に + for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2): # (1, 152, 75) + ids_chunk = (input_ids[0].unsqueeze(0), + input_ids[i:i + self.tokenizer.model_max_length - 2], + input_ids[-1].unsqueeze(0)) + ids_chunk = torch.cat(ids_chunk) + iids_list.append(ids_chunk) + else: + # v2 + # 77以上の時は " .... ..." でトータル227とかになっているので、"... ..."の三連に変換する + for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2): + ids_chunk = (input_ids[0].unsqueeze(0), # BOS + input_ids[i:i + self.tokenizer.model_max_length - 2], + input_ids[-1].unsqueeze(0)) # PAD or EOS + ids_chunk = torch.cat(ids_chunk) + + # 末尾が または の場合は、何もしなくてよい + # 末尾が x の場合は末尾を に変える(x なら結果的に変化なし) + if ids_chunk[-2] != self.tokenizer.eos_token_id and ids_chunk[-2] != self.tokenizer.pad_token_id: + ids_chunk[-1] = self.tokenizer.eos_token_id + # 先頭が ... の場合は ... に変える + if ids_chunk[1] == self.tokenizer.pad_token_id: + ids_chunk[1] = self.tokenizer.eos_token_id + + iids_list.append(ids_chunk) + + input_ids = torch.stack(iids_list) # 3,77 + return input_ids + + def register_image(self, info: ImageInfo): + self.image_data[info.image_key] = info + + def make_buckets(self, enable_bucket, min_size, max_size): + ''' + bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る) + min_size and max_size are ignored when enable_bucket is False + ''' + + self.enable_bucket = enable_bucket + + print("loading image sizes.") + for info in tqdm(self.image_data.values()): + if info.image_size is None: + info.image_size = self.get_image_size(info.absolute_path) + + if enable_bucket: + print("make buckets") + else: + print("prepare dataset") + + # bucketingを用意する + if enable_bucket: + bucket_resos, bucket_aspect_ratios = model_util.make_bucket_resolutions((self.width, self.height), min_size, max_size) + else: + # bucketはひとつだけ、すべての画像は同じ解像度 + bucket_resos = [(self.width, self.height)] + bucket_aspect_ratios = [self.width / self.height] + bucket_aspect_ratios = np.array(bucket_aspect_ratios) + + # bucketを作成する + if enable_bucket: + img_ar_errors = [] + for image_info in self.image_data.values(): + # bucketを決める + image_width, image_height = image_info.image_size + aspect_ratio = image_width / image_height + ar_errors = bucket_aspect_ratios - aspect_ratio + + bucket_id = np.abs(ar_errors).argmin() + image_info.bucket_reso = bucket_resos[bucket_id] + + ar_error = ar_errors[bucket_id] + img_ar_errors.append(ar_error) + else: + reso = (self.width, self.height) + for image_info in self.image_data.values(): + image_info.bucket_reso = reso + + # 画像をbucketに分割する + self.buckets: list[str] = [[] for _ in range(len(bucket_resos))] + reso_to_index = {} + for i, reso in enumerate(bucket_resos): + reso_to_index[reso] = i + + for image_info in self.image_data.values(): + bucket_index = reso_to_index[image_info.bucket_reso] + for _ in range(image_info.num_repeats): + self.buckets[bucket_index].append(image_info.image_key) + + if enable_bucket: + print("number of images (including repeats for DreamBooth) / 各bucketの画像枚数(DreamBoothの場合は繰り返し回数を含む)") + for i, (reso, img_keys) in enumerate(zip(bucket_resos, self.buckets)): + print(f"bucket {i}: resolution {reso}, count: {len(img_keys)}") + img_ar_errors = np.array(img_ar_errors) + print(f"mean ar error (without repeats): {np.mean(np.abs(img_ar_errors))}") + + # 参照用indexを作る + self.buckets_indices: list(BucketBatchIndex) = [] + for bucket_index, bucket in enumerate(self.buckets): + batch_count = int(math.ceil(len(bucket) / self.batch_size)) + for batch_index in range(batch_count): + self.buckets_indices.append(BucketBatchIndex(bucket_index, batch_index)) + + self.shuffle_buckets() + self._length = len(self.buckets_indices) + + def shuffle_buckets(self): + random.shuffle(self.buckets_indices) + for bucket in self.buckets: + random.shuffle(bucket) + + def load_image(self, image_path): + image = Image.open(image_path) + if not image.mode == "RGB": + image = image.convert("RGB") + img = np.array(image, np.uint8) + return img + + def resize_and_trim(self, image, reso): + image_height, image_width = image.shape[0:2] + ar_img = image_width / image_height + ar_reso = reso[0] / reso[1] + if ar_img > ar_reso: # 横が長い→縦を合わせる + scale = reso[1] / image_height + else: + scale = reso[0] / image_width + resized_size = (int(image_width * scale + .5), int(image_height * scale + .5)) + + image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ + if resized_size[0] > reso[0]: + trim_size = resized_size[0] - reso[0] + image = image[:, trim_size//2:trim_size//2 + reso[0]] + elif resized_size[1] > reso[1]: + trim_size = resized_size[1] - reso[1] + image = image[trim_size//2:trim_size//2 + reso[1]] + assert image.shape[0] == reso[1] and image.shape[1] == reso[0], \ + f"internal error, illegal trimmed size: {image.shape}, {reso}" + return image + + def cache_latents(self, vae): + print("caching latents.") + for info in tqdm(self.image_data.values()): + if info.latents_npz is not None: + info.latents = self.load_latents_from_npz(info, False) + info.latents = torch.FloatTensor(info.latents) + info.latents_flipped = self.load_latents_from_npz(info, True) + info.latents_flipped = torch.FloatTensor(info.latents_flipped) + continue + + image = self.load_image(info.absolute_path) + image = self.resize_and_trim(image, info.bucket_reso) + + img_tensor = self.image_transforms(image) + img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype) + info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu") + + if self.flip_aug: + image = image[:, ::-1].copy() # cannot convert to Tensor without copy + img_tensor = self.image_transforms(image) + img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype) + info.latents_flipped = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu") + + def get_image_size(self, image_path): + image = Image.open(image_path) + return image.size + + def load_image_with_face_info(self, image_path: str): + img = self.load_image(image_path) + + face_cx = face_cy = face_w = face_h = 0 + if self.face_crop_aug_range is not None: + tokens = os.path.splitext(os.path.basename(image_path))[0].split('_') + if len(tokens) >= 5: + face_cx = int(tokens[-4]) + face_cy = int(tokens[-3]) + face_w = int(tokens[-2]) + face_h = int(tokens[-1]) + + return img, face_cx, face_cy, face_w, face_h + + # いい感じに切り出す + def crop_target(self, image, face_cx, face_cy, face_w, face_h): + height, width = image.shape[0:2] + if height == self.height and width == self.width: + return image + + # 画像サイズはsizeより大きいのでリサイズする + face_size = max(face_w, face_h) + min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率) + min_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[1]))) # 指定した顔最小サイズ + max_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[0]))) # 指定した顔最大サイズ + if min_scale >= max_scale: # range指定がmin==max + scale = min_scale + else: + scale = random.uniform(min_scale, max_scale) + + nh = int(height * scale + .5) + nw = int(width * scale + .5) + assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}" + image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA) + face_cx = int(face_cx * scale + .5) + face_cy = int(face_cy * scale + .5) + height, width = nh, nw + + # 顔を中心として448*640とかへ切り出す + for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))): + p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置 + + if self.random_crop: + # 背景も含めるために顔を中心に置く確率を高めつつずらす + range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう + p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数 + else: + # range指定があるときのみ、すこしだけランダムに(わりと適当) + if self.face_crop_aug_range[0] != self.face_crop_aug_range[1]: + if face_size > self.size // 10 and face_size >= 40: + p1 = p1 + random.randint(-face_size // 20, +face_size // 20) + + p1 = max(0, min(p1, length - target_size)) + + if axis == 0: + image = image[p1:p1 + target_size, :] + else: + image = image[:, p1:p1 + target_size] + + return image + + def load_latents_from_npz(self, image_info: ImageInfo, flipped): + npz_file = image_info.latents_npz_flipped if flipped else image_info.latents_npz + return np.load(npz_file)['arr_0'] + + def __len__(self): + return self._length + + def __getitem__(self, index): + if index == 0: + self.shuffle_buckets() + + bucket = self.buckets[self.buckets_indices[index].bucket_index] + image_index = self.buckets_indices[index].batch_index * self.batch_size + + loss_weights = [] + captions = [] + input_ids_list = [] + latents_list = [] + images = [] + + for image_key in bucket[image_index:image_index + self.batch_size]: + image_info = self.image_data[image_key] + loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0) + + # image/latentsを処理する + if image_info.latents is not None: + latents = image_info.latents if not self.flip_aug or random.random() < .5 else image_info.latents_flipped + image = None + elif image_info.latents_npz is not None: + latents = self.load_latents_from_npz(image_info, self.flip_aug and random.random() >= .5) + latents = torch.FloatTensor(latents) + image = None + else: + # 画像を読み込み、必要ならcropする + img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(image_info.absolute_path) + im_h, im_w = img.shape[0:2] + + if self.enable_bucket: + img = self.resize_and_trim(img, image_info.bucket_reso) + else: + if face_cx > 0: # 顔位置情報あり + img = self.crop_target(img, face_cx, face_cy, face_w, face_h) + elif im_h > self.height or im_w > self.width: + assert self.random_crop, f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}" + if im_h > self.height: + p = random.randint(0, im_h - self.height) + img = img[p:p + self.height] + if im_w > self.width: + p = random.randint(0, im_w - self.width) + img = img[:, p:p + self.width] + + im_h, im_w = img.shape[0:2] + assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}" + + # augmentation + if self.aug is not None: + img = self.aug(image=img)['image'] + + latents = None + image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる + + images.append(image) + latents_list.append(latents) + + caption = self.process_caption(image_info.caption) + captions.append(caption) + input_ids_list.append(self.get_input_ids(caption)) + + example = {} + example['loss_weights'] = torch.FloatTensor(loss_weights) + example['input_ids'] = torch.stack(input_ids_list) + + if images[0] is not None: + images = torch.stack(images) + images = images.to(memory_format=torch.contiguous_format).float() + else: + images = None + example['images'] = images + + example['latents'] = torch.stack(latents_list) if latents_list[0] is not None else None + + if self.debug_dataset: + example['image_keys'] = bucket[image_index:image_index + self.batch_size] + example['captions'] = captions + return example + + +class DreamBoothDataset(BaseDataset): + def __init__(self, batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens, resolution, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset) -> None: + super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, + resolution, flip_aug, color_aug, face_crop_aug_range, debug_dataset) + + self.batch_size = batch_size + self.size = min(self.width, self.height) # 短いほう + self.prior_loss_weight = prior_loss_weight + self.random_crop = random_crop + self.latents_cache = None + self.enable_bucket = False + + def read_caption(img_path): + # captionの候補ファイル名を作る + base_name = os.path.splitext(img_path)[0] + base_name_face_det = base_name + tokens = base_name.split("_") + if len(tokens) >= 5: + base_name_face_det = "_".join(tokens[:-4]) + cap_paths = [base_name + caption_extension, base_name_face_det + caption_extension] + + caption = None + for cap_path in cap_paths: + if os.path.isfile(cap_path): + with open(cap_path, "rt", encoding='utf-8') as f: + lines = f.readlines() + assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}" + caption = lines[0].strip() + break + return caption + + def load_dreambooth_dir(dir): + if not os.path.isdir(dir): + # print(f"ignore file: {dir}") + return 0, [], [] + + tokens = os.path.basename(dir).split('_') + try: + n_repeats = int(tokens[0]) + except ValueError as e: + print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {dir}") + return 0, [], [] + + caption_by_folder = '_'.join(tokens[1:]) + img_paths = glob.glob(os.path.join(dir, "*.png")) + glob.glob(os.path.join(dir, "*.jpg")) + \ + glob.glob(os.path.join(dir, "*.webp")) + print(f"found directory {n_repeats}_{caption_by_folder} contains {len(img_paths)} image files") + + # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う + captions = [] + for img_path in img_paths: + cap_for_img = read_caption(img_path) + captions.append(caption_by_folder if cap_for_img is None else cap_for_img) + + return n_repeats, img_paths, captions + + print("prepare train images.") + train_dirs = os.listdir(train_data_dir) + num_train_images = 0 + for dir in train_dirs: + n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(train_data_dir, dir)) + num_train_images += n_repeats * len(img_paths) + for img_path, caption in zip(img_paths, captions): + info = ImageInfo(img_path, n_repeats, caption, False, img_path) + self.register_image(info) + print(f"{num_train_images} train images with repeating.") + self.num_train_images = num_train_images + + # reg imageは数を数えて学習画像と同じ枚数にする + num_reg_images = 0 + if reg_data_dir: + print("prepare reg images.") + reg_infos: list[ImageInfo] = [] + + reg_dirs = os.listdir(reg_data_dir) + for dir in reg_dirs: + n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(reg_data_dir, dir)) + num_reg_images += n_repeats * len(img_paths) + for img_path, caption in zip(img_paths, captions): + info = ImageInfo(img_path, n_repeats, caption, True, img_path) + reg_infos.append(info) + + print(f"{num_reg_images} reg images.") + if num_train_images < num_reg_images: + print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります") + + if num_reg_images == 0: + print("no regularization images / 正則化画像が見つかりませんでした") + else: + n = 0 + while n < num_train_images: + for info in reg_infos: + self.register_image(info) + n += info.num_repeats + if n >= num_train_images: # reg画像にnum_repeats>1のときはまずありえないので考慮しない + break + + self.num_reg_images = num_reg_images + + +class FineTuningDataset(BaseDataset): + def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, flip_aug, color_aug, face_crop_aug_range, dataset_repeats, debug_dataset) -> None: + super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, + resolution, flip_aug, color_aug, face_crop_aug_range, debug_dataset) + + # メタデータを読み込む + if os.path.exists(json_file_name): + print(f"loading existing metadata: {json_file_name}") + with open(json_file_name, "rt", encoding='utf-8') as f: + metadata = json.load(f) + else: + raise ValueError(f"no metadata / メタデータファイルがありません: {json_file_name}") + + self.metadata = metadata + self.train_data_dir = train_data_dir + self.batch_size = batch_size + + for image_key, img_md in metadata.items(): + # path情報を作る + if os.path.exists(image_key): + abs_path = image_key + else: + # わりといい加減だがいい方法が思いつかん + abs_path = (glob.glob(os.path.join(train_data_dir, f"{image_key}.png")) + glob.glob(os.path.join(train_data_dir, f"{image_key}.jpg")) + + glob.glob(os.path.join(train_data_dir, f"{image_key}.webp"))) + assert len(abs_path) >= 1, f"no image / 画像がありません: {abs_path}" + abs_path = abs_path[0] + + caption = img_md.get('caption') + tags = img_md.get('tags') + if caption is None: + caption = tags + elif tags is not None and len(tags) > 0: + caption = caption + ', ' + tags + assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}" + + image_info = ImageInfo(image_key, dataset_repeats, caption, False, abs_path) + image_info.image_size = img_md.get('train_resolution') + + if not self.color_aug: + # if npz exists, use them + image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(image_key) + + self.register_image(image_info) + self.num_train_images = len(metadata) * dataset_repeats + self.num_reg_images = 0 + + # check existence of all npz files + if not self.color_aug: + npz_any = False + npz_all = True + for image_info in self.image_data.values(): + has_npz = image_info.latents_npz is not None + npz_any = npz_any or has_npz + + if self.flip_aug: + has_npz = has_npz and image_info.latents_npz_flipped is not None + npz_all = npz_all and has_npz + + if npz_any and not npz_all: + break + + if not npz_any: + print(f"npz file does not exist. make latents with VAE / npzファイルが見つからないためVAEを使ってlatentsを取得します") + elif not npz_all: + print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します") + for image_info in self.image_data.values(): + image_info.latents_npz = image_info.latents_npz_flipped = None + + # check min/max bucket size + sizes = set() + for image_info in self.image_data.values(): + if image_info.image_size is None: + sizes = None # not calculated + break + sizes.add(image_info.image_size[0]) + sizes.add(image_info.image_size[1]) + + if sizes is None: + self.min_bucket_reso = self.max_bucket_reso = None # set as not calculated + else: + self.min_bucket_reso = min(sizes) + self.max_bucket_reso = max(sizes) + + def image_key_to_npz_file(self, image_key): + base_name = os.path.splitext(image_key)[0] + npz_file_norm = base_name + '.npz' + + if os.path.exists(npz_file_norm): + # image_key is full path + npz_file_flip = base_name + '_flip.npz' + if not os.path.exists(npz_file_flip): + npz_file_flip = None + return npz_file_norm, npz_file_flip + + # image_key is relative path + npz_file_norm = os.path.join(self.train_data_dir, image_key + '.npz') + npz_file_flip = os.path.join(self.train_data_dir, image_key + '_flip.npz') + + if not os.path.exists(npz_file_norm): + npz_file_norm = None + npz_file_flip = None + elif not os.path.exists(npz_file_flip): + npz_file_flip = None + + return npz_file_norm, npz_file_flip + +# endregion + + +# region モジュール入れ替え部 +""" +高速化のためのモジュール入れ替え +""" + +# FlashAttentionを使うCrossAttention +# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py +# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE + +# constants + +EPSILON = 1e-6 + +# helper functions + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + +# flash attention forwards and backwards + +# https://arxiv.org/abs/2205.14135 + + +class FlashAttentionFunction(torch.autograd.function.Function): + @ staticmethod + @ torch.no_grad() + def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): + """ Algorithm 2 in the paper """ + + device = q.device + dtype = q.dtype + max_neg_value = -torch.finfo(q.dtype).max + qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) + + o = torch.zeros_like(q) + all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device) + all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device) + + scale = (q.shape[-1] ** -0.5) + + if not exists(mask): + mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) + else: + mask = rearrange(mask, 'b n -> b 1 1 n') + mask = mask.split(q_bucket_size, dim=-1) + + row_splits = zip( + q.split(q_bucket_size, dim=-2), + o.split(q_bucket_size, dim=-2), + mask, + all_row_sums.split(q_bucket_size, dim=-2), + all_row_maxes.split(q_bucket_size, dim=-2), + ) + + for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): + q_start_index = ind * q_bucket_size - qk_len_diff + + col_splits = zip( + k.split(k_bucket_size, dim=-2), + v.split(k_bucket_size, dim=-2), + ) + + for k_ind, (kc, vc) in enumerate(col_splits): + k_start_index = k_ind * k_bucket_size + + attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale + + if exists(row_mask): + attn_weights.masked_fill_(~row_mask, max_neg_value) + + if causal and q_start_index < (k_start_index + k_bucket_size - 1): + causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, + device=device).triu(q_start_index - k_start_index + 1) + attn_weights.masked_fill_(causal_mask, max_neg_value) + + block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) + attn_weights -= block_row_maxes + exp_weights = torch.exp(attn_weights) + + if exists(row_mask): + exp_weights.masked_fill_(~row_mask, 0.) + + block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON) + + new_row_maxes = torch.maximum(block_row_maxes, row_maxes) + + exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc) + + exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) + exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) + + new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums + + oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values) + + row_maxes.copy_(new_row_maxes) + row_sums.copy_(new_row_sums) + + ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) + ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) + + return o + + @ staticmethod + @ torch.no_grad() + def backward(ctx, do): + """ Algorithm 4 in the paper """ + + causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args + q, k, v, o, l, m = ctx.saved_tensors + + device = q.device + + max_neg_value = -torch.finfo(q.dtype).max + qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) + + dq = torch.zeros_like(q) + dk = torch.zeros_like(k) + dv = torch.zeros_like(v) + + row_splits = zip( + q.split(q_bucket_size, dim=-2), + o.split(q_bucket_size, dim=-2), + do.split(q_bucket_size, dim=-2), + mask, + l.split(q_bucket_size, dim=-2), + m.split(q_bucket_size, dim=-2), + dq.split(q_bucket_size, dim=-2) + ) + + for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): + q_start_index = ind * q_bucket_size - qk_len_diff + + col_splits = zip( + k.split(k_bucket_size, dim=-2), + v.split(k_bucket_size, dim=-2), + dk.split(k_bucket_size, dim=-2), + dv.split(k_bucket_size, dim=-2), + ) + + for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): + k_start_index = k_ind * k_bucket_size + + attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale + + if causal and q_start_index < (k_start_index + k_bucket_size - 1): + causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, + device=device).triu(q_start_index - k_start_index + 1) + attn_weights.masked_fill_(causal_mask, max_neg_value) + + exp_attn_weights = torch.exp(attn_weights - mc) + + if exists(row_mask): + exp_attn_weights.masked_fill_(~row_mask, 0.) + + p = exp_attn_weights / lc + + dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc) + dp = einsum('... i d, ... j d -> ... i j', doc, vc) + + D = (doc * oc).sum(dim=-1, keepdims=True) + ds = p * scale * (dp - D) + + dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc) + dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc) + + dqc.add_(dq_chunk) + dkc.add_(dk_chunk) + dvc.add_(dv_chunk) + + return dq, dk, dv, None, None, None, None + + +def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers): + if mem_eff_attn: + replace_unet_cross_attn_to_memory_efficient() + elif xformers: + replace_unet_cross_attn_to_xformers() + + +def replace_unet_cross_attn_to_memory_efficient(): + print("Replace CrossAttention.forward to use FlashAttention (not xformers)") + flash_func = FlashAttentionFunction + + def forward_flash_attn(self, x, context=None, mask=None): + q_bucket_size = 512 + k_bucket_size = 1024 + + h = self.heads + q = self.to_q(x) + + context = context if context is not None else x + context = context.to(x.dtype) + + if hasattr(self, 'hypernetwork') and self.hypernetwork is not None: + context_k, context_v = self.hypernetwork.forward(x, context) + context_k = context_k.to(x.dtype) + context_v = context_v.to(x.dtype) + else: + context_k = context + context_v = context + + k = self.to_k(context_k) + v = self.to_v(context_v) + del context, x + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) + + out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size) + + out = rearrange(out, 'b h n d -> b n (h d)') + + # diffusers 0.7.0~ わざわざ変えるなよ (;´Д`) + out = self.to_out[0](out) + out = self.to_out[1](out) + return out + + diffusers.models.attention.CrossAttention.forward = forward_flash_attn + + +def replace_unet_cross_attn_to_xformers(): + print("Replace CrossAttention.forward to use xformers") + try: + import xformers.ops + except ImportError: + raise ImportError("No xformers / xformersがインストールされていないようです") + + def forward_xformers(self, x, context=None, mask=None): + h = self.heads + q_in = self.to_q(x) + + context = default(context, x) + context = context.to(x.dtype) + + if hasattr(self, 'hypernetwork') and self.hypernetwork is not None: + context_k, context_v = self.hypernetwork.forward(x, context) + context_k = context_k.to(x.dtype) + context_v = context_v.to(x.dtype) + else: + context_k = context + context_v = context + + k_in = self.to_k(context_k) + v_in = self.to_v(context_v) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in + + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる + + out = rearrange(out, 'b n h d -> b n (h d)', h=h) + + # diffusers 0.7.0~ + out = self.to_out[0](out) + out = self.to_out[1](out) + return out + + diffusers.models.attention.CrossAttention.forward = forward_xformers +# endregion diff --git a/train_db.py b/train_db.py index 1dde882c..35ae1212 100644 --- a/train_db.py +++ b/train_db.py @@ -1,22 +1,4 @@ -# このスクリプトのライセンスは、train_dreambooth.pyと同じくApache License 2.0とします -# (c) 2022 Kohya S. @kohya_ss - -# v7: another text encoder ckpt format, average loss, save epochs/global steps, show num of train/reg images, -# enable reg images in fine-tuning, add dataset_repeats option -# v8: supports Diffusers 0.7.2 -# v9: add bucketing option -# v10: add min_bucket_reso/max_bucket_reso options, read captions for train/reg images in DreamBooth -# v11: Diffusers 0.9.0 is required. support for Stable Diffusion 2.0/v-parameterization -# add lr scheduler options, change handling folder/file caption, support loading DiffUser model from Huggingface -# support save_ever_n_epochs/save_state in DiffUsers model -# fix the issue that prior_loss_weight is applied to train images -# v12: stop train text encode, tqdm smoothing -# v13: bug fix -# v14: refactor to use model_util, add log prefix, support safetensors, support vae loading, keep vae in CPU to save the loaded vae -# v15: model_util update -# v16: support Diffusers 0.10.0 (v-parameterization training, safetensors in Diffusers) and accelerate 0.15.0 -# v17: add fp16 gradient training (experimental) -# v18: add save_model_as option +# DreamBooth training import gc import time @@ -44,19 +26,8 @@ from einops import rearrange from torch import einsum import library.model_util as model_util - -# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う -TOKENIZER_PATH = "openai/clip-vit-large-patch14" -V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ - -# CLIP_ID_L14_336 = "openai/clip-vit-large-patch14-336" - -# checkpointファイル名 -EPOCH_STATE_NAME = "epoch-{:06d}-state" -LAST_STATE_NAME = "last-state" - -EPOCH_DIFFUSERS_DIR_NAME = "epoch-{:06d}" -LAST_DIFFUSERS_DIR_NAME = "last" +import library.train_util as train_util +from library.train_util import DreamBoothDataset, FineTuningDataset # region dataset @@ -392,266 +363,10 @@ class DreamBoothOrFineTuningDataset(torch.utils.data.Dataset): # endregion -# region モジュール入れ替え部 -""" -高速化のためのモジュール入れ替え -""" - -# FlashAttentionを使うCrossAttention -# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py -# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE - -# constants - -EPSILON = 1e-6 - -# helper functions - - -def exists(val): - return val is not None - - -def default(val, d): - return val if exists(val) else d - -# flash attention forwards and backwards - -# https://arxiv.org/abs/2205.14135 - - -class FlashAttentionFunction(Function): - @ staticmethod - @ torch.no_grad() - def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): - """ Algorithm 2 in the paper """ - - device = q.device - dtype = q.dtype - max_neg_value = -torch.finfo(q.dtype).max - qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) - - o = torch.zeros_like(q) - all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device) - all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device) - - scale = (q.shape[-1] ** -0.5) - - if not exists(mask): - mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) - else: - mask = rearrange(mask, 'b n -> b 1 1 n') - mask = mask.split(q_bucket_size, dim=-1) - - row_splits = zip( - q.split(q_bucket_size, dim=-2), - o.split(q_bucket_size, dim=-2), - mask, - all_row_sums.split(q_bucket_size, dim=-2), - all_row_maxes.split(q_bucket_size, dim=-2), - ) - - for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): - q_start_index = ind * q_bucket_size - qk_len_diff - - col_splits = zip( - k.split(k_bucket_size, dim=-2), - v.split(k_bucket_size, dim=-2), - ) - - for k_ind, (kc, vc) in enumerate(col_splits): - k_start_index = k_ind * k_bucket_size - - attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale - - if exists(row_mask): - attn_weights.masked_fill_(~row_mask, max_neg_value) - - if causal and q_start_index < (k_start_index + k_bucket_size - 1): - causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, - device=device).triu(q_start_index - k_start_index + 1) - attn_weights.masked_fill_(causal_mask, max_neg_value) - - block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) - attn_weights -= block_row_maxes - exp_weights = torch.exp(attn_weights) - - if exists(row_mask): - exp_weights.masked_fill_(~row_mask, 0.) - - block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON) - - new_row_maxes = torch.maximum(block_row_maxes, row_maxes) - - exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc) - - exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) - exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) - - new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums - - oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values) - - row_maxes.copy_(new_row_maxes) - row_sums.copy_(new_row_sums) - - ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) - ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) - - return o - - @ staticmethod - @ torch.no_grad() - def backward(ctx, do): - """ Algorithm 4 in the paper """ - - causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args - q, k, v, o, l, m = ctx.saved_tensors - - device = q.device - - max_neg_value = -torch.finfo(q.dtype).max - qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) - - dq = torch.zeros_like(q) - dk = torch.zeros_like(k) - dv = torch.zeros_like(v) - - row_splits = zip( - q.split(q_bucket_size, dim=-2), - o.split(q_bucket_size, dim=-2), - do.split(q_bucket_size, dim=-2), - mask, - l.split(q_bucket_size, dim=-2), - m.split(q_bucket_size, dim=-2), - dq.split(q_bucket_size, dim=-2) - ) - - for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): - q_start_index = ind * q_bucket_size - qk_len_diff - - col_splits = zip( - k.split(k_bucket_size, dim=-2), - v.split(k_bucket_size, dim=-2), - dk.split(k_bucket_size, dim=-2), - dv.split(k_bucket_size, dim=-2), - ) - - for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): - k_start_index = k_ind * k_bucket_size - - attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale - - if causal and q_start_index < (k_start_index + k_bucket_size - 1): - causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, - device=device).triu(q_start_index - k_start_index + 1) - attn_weights.masked_fill_(causal_mask, max_neg_value) - - exp_attn_weights = torch.exp(attn_weights - mc) - - if exists(row_mask): - exp_attn_weights.masked_fill_(~row_mask, 0.) - - p = exp_attn_weights / lc - - dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc) - dp = einsum('... i d, ... j d -> ... i j', doc, vc) - - D = (doc * oc).sum(dim=-1, keepdims=True) - ds = p * scale * (dp - D) - - dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc) - dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc) - - dqc.add_(dq_chunk) - dkc.add_(dk_chunk) - dvc.add_(dv_chunk) - - return dq, dk, dv, None, None, None, None - - -def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers): - if mem_eff_attn: - replace_unet_cross_attn_to_memory_efficient() - elif xformers: - replace_unet_cross_attn_to_xformers() - - -def replace_unet_cross_attn_to_memory_efficient(): - print("Replace CrossAttention.forward to use FlashAttention") - flash_func = FlashAttentionFunction - - def forward_flash_attn(self, x, context=None, mask=None): - q_bucket_size = 512 - k_bucket_size = 1024 - - h = self.heads - q = self.to_q(x) - - context = context if context is not None else x - context = context.to(x.dtype) - k = self.to_k(context) - v = self.to_v(context) - del context, x - - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) - - out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size) - - out = rearrange(out, 'b h n d -> b n (h d)') - - # diffusers 0.7.0~ - out = self.to_out[0](out) - out = self.to_out[1](out) - return out - - diffusers.models.attention.CrossAttention.forward = forward_flash_attn - - -def replace_unet_cross_attn_to_xformers(): - print("Replace CrossAttention.forward to use xformers") - try: - import xformers.ops - except ImportError: - raise ImportError("No xformers / xformersがインストールされていないようです") - - def forward_xformers(self, x, context=None, mask=None): - h = self.heads - q_in = self.to_q(x) - - context = default(context, x) - context = context.to(x.dtype) - - k_in = self.to_k(context) - v_in = self.to_v(context) - - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) # new format - # q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) # legacy format - del q_in, k_in, v_in - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる - - out = rearrange(out, 'b n h d -> b n (h d)', h=h) - # out = rearrange(out, '(b h) n d -> b n (h d)', h=h) - - # diffusers 0.7.0~ - out = self.to_out[0](out) - out = self.to_out[1](out) - return out - - diffusers.models.attention.CrossAttention.forward = forward_xformers -# endregion - - def collate_fn(examples): return examples[0] -# def load_clip_l14_336(dtype): -# print(f"loading CLIP: {CLIP_ID_L14_336}") -# text_encoder = CLIPTextModel.from_pretrained(CLIP_ID_L14_336, torch_dtype=dtype) -# return text_encoder - - def train(args): if args.caption_extention is not None: args.caption_extension = args.caption_extention @@ -679,7 +394,7 @@ def train(args): else: src_stable_diffusion_ckpt = None src_diffusers_model_path = args.pretrained_model_name_or_path - + if args.save_model_as is None: save_stable_diffusion_format = load_stable_diffusion_format use_safetensors = args.use_safetensors @@ -790,9 +505,9 @@ def train(args): # tokenizerを読み込む print("prepare tokenizer") if args.v2: - tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer") + tokenizer = CLIPTokenizer.from_pretrained(train_util.V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer") else: - tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) + tokenizer = CLIPTokenizer.from_pretrained(train_util.TOKENIZER_PATH) print("prepare dataset") train_dataset = DreamBoothOrFineTuningDataset(args.train_batch_size, fine_tuning, train_img_path_captions, reg_img_path_captions, tokenizer, resolution, @@ -884,7 +599,7 @@ def train(args): print("additional VAE loaded") # モデルに xformers とか memory efficient attention を組み込む - replace_unet_modules(unet, args.mem_eff_attn, args.xformers) + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) # 学習を準備する if cache_latents: @@ -1102,7 +817,7 @@ def train(args): model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, unwrap_model(text_encoder), unwrap_model(unet), src_stable_diffusion_ckpt, epoch + 1, global_step, save_dtype, vae) else: - out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1)) + out_dir = os.path.join(args.output_dir, train_util.EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1)) os.makedirs(out_dir, exist_ok=True) model_util.save_diffusers_checkpoint(args.v2, out_dir, unwrap_model(text_encoder), unwrap_model(unet), src_diffusers_model_path, @@ -1110,7 +825,7 @@ def train(args): if args.save_state: print("saving state.") - accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(epoch + 1))) + accelerator.save_state(os.path.join(args.output_dir, train_util.EPOCH_STATE_NAME.format(epoch + 1))) is_main_process = accelerator.is_main_process if is_main_process: @@ -1121,7 +836,7 @@ def train(args): if args.save_state: print("saving last state.") - accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME)) + accelerator.save_state(os.path.join(args.output_dir, train_util.LAST_STATE_NAME)) del accelerator # この後メモリを使うのでこれは消す @@ -1134,7 +849,7 @@ def train(args): src_stable_diffusion_ckpt, epoch, global_step, save_dtype, vae) else: print(f"save trained model as Diffusers to {args.output_dir}") - out_dir = os.path.join(args.output_dir, LAST_DIFFUSERS_DIR_NAME) + out_dir = os.path.join(args.output_dir, train_util.LAST_DIFFUSERS_DIR_NAME) os.makedirs(out_dir, exist_ok=True) model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet, src_diffusers_model_path, use_safetensors=use_safetensors) diff --git a/train_network.py b/train_network.py index f26ced8b..bcb9db8e 100644 --- a/train_network.py +++ b/train_network.py @@ -2,890 +2,23 @@ import gc import importlib import json import time -from typing import NamedTuple -from torch.autograd.function import Function import argparse -import glob import math import os -import random from tqdm import tqdm import torch -from torchvision import transforms from accelerate import Accelerator from accelerate.utils import set_seed from transformers import CLIPTokenizer import diffusers from diffusers import DDPMScheduler, StableDiffusionPipeline -import albumentations as albu import numpy as np -from PIL import Image import cv2 -from einops import rearrange -from torch import einsum import library.model_util as model_util - -# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う -TOKENIZER_PATH = "openai/clip-vit-large-patch14" -V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ - -# checkpointファイル名 -EPOCH_STATE_NAME = "epoch-{:06d}-state" -LAST_STATE_NAME = "last-state" - -EPOCH_FILE_NAME = "epoch-{:06d}" -LAST_FILE_NAME = "last" - - -# region dataset - -class ImageInfo(): - def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None: - self.image_key: str = image_key - self.num_repeats: int = num_repeats - self.caption: str = caption - self.is_reg: bool = is_reg - self.absolute_path: str = absolute_path - self.image_size: tuple[int, int] = None - self.bucket_reso: tuple[int, int] = None - self.latents: torch.Tensor = None - self.latents_flipped: torch.Tensor = None - self.latents_npz: str = None - self.latents_npz_flipped: str = None - - -class BucketBatchIndex(NamedTuple): - bucket_index: int - batch_index: int - - -class BaseDataset(torch.utils.data.Dataset): - def __init__(self, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, flip_aug: bool, color_aug: bool, face_crop_aug_range, debug_dataset: bool) -> None: - super().__init__() - self.tokenizer: CLIPTokenizer = tokenizer - self.max_token_length = max_token_length - self.shuffle_caption = shuffle_caption - self.shuffle_keep_tokens = shuffle_keep_tokens - self.width, self.height = resolution - self.face_crop_aug_range = face_crop_aug_range - self.flip_aug = flip_aug - self.color_aug = color_aug - self.debug_dataset = debug_dataset - - self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2 - - # augmentation - flip_p = 0.5 if flip_aug else 0.0 - if color_aug: - # わりと弱めの色合いaugmentation:brightness/contrastあたりは画像のpixel valueの最大値・最小値を変えてしまうのでよくないのではという想定でgamma/hueあたりを触る - self.aug = albu.Compose([ - albu.OneOf([ - albu.HueSaturationValue(8, 0, 0, p=.5), - albu.RandomGamma((95, 105), p=.5), - ], p=.33), - albu.HorizontalFlip(p=flip_p) - ], p=1.) - elif flip_aug: - self.aug = albu.Compose([ - albu.HorizontalFlip(p=flip_p) - ], p=1.) - else: - self.aug = None - - self.image_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ]) - - self.image_data: dict[str, ImageInfo] = {} - - def process_caption(self, caption): - if self.shuffle_caption: - tokens = caption.strip().split(",") - if self.shuffle_keep_tokens is None: - random.shuffle(tokens) - else: - if len(tokens) > self.shuffle_keep_tokens: - keep_tokens = tokens[:self.shuffle_keep_tokens] - tokens = tokens[self.shuffle_keep_tokens:] - random.shuffle(tokens) - tokens = keep_tokens + tokens - caption = ",".join(tokens).strip() - return caption - - def get_input_ids(self, caption): - input_ids = self.tokenizer(caption, padding="max_length", truncation=True, - max_length=self.tokenizer_max_length, return_tensors="pt").input_ids - - if self.tokenizer_max_length > self.tokenizer.model_max_length: - input_ids = input_ids.squeeze(0) - iids_list = [] - if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id: - # v1 - # 77以上の時は " .... " でトータル227とかになっているので、"..."の三連に変換する - # 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に - for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2): # (1, 152, 75) - ids_chunk = (input_ids[0].unsqueeze(0), - input_ids[i:i + self.tokenizer.model_max_length - 2], - input_ids[-1].unsqueeze(0)) - ids_chunk = torch.cat(ids_chunk) - iids_list.append(ids_chunk) - else: - # v2 - # 77以上の時は " .... ..." でトータル227とかになっているので、"... ..."の三連に変換する - for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2): - ids_chunk = (input_ids[0].unsqueeze(0), # BOS - input_ids[i:i + self.tokenizer.model_max_length - 2], - input_ids[-1].unsqueeze(0)) # PAD or EOS - ids_chunk = torch.cat(ids_chunk) - - # 末尾が または の場合は、何もしなくてよい - # 末尾が x の場合は末尾を に変える(x なら結果的に変化なし) - if ids_chunk[-2] != self.tokenizer.eos_token_id and ids_chunk[-2] != self.tokenizer.pad_token_id: - ids_chunk[-1] = self.tokenizer.eos_token_id - # 先頭が ... の場合は ... に変える - if ids_chunk[1] == self.tokenizer.pad_token_id: - ids_chunk[1] = self.tokenizer.eos_token_id - - iids_list.append(ids_chunk) - - input_ids = torch.stack(iids_list) # 3,77 - return input_ids - - def register_image(self, info: ImageInfo): - self.image_data[info.image_key] = info - - def make_buckets(self, enable_bucket, min_size, max_size): - ''' - bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る) - min_size and max_size are ignored when enable_bucket is False - ''' - - self.enable_bucket = enable_bucket - - print("loading image sizes.") - for info in tqdm(self.image_data.values()): - if info.image_size is None: - info.image_size = self.get_image_size(info.absolute_path) - - if enable_bucket: - print("make buckets") - else: - print("prepare dataset") - - # bucketingを用意する - if enable_bucket: - bucket_resos, bucket_aspect_ratios = model_util.make_bucket_resolutions((self.width, self.height), min_size, max_size) - else: - # bucketはひとつだけ、すべての画像は同じ解像度 - bucket_resos = [(self.width, self.height)] - bucket_aspect_ratios = [self.width / self.height] - bucket_aspect_ratios = np.array(bucket_aspect_ratios) - - # bucketを作成する - if enable_bucket: - img_ar_errors = [] - for image_info in self.image_data.values(): - # bucketを決める - image_width, image_height = image_info.image_size - aspect_ratio = image_width / image_height - ar_errors = bucket_aspect_ratios - aspect_ratio - - bucket_id = np.abs(ar_errors).argmin() - image_info.bucket_reso = bucket_resos[bucket_id] - - ar_error = ar_errors[bucket_id] - img_ar_errors.append(ar_error) - else: - reso = (self.width, self.height) - for image_info in self.image_data.values(): - image_info.bucket_reso = reso - - # 画像をbucketに分割する - self.buckets: list[str] = [[] for _ in range(len(bucket_resos))] - reso_to_index = {} - for i, reso in enumerate(bucket_resos): - reso_to_index[reso] = i - - for image_info in self.image_data.values(): - bucket_index = reso_to_index[image_info.bucket_reso] - for _ in range(image_info.num_repeats): - self.buckets[bucket_index].append(image_info.image_key) - - if enable_bucket: - print("number of images (including repeats for DreamBooth) / 各bucketの画像枚数(DreamBoothの場合は繰り返し回数を含む)") - for i, (reso, img_keys) in enumerate(zip(bucket_resos, self.buckets)): - print(f"bucket {i}: resolution {reso}, count: {len(img_keys)}") - img_ar_errors = np.array(img_ar_errors) - print(f"mean ar error (without repeats): {np.mean(np.abs(img_ar_errors))}") - - # 参照用indexを作る - self.buckets_indices: list(BucketBatchIndex) = [] - for bucket_index, bucket in enumerate(self.buckets): - batch_count = int(math.ceil(len(bucket) / self.batch_size)) - for batch_index in range(batch_count): - self.buckets_indices.append(BucketBatchIndex(bucket_index, batch_index)) - - self.shuffle_buckets() - self._length = len(self.buckets_indices) - - def shuffle_buckets(self): - random.shuffle(self.buckets_indices) - for bucket in self.buckets: - random.shuffle(bucket) - - def load_image(self, image_path): - image = Image.open(image_path) - if not image.mode == "RGB": - image = image.convert("RGB") - img = np.array(image, np.uint8) - return img - - def resize_and_trim(self, image, reso): - image_height, image_width = image.shape[0:2] - ar_img = image_width / image_height - ar_reso = reso[0] / reso[1] - if ar_img > ar_reso: # 横が長い→縦を合わせる - scale = reso[1] / image_height - else: - scale = reso[0] / image_width - resized_size = (int(image_width * scale + .5), int(image_height * scale + .5)) - - image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ - if resized_size[0] > reso[0]: - trim_size = resized_size[0] - reso[0] - image = image[:, trim_size//2:trim_size//2 + reso[0]] - elif resized_size[1] > reso[1]: - trim_size = resized_size[1] - reso[1] - image = image[trim_size//2:trim_size//2 + reso[1]] - assert image.shape[0] == reso[1] and image.shape[1] == reso[0], \ - f"internal error, illegal trimmed size: {image.shape}, {reso}" - return image - - def cache_latents(self, vae): - print("caching latents.") - for info in tqdm(self.image_data.values()): - if info.latents_npz is not None: - info.latents = self.load_latents_from_npz(info, False) - info.latents = torch.FloatTensor(info.latents) - info.latents_flipped = self.load_latents_from_npz(info, True) - info.latents_flipped = torch.FloatTensor(info.latents_flipped) - continue - - image = self.load_image(info.absolute_path) - image = self.resize_and_trim(image, info.bucket_reso) - - img_tensor = self.image_transforms(image) - img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype) - info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu") - - if self.flip_aug: - image = image[:, ::-1].copy() # cannot convert to Tensor without copy - img_tensor = self.image_transforms(image) - img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype) - info.latents_flipped = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu") - - def get_image_size(self, image_path): - image = Image.open(image_path) - return image.size - - def load_image_with_face_info(self, image_path: str): - img = self.load_image(image_path) - - face_cx = face_cy = face_w = face_h = 0 - if self.face_crop_aug_range is not None: - tokens = os.path.splitext(os.path.basename(image_path))[0].split('_') - if len(tokens) >= 5: - face_cx = int(tokens[-4]) - face_cy = int(tokens[-3]) - face_w = int(tokens[-2]) - face_h = int(tokens[-1]) - - return img, face_cx, face_cy, face_w, face_h - - # いい感じに切り出す - def crop_target(self, image, face_cx, face_cy, face_w, face_h): - height, width = image.shape[0:2] - if height == self.height and width == self.width: - return image - - # 画像サイズはsizeより大きいのでリサイズする - face_size = max(face_w, face_h) - min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率) - min_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[1]))) # 指定した顔最小サイズ - max_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[0]))) # 指定した顔最大サイズ - if min_scale >= max_scale: # range指定がmin==max - scale = min_scale - else: - scale = random.uniform(min_scale, max_scale) - - nh = int(height * scale + .5) - nw = int(width * scale + .5) - assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}" - image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA) - face_cx = int(face_cx * scale + .5) - face_cy = int(face_cy * scale + .5) - height, width = nh, nw - - # 顔を中心として448*640とかへ切り出す - for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))): - p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置 - - if self.random_crop: - # 背景も含めるために顔を中心に置く確率を高めつつずらす - range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう - p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数 - else: - # range指定があるときのみ、すこしだけランダムに(わりと適当) - if self.face_crop_aug_range[0] != self.face_crop_aug_range[1]: - if face_size > self.size // 10 and face_size >= 40: - p1 = p1 + random.randint(-face_size // 20, +face_size // 20) - - p1 = max(0, min(p1, length - target_size)) - - if axis == 0: - image = image[p1:p1 + target_size, :] - else: - image = image[:, p1:p1 + target_size] - - return image - - def load_latents_from_npz(self, image_info: ImageInfo, flipped): - npz_file = image_info.latents_npz_flipped if flipped else image_info.latents_npz - return np.load(npz_file)['arr_0'] - - def __len__(self): - return self._length - - def __getitem__(self, index): - if index == 0: - self.shuffle_buckets() - - bucket = self.buckets[self.buckets_indices[index].bucket_index] - image_index = self.buckets_indices[index].batch_index * self.batch_size - - loss_weights = [] - captions = [] - input_ids_list = [] - latents_list = [] - images = [] - - for image_key in bucket[image_index:image_index + self.batch_size]: - image_info = self.image_data[image_key] - loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0) - - # image/latentsを処理する - if image_info.latents is not None: - latents = image_info.latents if not self.flip_aug or random.random() < .5 else image_info.latents_flipped - image = None - elif image_info.latents_npz is not None: - latents = self.load_latents_from_npz(image_info, self.flip_aug and random.random() >= .5) - latents = torch.FloatTensor(latents) - image = None - else: - # 画像を読み込み、必要ならcropする - img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(image_info.absolute_path) - im_h, im_w = img.shape[0:2] - - if self.enable_bucket: - img = self.resize_and_trim(img, image_info.bucket_reso) - else: - if face_cx > 0: # 顔位置情報あり - img = self.crop_target(img, face_cx, face_cy, face_w, face_h) - elif im_h > self.height or im_w > self.width: - assert self.random_crop, f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}" - if im_h > self.height: - p = random.randint(0, im_h - self.height) - img = img[p:p + self.height] - if im_w > self.width: - p = random.randint(0, im_w - self.width) - img = img[:, p:p + self.width] - - im_h, im_w = img.shape[0:2] - assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}" - - # augmentation - if self.aug is not None: - img = self.aug(image=img)['image'] - - latents = None - image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる - - images.append(image) - latents_list.append(latents) - - caption = self.process_caption(image_info.caption) - captions.append(caption) - input_ids_list.append(self.get_input_ids(caption)) - - example = {} - example['loss_weights'] = torch.FloatTensor(loss_weights) - example['input_ids'] = torch.stack(input_ids_list) - - if images[0] is not None: - images = torch.stack(images) - images = images.to(memory_format=torch.contiguous_format).float() - else: - images = None - example['images'] = images - - example['latents'] = torch.stack(latents_list) if latents_list[0] is not None else None - - if self.debug_dataset: - example['image_keys'] = bucket[image_index:image_index + self.batch_size] - example['captions'] = captions - return example - - -class DreamBoothDataset(BaseDataset): - def __init__(self, batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens, resolution, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset) -> None: - super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, - resolution, flip_aug, color_aug, face_crop_aug_range, debug_dataset) - - self.batch_size = batch_size - self.size = min(self.width, self.height) # 短いほう - self.prior_loss_weight = prior_loss_weight - self.random_crop = random_crop - self.latents_cache = None - self.enable_bucket = False - - def read_caption(img_path): - # captionの候補ファイル名を作る - base_name = os.path.splitext(img_path)[0] - base_name_face_det = base_name - tokens = base_name.split("_") - if len(tokens) >= 5: - base_name_face_det = "_".join(tokens[:-4]) - cap_paths = [base_name + caption_extension, base_name_face_det + caption_extension] - - caption = None - for cap_path in cap_paths: - if os.path.isfile(cap_path): - with open(cap_path, "rt", encoding='utf-8') as f: - lines = f.readlines() - assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}" - caption = lines[0].strip() - break - return caption - - def load_dreambooth_dir(dir): - if not os.path.isdir(dir): - # print(f"ignore file: {dir}") - return 0, [], [] - - tokens = os.path.basename(dir).split('_') - try: - n_repeats = int(tokens[0]) - except ValueError as e: - print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {dir}") - return 0, [], [] - - caption_by_folder = '_'.join(tokens[1:]) - img_paths = glob.glob(os.path.join(dir, "*.png")) + glob.glob(os.path.join(dir, "*.jpg")) + \ - glob.glob(os.path.join(dir, "*.webp")) - print(f"found directory {n_repeats}_{caption_by_folder} contains {len(img_paths)} image files") - - # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う - captions = [] - for img_path in img_paths: - cap_for_img = read_caption(img_path) - captions.append(caption_by_folder if cap_for_img is None else cap_for_img) - - return n_repeats, img_paths, captions - - print("prepare train images.") - train_dirs = os.listdir(train_data_dir) - num_train_images = 0 - for dir in train_dirs: - n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(train_data_dir, dir)) - num_train_images += n_repeats * len(img_paths) - for img_path, caption in zip(img_paths, captions): - info = ImageInfo(img_path, n_repeats, caption, False, img_path) - self.register_image(info) - print(f"{num_train_images} train images with repeating.") - self.num_train_images = num_train_images - - # reg imageは数を数えて学習画像と同じ枚数にする - num_reg_images = 0 - if reg_data_dir: - print("prepare reg images.") - reg_infos: list[ImageInfo] = [] - - reg_dirs = os.listdir(reg_data_dir) - for dir in reg_dirs: - n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(reg_data_dir, dir)) - num_reg_images += n_repeats * len(img_paths) - for img_path, caption in zip(img_paths, captions): - info = ImageInfo(img_path, n_repeats, caption, True, img_path) - reg_infos.append(info) - - print(f"{num_reg_images} reg images.") - if num_train_images < num_reg_images: - print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります") - - if num_reg_images == 0: - print("no regularization images / 正則化画像が見つかりませんでした") - else: - n = 0 - while n < num_train_images: - for info in reg_infos: - self.register_image(info) - n += info.num_repeats - if n >= num_train_images: # reg画像にnum_repeats>1のときはまずありえないので考慮しない - break - - self.num_reg_images = num_reg_images - - -class FineTuningDataset(BaseDataset): - def __init__(self, metadata, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, flip_aug, color_aug, face_crop_aug_range, dataset_repeats, debug_dataset) -> None: - super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, - resolution, flip_aug, color_aug, face_crop_aug_range, debug_dataset) - - self.metadata = metadata - self.train_data_dir = train_data_dir - self.batch_size = batch_size - - for image_key, img_md in metadata.items(): - # path情報を作る - if os.path.exists(image_key): - abs_path = image_key - else: - # わりといい加減だがいい方法が思いつかん - abs_path = (glob.glob(os.path.join(train_data_dir, f"{image_key}.png")) + glob.glob(os.path.join(train_data_dir, f"{image_key}.jpg")) + - glob.glob(os.path.join(train_data_dir, f"{image_key}.webp"))) - assert len(abs_path) >= 1, f"no image / 画像がありません: {abs_path}" - abs_path = abs_path[0] - - caption = img_md.get('caption') - tags = img_md.get('tags') - if caption is None: - caption = tags - elif tags is not None and len(tags) > 0: - caption = caption + ', ' + tags - assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}" - - image_info = ImageInfo(image_key, dataset_repeats, caption, False, abs_path) - image_info.image_size = img_md.get('train_resolution') - - if not self.color_aug: - # if npz exists, use them - image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(image_key) - - self.register_image(image_info) - self.num_train_images = len(metadata) * dataset_repeats - self.num_reg_images = 0 - - # check existence of all npz files - if not self.color_aug: - npz_any = False - npz_all = True - for image_info in self.image_data.values(): - has_npz = image_info.latents_npz is not None - npz_any = npz_any or has_npz - - if self.flip_aug: - has_npz = has_npz and image_info.latents_npz_flipped is not None - npz_all = npz_all and has_npz - - if npz_any and not npz_all: - break - - if not npz_any: - print(f"npz file does not exist. make latents with VAE / npzファイルが見つからないためVAEを使ってlatentsを取得します") - elif not npz_all: - print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します") - for image_info in self.image_data.values(): - image_info.latents_npz = image_info.latents_npz_flipped = None - - # check min/max bucket size - sizes = set() - for image_info in self.image_data.values(): - if image_info.image_size is None: - sizes = None # not calculated - break - sizes.add(image_info.image_size[0]) - sizes.add(image_info.image_size[1]) - - if sizes is None: - self.min_bucket_reso = self.max_bucket_reso = None # set as not calculated - else: - self.min_bucket_reso = min(sizes) - self.max_bucket_reso = max(sizes) - - def image_key_to_npz_file(self, image_key): - base_name = os.path.splitext(image_key)[0] - npz_file_norm = base_name + '.npz' - - if os.path.exists(npz_file_norm): - # image_key is full path - npz_file_flip = base_name + '_flip.npz' - if not os.path.exists(npz_file_flip): - npz_file_flip = None - return npz_file_norm, npz_file_flip - - # image_key is relative path - npz_file_norm = os.path.join(self.train_data_dir, image_key + '.npz') - npz_file_flip = os.path.join(self.train_data_dir, image_key + '_flip.npz') - - if not os.path.exists(npz_file_norm): - npz_file_norm = None - npz_file_flip = None - elif not os.path.exists(npz_file_flip): - npz_file_flip = None - - return npz_file_norm, npz_file_flip - -# endregion - - -# region モジュール入れ替え部 -""" -高速化のためのモジュール入れ替え -""" - -# FlashAttentionを使うCrossAttention -# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py -# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE - -# constants - -EPSILON = 1e-6 - -# helper functions - - -def exists(val): - return val is not None - - -def default(val, d): - return val if exists(val) else d - -# flash attention forwards and backwards - -# https://arxiv.org/abs/2205.14135 - - -class FlashAttentionFunction(Function): - @ staticmethod - @ torch.no_grad() - def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): - """ Algorithm 2 in the paper """ - - device = q.device - dtype = q.dtype - max_neg_value = -torch.finfo(q.dtype).max - qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) - - o = torch.zeros_like(q) - all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device) - all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device) - - scale = (q.shape[-1] ** -0.5) - - if not exists(mask): - mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) - else: - mask = rearrange(mask, 'b n -> b 1 1 n') - mask = mask.split(q_bucket_size, dim=-1) - - row_splits = zip( - q.split(q_bucket_size, dim=-2), - o.split(q_bucket_size, dim=-2), - mask, - all_row_sums.split(q_bucket_size, dim=-2), - all_row_maxes.split(q_bucket_size, dim=-2), - ) - - for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): - q_start_index = ind * q_bucket_size - qk_len_diff - - col_splits = zip( - k.split(k_bucket_size, dim=-2), - v.split(k_bucket_size, dim=-2), - ) - - for k_ind, (kc, vc) in enumerate(col_splits): - k_start_index = k_ind * k_bucket_size - - attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale - - if exists(row_mask): - attn_weights.masked_fill_(~row_mask, max_neg_value) - - if causal and q_start_index < (k_start_index + k_bucket_size - 1): - causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, - device=device).triu(q_start_index - k_start_index + 1) - attn_weights.masked_fill_(causal_mask, max_neg_value) - - block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) - attn_weights -= block_row_maxes - exp_weights = torch.exp(attn_weights) - - if exists(row_mask): - exp_weights.masked_fill_(~row_mask, 0.) - - block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON) - - new_row_maxes = torch.maximum(block_row_maxes, row_maxes) - - exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc) - - exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) - exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) - - new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums - - oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values) - - row_maxes.copy_(new_row_maxes) - row_sums.copy_(new_row_sums) - - ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) - ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) - - return o - - @ staticmethod - @ torch.no_grad() - def backward(ctx, do): - """ Algorithm 4 in the paper """ - - causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args - q, k, v, o, l, m = ctx.saved_tensors - - device = q.device - - max_neg_value = -torch.finfo(q.dtype).max - qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) - - dq = torch.zeros_like(q) - dk = torch.zeros_like(k) - dv = torch.zeros_like(v) - - row_splits = zip( - q.split(q_bucket_size, dim=-2), - o.split(q_bucket_size, dim=-2), - do.split(q_bucket_size, dim=-2), - mask, - l.split(q_bucket_size, dim=-2), - m.split(q_bucket_size, dim=-2), - dq.split(q_bucket_size, dim=-2) - ) - - for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): - q_start_index = ind * q_bucket_size - qk_len_diff - - col_splits = zip( - k.split(k_bucket_size, dim=-2), - v.split(k_bucket_size, dim=-2), - dk.split(k_bucket_size, dim=-2), - dv.split(k_bucket_size, dim=-2), - ) - - for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): - k_start_index = k_ind * k_bucket_size - - attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale - - if causal and q_start_index < (k_start_index + k_bucket_size - 1): - causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, - device=device).triu(q_start_index - k_start_index + 1) - attn_weights.masked_fill_(causal_mask, max_neg_value) - - exp_attn_weights = torch.exp(attn_weights - mc) - - if exists(row_mask): - exp_attn_weights.masked_fill_(~row_mask, 0.) - - p = exp_attn_weights / lc - - dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc) - dp = einsum('... i d, ... j d -> ... i j', doc, vc) - - D = (doc * oc).sum(dim=-1, keepdims=True) - ds = p * scale * (dp - D) - - dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc) - dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc) - - dqc.add_(dq_chunk) - dkc.add_(dk_chunk) - dvc.add_(dv_chunk) - - return dq, dk, dv, None, None, None, None - - -def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers): - if mem_eff_attn: - replace_unet_cross_attn_to_memory_efficient() - elif xformers: - replace_unet_cross_attn_to_xformers() - - -def replace_unet_cross_attn_to_memory_efficient(): - print("Replace CrossAttention.forward to use FlashAttention") - flash_func = FlashAttentionFunction - - def forward_flash_attn(self, x, context=None, mask=None): - q_bucket_size = 512 - k_bucket_size = 1024 - - h = self.heads - q = self.to_q(x) - - context = context if context is not None else x - context = context.to(x.dtype) - k = self.to_k(context) - v = self.to_v(context) - del context, x - - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) - - out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size) - - out = rearrange(out, 'b h n d -> b n (h d)') - - # diffusers 0.7.0~ - out = self.to_out[0](out) - out = self.to_out[1](out) - return out - - diffusers.models.attention.CrossAttention.forward = forward_flash_attn - - -def replace_unet_cross_attn_to_xformers(): - print("Replace CrossAttention.forward to use xformers") - try: - import xformers.ops - except ImportError: - raise ImportError("No xformers / xformersがインストールされていないようです") - - def forward_xformers(self, x, context=None, mask=None): - h = self.heads - q_in = self.to_q(x) - - context = default(context, x) - context = context.to(x.dtype) - - k_in = self.to_k(context) - v_in = self.to_v(context) - - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) # new format - # q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) # legacy format - del q_in, k_in, v_in - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる - - out = rearrange(out, 'b n h d -> b n (h d)', h=h) - # out = rearrange(out, '(b h) n d -> b n (h d)', h=h) - - # diffusers 0.7.0~ - out = self.to_out[0](out) - out = self.to_out[1](out) - return out - - diffusers.models.attention.CrossAttention.forward = forward_xformers -# endregion +import library.train_util as train_util +from library.train_util import DreamBoothDataset, FineTuningDataset def collate_fn(examples): @@ -917,9 +50,9 @@ def train(args): # tokenizerを読み込む print("prepare tokenizer") if args.v2: - tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer") + tokenizer = CLIPTokenizer.from_pretrained(train_util. V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer") else: - tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) + tokenizer = CLIPTokenizer.from_pretrained(train_util. TOKENIZER_PATH) if args.max_token_length is not None: print(f"update token length: {args.max_token_length}") @@ -948,19 +81,10 @@ def train(args): else: print("Train with captions.") - # メタデータを読み込む - if os.path.exists(args.in_json): - print(f"loading existing metadata: {args.in_json}") - with open(args.in_json, "rt", encoding='utf-8') as f: - metadata = json.load(f) - else: - print(f"no metadata / メタデータファイルがありません: {args.in_json}") - return - if args.color_aug: print(f"latents in npz is ignored when color_aug is True / color_augを有効にした場合、npzファイルのlatentsは無視されます") - train_dataset = FineTuningDataset(metadata, args.train_batch_size, args.train_data_dir, + train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir, tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens, resolution, args.flip_aug, args.color_aug, face_crop_aug_range, args.dataset_repeats, args.debug_dataset) @@ -1063,7 +187,7 @@ def train(args): print("additional VAE loaded") # モデルに xformers とか memory efficient attention を組み込む - replace_unet_modules(unet, args.mem_eff_attn, args.xformers) + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) # 学習を準備する if cache_latents: @@ -1329,12 +453,12 @@ def train(args): if (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs: print("saving checkpoint.") os.makedirs(args.output_dir, exist_ok=True) - ckpt_file = os.path.join(args.output_dir, EPOCH_FILE_NAME.format(epoch + 1) + '.' + args.save_model_as) + ckpt_file = os.path.join(args.output_dir, train_util.EPOCH_FILE_NAME.format(epoch + 1) + '.' + args.save_model_as) unwrap_model(network).save_weights(ckpt_file, save_dtype) if args.save_state: print("saving state.") - accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(epoch + 1))) + accelerator.save_state(os.path.join(args.output_dir, train_util.EPOCH_STATE_NAME.format(epoch + 1))) is_main_process = accelerator.is_main_process if is_main_process: @@ -1345,13 +469,13 @@ def train(args): if args.save_state: print("saving last state.") os.makedirs(args.output_dir, exist_ok=True) - accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME)) + accelerator.save_state(os.path.join(args.output_dir, train_util.LAST_STATE_NAME)) del accelerator # この後メモリを使うのでこれは消す if is_main_process: os.makedirs(args.output_dir, exist_ok=True) - ckpt_file = os.path.join(args.output_dir, LAST_FILE_NAME + '.' + args.save_model_as) + ckpt_file = os.path.join(args.output_dir, train_util.LAST_FILE_NAME + '.' + args.save_model_as) print(f"save trained model to {ckpt_file}") network.save_weights(ckpt_file, save_dtype) print("model saved.") From 4c350067312afce037bf0ac2ba8042de78fcde84 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 3 Jan 2023 20:22:25 +0900 Subject: [PATCH 04/26] split common function from train_network to util --- fine_tune.py | 335 ++++++++---------------------------- library/train_util.py | 386 +++++++++++++++++++++++++++++++++++++++--- train_db.py | 5 - train_network.py | 297 ++++---------------------------- 4 files changed, 460 insertions(+), 563 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 53ace2e8..5da37b68 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -1,6 +1,8 @@ # training with captions +# XXX dropped option: fine_tune import argparse +import gc import math import os import random @@ -200,21 +202,13 @@ class FineTuningDataset(torch.utils.data.Dataset): return example -def save_hypernetwork(output_file, hypernetwork): - state_dict = hypernetwork.get_state_dict() - torch.save(state_dict, output_file) - - def train(args): - fine_tuning = args.hypernetwork_module is None # fine tuning or hypernetwork training + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) - # その他のオプション設定を確認する - if args.v_parameterization and not args.v2: - print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") - if args.v2 and args.clip_skip is not None: - print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") + cache_latents = args.cache_latents - # モデル形式のオプション設定を確認する + # verify load/save model formats load_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path) if load_stable_diffusion_format: @@ -231,109 +225,33 @@ def train(args): save_stable_diffusion_format = args.save_model_as.lower() == 'ckpt' or args.save_model_as.lower() == 'safetensors' use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower()) - # 乱数系列を初期化する if args.seed is not None: - set_seed(args.seed) + set_seed(args.seed) # 乱数系列を初期化する - # メタデータを読み込む - if os.path.exists(args.in_json): - print(f"loading existing metadata: {args.in_json}") - with open(args.in_json, "rt", encoding='utf-8') as f: - metadata = json.load(f) - else: - print(f"no metadata / メタデータファイルがありません: {args.in_json}") - return + tokenizer = train_util.load_tokenizer(args) - # tokenizerを読み込む - print("prepare tokenizer") - if args.v2: - tokenizer = CLIPTokenizer.from_pretrained(train_util.V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer") - else: - tokenizer = CLIPTokenizer.from_pretrained(train_util.TOKENIZER_PATH) - - if args.max_token_length is not None: - print(f"update token length: {args.max_token_length}") - - # datasetを用意する - print("prepare dataset") - train_dataset = FineTuningDataset(metadata, args.train_data_dir, args.train_batch_size, + train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir, tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens, - args.dataset_repeats, args.debug_dataset) - - print(f"Total dataset length / データセットの長さ: {len(train_dataset)}") - print(f"Total images / 画像数: {train_dataset.images_count}") + args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, + args.flip_aug, args.color_aug, args.face_crop_aug_range, args.dataset_repeats, args.debug_dataset) + train_dataset.make_buckets() + if args.debug_dataset: + train_util.debug_dataset(train_dataset) + return if len(train_dataset) == 0: print("No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。") return - if args.debug_dataset: - train_dataset.show_buckets() - i = 0 - for example in train_dataset: - print(f"image: {example['image_keys']}") - print(f"captions: {example['captions']}") - print(f"latents: {example['latents'].shape}") - print(f"input_ids: {example['input_ids'].shape}") - print(example['input_ids']) - i += 1 - if i >= 8: - break - return - # acceleratorを準備する print("prepare accelerator") - if args.logging_dir is None: - log_with = None - logging_dir = None - else: - log_with = "tensorboard" - log_prefix = "" if args.log_prefix is None else args.log_prefix - logging_dir = args.logging_dir + "/" + log_prefix + time.strftime('%Y%m%d%H%M%S', time.localtime()) - accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, - mixed_precision=args.mixed_precision, log_with=log_with, logging_dir=logging_dir) - - # accelerateの互換性問題を解決する - accelerator_0_15 = True - try: - accelerator.unwrap_model("dummy", True) - print("Using accelerator 0.15.0 or above.") - except TypeError: - accelerator_0_15 = False - - def unwrap_model(model): - if accelerator_0_15: - return accelerator.unwrap_model(model, True) - return accelerator.unwrap_model(model) + accelerator, unwrap_model = train_util.prepare_accelerator(args) # mixed precisionに対応した型を用意しておき適宜castする - weight_dtype = torch.float32 - if args.mixed_precision == "fp16": - weight_dtype = torch.float16 - elif args.mixed_precision == "bf16": - weight_dtype = torch.bfloat16 - - save_dtype = None - if args.save_precision == "fp16": - save_dtype = torch.float16 - elif args.save_precision == "bf16": - save_dtype = torch.bfloat16 - elif args.save_precision == "float": - save_dtype = torch.float32 + weight_dtype, save_dtype = train_util.prepare_dtype(args) # モデルを読み込む - if load_stable_diffusion_format: - print("load StableDiffusion checkpoint") - text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.pretrained_model_name_or_path) - else: - print("load Diffusers pretrained models") - pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, tokenizer=None, safety_checker=None) - # , torch_dtype=weight_dtype) ここでtorch_dtypeを指定すると学習時にエラーになる - text_encoder = pipe.text_encoder - unet = pipe.unet - vae = pipe.vae - del pipe - vae.to("cpu") # 保存時にしか使わないので、メモリを開けるためCPUに移しておく + text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype) # Diffusers版のxformers使用フラグを設定する関数 def set_diffusers_xformers_flag(model, valid): @@ -364,46 +282,38 @@ def train(args): set_diffusers_xformers_flag(unet, False) train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) - if not fine_tuning: - # Hypernetwork - print("import hypernetwork module:", args.hypernetwork_module) - hyp_module = importlib.import_module(args.hypernetwork_module) - - hypernetwork = hyp_module.Hypernetwork() - - if args.hypernetwork_weights is not None: - print("load hypernetwork weights from:", args.hypernetwork_weights) - hyp_sd = torch.load(args.hypernetwork_weights, map_location='cpu') - success = hypernetwork.load_from_state_dict(hyp_sd) - assert success, "hypernetwork weights loading failed." - - print("apply hypernetwork") - hypernetwork.apply_to_diffusers(None, text_encoder, unet) + # 学習を準備する + if cache_latents: + vae.to(accelerator.device, dtype=weight_dtype) + vae.requires_grad_(False) + vae.eval() + with torch.no_grad(): + train_dataset.cache_latents(vae) + vae.to("cpu") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() # 学習を準備する:モデルを適切な状態にする training_models = [] - if fine_tuning: - if args.gradient_checkpointing: - unet.enable_gradient_checkpointing() - training_models.append(unet) + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + training_models.append(unet) - if args.train_text_encoder: - print("enable text encoder training") - if args.gradient_checkpointing: - text_encoder.gradient_checkpointing_enable() - training_models.append(text_encoder) - else: - text_encoder.to(accelerator.device, dtype=weight_dtype) - text_encoder.requires_grad_(False) # text encoderは学習しない - text_encoder.eval() + if args.train_text_encoder: + print("enable text encoder training") + if args.gradient_checkpointing: + text_encoder.gradient_checkpointing_enable() + training_models.append(text_encoder) else: - unet.to(accelerator.device) # , dtype=weight_dtype) # dtypeを指定すると学習できない - unet.requires_grad_(False) - unet.eval() text_encoder.to(accelerator.device, dtype=weight_dtype) - text_encoder.requires_grad_(False) + text_encoder.requires_grad_(False) # text encoderは学習しない text_encoder.eval() - training_models.append(hypernetwork) + + if not cache_latents: + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=weight_dtype) for m in training_models: m.requires_grad_(True) @@ -439,29 +349,19 @@ def train(args): lr_scheduler = diffusers.optimization.get_scheduler( args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps) - # acceleratorがなんかよろしくやってくれるらしい + # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする if args.full_fp16: assert args.mixed_precision == "fp16", "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" print("enable full fp16 training.") + unet.to(weight_dtype) + text_encoder.to(weight_dtype) - if fine_tuning: - # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする - if args.full_fp16: - unet.to(weight_dtype) - text_encoder.to(weight_dtype) - - if args.train_text_encoder: - unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader, lr_scheduler) - else: - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) + # acceleratorがなんかよろしくやってくれるらしい + if args.train_text_encoder: + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler) else: - if args.full_fp16: - unet.to(weight_dtype) - hypernetwork.to(weight_dtype) - - unet, hypernetwork, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, hypernetwork, optimizer, train_dataloader, lr_scheduler) + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする if args.full_fp16: @@ -472,8 +372,6 @@ def train(args): accelerator.scaler._unscale_grads_ = _unscale_grads_replacer - # TODO accelerateのconfigに指定した型とオプション指定の型とをチェックして異なれば警告を出す - # resumeする if args.resume is not None: print(f"resume training from state: {args.resume}") @@ -497,17 +395,12 @@ def train(args): progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") global_step = 0 - # v4で更新:clip_sample=Falseに - # Diffusersのtrain_dreambooth.pyがconfigから持ってくるように変更されたので、clip_sample=Falseになるため、それに合わせる - # 既存の1.4/1.5/2.0/2.1はすべてschedulerのconfigは(クラス名を除いて)同じ - # よくソースを見たら学習時はclip_sampleは関係ないや(;'∀') noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False) if accelerator.is_main_process: - accelerator.init_trackers("finetuning" if fine_tuning else "hypernetwork") + accelerator.init_trackers("finetuning") - # 以下 train_dreambooth.py からほぼコピペ for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") for m in training_models: @@ -524,38 +417,7 @@ def train(args): with torch.set_grad_enabled(args.train_text_encoder): # Get the text embedding for conditioning input_ids = batch["input_ids"].to(accelerator.device) - input_ids = input_ids.reshape((-1, tokenizer.model_max_length)) # batch_size*3, 77 - - if args.clip_skip is None: - encoder_hidden_states = text_encoder(input_ids)[0] - else: - enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True) - encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip] - encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) - - # bs*3, 77, 768 or 1024 - encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1])) - - if args.max_token_length is not None: - if args.v2: - # v2: ... ... の三連を ... ... へ戻す 正直この実装でいいのかわからん - states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # - for i in range(1, args.max_token_length, tokenizer.model_max_length): - chunk = encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2] # の後から 最後の前まで - if i > 0: - for j in range(len(chunk)): - if input_ids[j, 1] == tokenizer.eos_token: # 空、つまり ...のパターン - chunk[j, 0] = chunk[j, 1] # 次の の値をコピーする - states_list.append(chunk) # の後から の前まで - states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # のどちらか - encoder_hidden_states = torch.cat(states_list, dim=1) - else: - # v1: ... の三連を ... へ戻す - states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # - for i in range(1, args.max_token_length, tokenizer.model_max_length): - states_list.append(encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2]) # の後から の前まで - states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # - encoder_hidden_states = torch.cat(states_list, dim=1) + encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder) # Sample noise that we'll add to the latents noise = torch.randn_like(latents, device=latents.device) @@ -616,23 +478,23 @@ def train(args): accelerator.wait_for_everyone() if args.save_every_n_epochs is not None: + def save_func(file): + model_util.save_diffusers_checkpoint(args.v2, out_dir, unwrap_model(text_encoder), unwrap_model(unet), + src_diffusers_model_path, vae=vae, use_safetensors=use_safetensors) + train_util.save_on_epoch_end(args, accelerator, epoch, num_train_epochs, save_func) if (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs: print("saving checkpoint.") os.makedirs(args.output_dir, exist_ok=True) ckpt_file = os.path.join(args.output_dir, model_util.get_epoch_ckpt_name(use_safetensors, epoch + 1)) - if fine_tuning: - if save_stable_diffusion_format: - model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, unwrap_model(text_encoder), unwrap_model(unet), - src_stable_diffusion_ckpt, epoch + 1, global_step, save_dtype, vae) - else: - out_dir = os.path.join(args.output_dir, train_util.EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1)) - os.makedirs(out_dir, exist_ok=True) - model_util.save_diffusers_checkpoint(args.v2, out_dir, unwrap_model(text_encoder), unwrap_model(unet), - src_diffusers_model_path, vae=vae, use_safetensors=use_safetensors) + if save_stable_diffusion_format: + model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, unwrap_model(text_encoder), unwrap_model(unet), + src_stable_diffusion_ckpt, epoch + 1, global_step, save_dtype, vae) else: - save_hypernetwork(ckpt_file, unwrap_model(hypernetwork)) - + out_dir = os.path.join(args.output_dir, train_util.EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1)) + os.makedirs(out_dir, exist_ok=True) + model_util.save_diffusers_checkpoint(args.v2, out_dir, unwrap_model(text_encoder), unwrap_model(unet), + src_diffusers_model_path, vae=vae, use_safetensors=use_safetensors) if args.save_state: print("saving state.") accelerator.save_state(os.path.join(args.output_dir, train_util.EPOCH_STATE_NAME.format(epoch + 1))) @@ -677,73 +539,16 @@ def train(args): if __name__ == '__main__': - # torch.cuda.set_per_process_memory_fraction(0.48) parser = argparse.ArgumentParser() - parser.add_argument("--v2", action='store_true', - help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む') - parser.add_argument("--v_parameterization", action='store_true', - help='enable v-parameterization training / v-parameterization学習を有効にする') - parser.add_argument("--pretrained_model_name_or_path", type=str, default=None, - help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル") - parser.add_argument("--in_json", type=str, default=None, help="metadata file to input / 読みこむメタデータファイル") - parser.add_argument("--shuffle_caption", action="store_true", - help="shuffle comma-separated caption when fine tuning / fine tuning時にコンマで区切られたcaptionの各要素をshuffleする") - parser.add_argument("--keep_tokens", type=int, default=None, - help="keep heading N tokens when shuffling caption tokens / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す") - parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ") - parser.add_argument("--dataset_repeats", type=int, default=None, help="num times to repeat dataset / 学習にデータセットを繰り返す回数") - parser.add_argument("--output_dir", type=str, default=None, - help="directory to output trained model, save as same format as input / 学習後のモデル出力先ディレクトリ(入力と同じ形式で保存)") - parser.add_argument("--save_precision", type=str, default=None, - choices=[None, "float", "fp16", "bf16"], help="precision in saving (available in StableDiffusion checkpoint) / 保存時に精度を変更して保存する(StableDiffusion形式での保存時のみ有効)") - parser.add_argument("--save_model_as", type=str, default=None, choices=[None, "ckpt", "safetensors", "diffusers", "diffusers_safetensors"], - help="format to save the model (default is same to original) / モデル保存時の形式(未指定時は元モデルと同じ)") + + train_util.add_sd_models_arguments(parser) + train_util.add_dataset_arguments(parser, False, True) + train_util.add_training_arguments(parser, False) + parser.add_argument("--use_safetensors", action='store_true', help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存する(save_model_as未指定時)") - parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") - parser.add_argument("--hypernetwork_module", type=str, default=None, - help='train hypernetwork instead of fine tuning, module to use / fine tuningの代わりにHypernetworkの学習をする場合、そのモジュール') - parser.add_argument("--hypernetwork_weights", type=str, default=None, - help='hypernetwork weights to initialize for additional training / Hypernetworkの学習時に読み込む重み(Hypernetworkの追加学習)') - parser.add_argument("--save_every_n_epochs", type=int, default=None, - help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する") - parser.add_argument("--save_state", action="store_true", - help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する") - parser.add_argument("--resume", type=str, default=None, - help="saved state to resume training / 学習再開するモデルのstate") - parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225], - help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)") - parser.add_argument("--train_batch_size", type=int, default=1, - help="batch size for training / 学習時のバッチサイズ") - parser.add_argument("--use_8bit_adam", action="store_true", - help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)") - parser.add_argument("--mem_eff_attn", action="store_true", - help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う") - parser.add_argument("--xformers", action="store_true", - help="use xformers for CrossAttention / CrossAttentionにxformersを使う") parser.add_argument("--diffusers_xformers", action='store_true', - help='use xformers by diffusers (Hypernetworks doesn\'t work) / Diffusersでxformersを使用する(Hypernetwork利用不可)') - parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率") - parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数") - parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed") - parser.add_argument("--gradient_checkpointing", action="store_true", - help="enable gradient checkpointing / grandient checkpointingを有効にする") - parser.add_argument("--gradient_accumulation_steps", type=int, default=1, - help="Number of updates steps to accumulate before performing a backward/update pass / 学習時に逆伝播をする前に勾配を合計するステップ数") - parser.add_argument("--mixed_precision", type=str, default="no", - choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度") - parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する") - parser.add_argument("--clip_skip", type=int, default=None, - help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)") - parser.add_argument("--debug_dataset", action="store_true", - help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)") - parser.add_argument("--logging_dir", type=str, default=None, - help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する") - parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列") - parser.add_argument("--lr_scheduler", type=str, default="constant", - help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup") - parser.add_argument("--lr_warmup_steps", type=int, default=0, - help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)") + help='use xformers by diffusers / Diffusersでxformersを使用する') args = parser.parse_args() train(args) diff --git a/library/train_util.py b/library/train_util.py index f525a431..8eedf48c 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1,7 +1,10 @@ # common functions for training +import argparse import json +import time from typing import NamedTuple +from accelerate import Accelerator from torch.autograd.function import Function import glob import math @@ -13,6 +16,7 @@ import torch from torchvision import transforms from transformers import CLIPTokenizer import diffusers +from diffusers import DDPMScheduler, StableDiffusionPipeline import albumentations as albu import numpy as np from PIL import Image @@ -33,6 +37,9 @@ LAST_STATE_NAME = "last-state" EPOCH_FILE_NAME = "epoch-{:06d}" LAST_FILE_NAME = "last" +LAST_DIFFUSERS_DIR_NAME = "last" +EPOCH_DIFFUSERS_DIR_NAME = "epoch-{:06d}" + # region dataset @@ -63,7 +70,8 @@ class BaseDataset(torch.utils.data.Dataset): self.max_token_length = max_token_length self.shuffle_caption = shuffle_caption self.shuffle_keep_tokens = shuffle_keep_tokens - self.width, self.height = resolution + # width/height is used when enable_bucket==False + self.width, self.height = (None, None) if resolution is None else resolution self.face_crop_aug_range = face_crop_aug_range self.flip_aug = flip_aug self.color_aug = color_aug @@ -149,35 +157,26 @@ class BaseDataset(torch.utils.data.Dataset): def register_image(self, info: ImageInfo): self.image_data[info.image_key] = info - def make_buckets(self, enable_bucket, min_size, max_size): + def make_buckets(self): ''' bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る) min_size and max_size are ignored when enable_bucket is False ''' - - self.enable_bucket = enable_bucket - print("loading image sizes.") for info in tqdm(self.image_data.values()): if info.image_size is None: info.image_size = self.get_image_size(info.absolute_path) - if enable_bucket: + if self.enable_bucket: print("make buckets") else: print("prepare dataset") - # bucketingを用意する - if enable_bucket: - bucket_resos, bucket_aspect_ratios = model_util.make_bucket_resolutions((self.width, self.height), min_size, max_size) - else: - # bucketはひとつだけ、すべての画像は同じ解像度 - bucket_resos = [(self.width, self.height)] - bucket_aspect_ratios = [self.width / self.height] - bucket_aspect_ratios = np.array(bucket_aspect_ratios) + bucket_resos = self.bucket_resos + bucket_aspect_ratios = np.array(self.bucket_aspect_ratios) # bucketを作成する - if enable_bucket: + if self.enable_bucket: img_ar_errors = [] for image_info in self.image_data.values(): # bucketを決める @@ -191,9 +190,8 @@ class BaseDataset(torch.utils.data.Dataset): ar_error = ar_errors[bucket_id] img_ar_errors.append(ar_error) else: - reso = (self.width, self.height) for image_info in self.image_data.values(): - image_info.bucket_reso = reso + image_info.bucket_reso = bucket_resos[0] # bucket_resos contains (width, height) only # 画像をbucketに分割する self.buckets: list[str] = [[] for _ in range(len(bucket_resos))] @@ -206,8 +204,8 @@ class BaseDataset(torch.utils.data.Dataset): for _ in range(image_info.num_repeats): self.buckets[bucket_index].append(image_info.image_key) - if enable_bucket: - print("number of images (including repeats for DreamBooth) / 各bucketの画像枚数(DreamBoothの場合は繰り返し回数を含む)") + if self.enable_bucket: + print("number of images (including repeats) / 各bucketの画像枚数(繰り返し回数を含む)") for i, (reso, img_keys) in enumerate(zip(bucket_resos, self.buckets)): print(f"bucket {i}: resolution {reso}, count: {len(img_keys)}") img_ar_errors = np.array(img_ar_errors) @@ -432,16 +430,27 @@ class BaseDataset(torch.utils.data.Dataset): class DreamBoothDataset(BaseDataset): - def __init__(self, batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens, resolution, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset) -> None: + def __init__(self, batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset) -> None: super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, flip_aug, color_aug, face_crop_aug_range, debug_dataset) + assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です" + self.batch_size = batch_size self.size = min(self.width, self.height) # 短いほう self.prior_loss_weight = prior_loss_weight self.random_crop = random_crop self.latents_cache = None - self.enable_bucket = False + + self.enable_bucket = enable_bucket + if self.enable_bucket: + assert min(resolution) >= min_bucket_reso, f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください" + assert max(resolution) <= max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください" + self.bucket_resos, self.bucket_aspect_ratios = model_util.make_bucket_resolutions( + (self.width, self.height), min_bucket_reso, max_bucket_reso) + else: + self.bucket_resos = [(self.width, self.height)] + self.bucket_aspect_ratios = [self.width / self.height] def read_caption(img_path): # captionの候補ファイル名を作る @@ -532,9 +541,9 @@ class DreamBoothDataset(BaseDataset): class FineTuningDataset(BaseDataset): - def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, flip_aug, color_aug, face_crop_aug_range, dataset_repeats, debug_dataset) -> None: + def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, flip_aug, color_aug, face_crop_aug_range, dataset_repeats, debug_dataset) -> None: super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, - resolution, flip_aug, color_aug, face_crop_aug_range, debug_dataset) + resolution, min_bucket_reso, max_bucket_reso, flip_aug, color_aug, face_crop_aug_range, debug_dataset) # メタデータを読み込む if os.path.exists(json_file_name): @@ -602,18 +611,35 @@ class FineTuningDataset(BaseDataset): # check min/max bucket size sizes = set() + resos = set() for image_info in self.image_data.values(): if image_info.image_size is None: sizes = None # not calculated break sizes.add(image_info.image_size[0]) sizes.add(image_info.image_size[1]) + resos.add(image_info.image_size) if sizes is None: - self.min_bucket_reso = self.max_bucket_reso = None # set as not calculated + assert resolution is not None, "if metadata doesn't have bucket info, resolution is required / メタデータにbucket情報がない場合はresolutionを指定してください" + + self.enable_bucket = enable_bucket + if self.enable_bucket: + assert min(resolution) >= min_bucket_reso, f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください" + assert max(resolution) <= max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください" + self.bucket_resos, self.bucket_aspect_ratios = model_util.make_bucket_resolutions( + (self.width, self.height), min_bucket_reso, max_bucket_reso) + else: + self.bucket_resos = [(self.width, self.height)] + self.bucket_aspect_ratios = [self.width / self.height] else: - self.min_bucket_reso = min(sizes) - self.max_bucket_reso = max(sizes) + if not enable_bucket: + print("metadata has bucket info, enable bucketing / メタデータにbucket情報があるためbucketを有効にします") + print("using bucket info in metadata / メタデータ内のbucket情報を使います") + self.enable_bucket = True + self.bucket_resos = list(resos) + self.bucket_resos.sort() + self.bucket_aspect_ratios = [w / h for w, h in self.bucket_resos] def image_key_to_npz_file(self, image_key): base_name = os.path.splitext(image_key)[0] @@ -638,6 +664,28 @@ class FineTuningDataset(BaseDataset): return npz_file_norm, npz_file_flip +def debug_dataset(train_dataset): + print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}") + print("Escape for exit. / Escキーで中断、終了します") + k = 0 + for example in train_dataset: + if example['latents'] is not None: + print("sample has latents from npz file") + for j, (ik, cap, lw) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'])): + print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, caption: "{cap}", loss weight: {lw}') + if example['images'] is not None: + im = example['images'][j] + im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8) + im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c + im = im[:, :, ::-1] # RGB -> BGR (OpenCV) + cv2.imshow("img", im) + k = cv2.waitKey() + cv2.destroyAllWindows() + if k == 27: + break + if k == 27 or example['images'] is None: + break + # endregion @@ -908,3 +956,289 @@ def replace_unet_cross_attn_to_xformers(): diffusers.models.attention.CrossAttention.forward = forward_xformers # endregion + + +# region utils + +def add_sd_models_arguments(parser: argparse.ArgumentParser): + # for pretrained models + parser.add_argument("--v2", action='store_true', + help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む') + parser.add_argument("--v_parameterization", action='store_true', + help='enable v-parameterization training / v-parameterization学習を有効にする') + parser.add_argument("--pretrained_model_name_or_path", type=str, default=None, + help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル") + + +def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool): + parser.add_argument("--output_dir", type=str, default=None, + help="directory to output trained model / 学習後のモデル出力先ディレクトリ") + parser.add_argument("--save_precision", type=str, default=None, + choices=[None, "float", "fp16", "bf16"], help="precision in saving / 保存時に精度を変更して保存する") + parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"], + help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)") + parser.add_argument("--save_every_n_epochs", type=int, default=None, + help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する") + parser.add_argument("--save_state", action="store_true", + help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する") + parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate") + + parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ") + parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225], + help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)") + parser.add_argument("--use_8bit_adam", action="store_true", + help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)") + parser.add_argument("--mem_eff_attn", action="store_true", + help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う") + parser.add_argument("--xformers", action="store_true", + help="use xformers for CrossAttention / CrossAttentionにxformersを使う") + parser.add_argument("--vae", type=str, default=None, + help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ") + + parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率") + parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数") + parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed") + parser.add_argument("--gradient_checkpointing", action="store_true", + help="enable gradient checkpointing / grandient checkpointingを有効にする") + parser.add_argument("--gradient_accumulation_steps", type=int, default=1, + help="Number of updates steps to accumulate before performing a backward/update pass / 学習時に逆伝播をする前に勾配を合計するステップ数") + parser.add_argument("--mixed_precision", type=str, default="no", + choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度") + parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する") + parser.add_argument("--clip_skip", type=int, default=None, + help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)") + parser.add_argument("--logging_dir", type=str, default=None, + help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する") + parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列") + parser.add_argument("--lr_scheduler", type=str, default="constant", + help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup") + parser.add_argument("--lr_warmup_steps", type=int, default=0, + help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)") + + if support_dreambooth: + # DreamBooth training + parser.add_argument("--prior_loss_weight", type=float, default=1.0, + help="loss weight for regularization images / 正則化画像のlossの重み") + + +def verify_training_args(args: argparse.Namespace): + if args.v_parameterization and not args.v2: + print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") + if args.v2 and args.clip_skip is not None: + print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") + + +def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool): + # dataset common + parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ") + parser.add_argument("--shuffle_caption", action="store_true", + help="shuffle comma-separated caption / コンマで区切られたcaptionの各要素をshuffleする") + parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子") + parser.add_argument("--keep_tokens", type=int, default=None, + help="keep heading N tokens when shuffling caption tokens / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す") + parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする") + parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする") + parser.add_argument("--face_crop_aug_range", type=str, default=None, + help="enable face-centered crop augmentation and its range (e.g. 2.0,4.0) / 学習時に顔を中心とした切り出しaugmentationを有効にするときは倍率を指定する(例:2.0,4.0)") + parser.add_argument("--random_crop", action="store_true", + help="enable random crop (for style training in face-centered crop augmentation) / ランダムな切り出しを有効にする(顔を中心としたaugmentationを行うときに画風の学習用に指定する)") + parser.add_argument("--debug_dataset", action="store_true", + help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)") + parser.add_argument("--resolution", type=str, default=None, + help="resolution in training ('size' or 'width,height') / 学習時の画像解像度('サイズ'指定、または'幅,高さ'指定)") + parser.add_argument("--cache_latents", action="store_true", + help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheする(augmentationは使用不可)") + parser.add_argument("--enable_bucket", action="store_true", + help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする") + parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度") + parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最大解像度") + + if support_dreambooth: + # DreamBooth dataset + parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ") + + if support_caption: + # caption dataset + parser.add_argument("--in_json", type=str, default=None, help="json metadata for dataset / データセットのmetadataのjsonファイル") + parser.add_argument("--dataset_repeats", type=int, default=1, + help="repeat dataset when training with captions / キャプションでの学習時にデータセットを繰り返す回数") + + +def prepare_dataset_args(args: argparse.Namespace, support_caption: bool): + if args.cache_latents: + assert not args.color_aug, "when caching latents, color_aug cannot be used / latentをキャッシュするときはcolor_augは使えません" + + # assert args.resolution is not None, f"resolution is required / resolution(解像度)を指定してください" + if args.resolution is not None: + args.resolution = tuple([int(r) for r in args.resolution.split(',')]) + if len(args.resolution) == 1: + args.resolution = (args.resolution[0], args.resolution[0]) + assert len(args.resolution) == 2, \ + f"resolution must be 'size' or 'width,height' / resolution(解像度)は'サイズ'または'幅','高さ'で指定してください: {args.resolution}" + + if args.face_crop_aug_range is not None: + args.face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(',')]) + assert len(args.face_crop_aug_range) == 2, \ + f"face_crop_aug_range must be two floats / face_crop_aug_rangeは'下限,上限'で指定してください: {args.face_crop_aug_range}" + else: + args.face_crop_aug_range = None + + if support_caption: + if args.in_json is not None and args.color_aug: + print(f"latents in npz is ignored when color_aug is True / color_augを有効にした場合、npzファイルのlatentsは無視されます") + + +def load_tokenizer(args: argparse.Namespace): + print("prepare tokenizer") + if args.v2: + tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer") + else: + tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) + if args.max_token_length is not None: + print(f"update token length: {args.max_token_length}") + return tokenizer + + +def prepare_accelerator(args: argparse.Namespace): + if args.logging_dir is None: + log_with = None + logging_dir = None + else: + log_with = "tensorboard" + log_prefix = "" if args.log_prefix is None else args.log_prefix + logging_dir = args.logging_dir + "/" + log_prefix + time.strftime('%Y%m%d%H%M%S', time.localtime()) + + accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, + log_with=log_with, logging_dir=logging_dir) + + # accelerateの互換性問題を解決する + accelerator_0_15 = True + try: + accelerator.unwrap_model("dummy", True) + print("Using accelerator 0.15.0 or above.") + except TypeError: + accelerator_0_15 = False + + def unwrap_model(model): + if accelerator_0_15: + return accelerator.unwrap_model(model, True) + return accelerator.unwrap_model(model) + + return accelerator, unwrap_model + + +def prepare_dtype(args: argparse.Namespace): + weight_dtype = torch.float32 + if args.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif args.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + save_dtype = None + if args.save_precision == "fp16": + save_dtype = torch.float16 + elif args.save_precision == "bf16": + save_dtype = torch.bfloat16 + elif args.save_precision == "float": + save_dtype = torch.float32 + + return weight_dtype, save_dtype + + +def load_target_model(args: argparse.Namespace, weight_dtype): + load_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path) # determine SD or Diffusers + if load_stable_diffusion_format: + print("load StableDiffusion checkpoint") + text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.pretrained_model_name_or_path) + else: + print("load Diffusers pretrained models") + pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, tokenizer=None, safety_checker=None) + text_encoder = pipe.text_encoder + vae = pipe.vae + unet = pipe.unet + del pipe + + # VAEを読み込む + if args.vae is not None: + vae = model_util.load_vae(args.vae, weight_dtype) + print("additional VAE loaded") + + return text_encoder, vae, unet, load_stable_diffusion_format + + +def patch_accelerator_for_fp16_training(accelerator): + org_unscale_grads = accelerator.scaler._unscale_grads_ + + def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): + return org_unscale_grads(optimizer, inv_scale, found_inf, True) + + accelerator.scaler._unscale_grads_ = _unscale_grads_replacer + + +def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encoder, weight_dtype=None): + b_size = input_ids.size()[0] + input_ids = input_ids.reshape((-1, tokenizer.model_max_length)) # batch_size*3, 77 + + if args.clip_skip is None: + encoder_hidden_states = text_encoder(input_ids)[0] + else: + enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True) + encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip] + if weight_dtype is not None: + # this is required for additional network training + encoder_hidden_states = encoder_hidden_states.to(weight_dtype) + encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) + + # bs*3, 77, 768 or 1024 + encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1])) + + if args.max_token_length is not None: + if args.v2: + # v2: ... ... の三連を ... ... へ戻す 正直この実装でいいのかわからん + states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # + for i in range(1, args.max_token_length, tokenizer.model_max_length): + chunk = encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2] # の後から 最後の前まで + if i > 0: + for j in range(len(chunk)): + if input_ids[j, 1] == tokenizer.eos_token: # 空、つまり ...のパターン + chunk[j, 0] = chunk[j, 1] # 次の の値をコピーする + states_list.append(chunk) # の後から の前まで + states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # のどちらか + encoder_hidden_states = torch.cat(states_list, dim=1) + else: + # v1: ... の三連を ... へ戻す + states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # + for i in range(1, args.max_token_length, tokenizer.model_max_length): + states_list.append(encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2]) # の後から の前まで + states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # + encoder_hidden_states = torch.cat(states_list, dim=1) + + return encoder_hidden_states + + +def save_on_epoch_end(args: argparse.Namespace, accelerator, epoch: int, num_train_epochs: int, save_func): + if (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs: + print("saving checkpoint.") + os.makedirs(args.output_dir, exist_ok=True) + ckpt_file = os.path.join(args.output_dir, EPOCH_FILE_NAME.format(epoch + 1) + '.' + args.save_model_as) + save_func(ckpt_file) + + if args.save_state: + print("saving state.") + accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(epoch + 1))) + + +def save_last_state(args, accelerator): + print("saving last state.") + os.makedirs(args.output_dir, exist_ok=True) + accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME)) + + +def save_last_model(args, save_func): + os.makedirs(args.output_dir, exist_ok=True) + ckpt_file = os.path.join(args.output_dir, LAST_FILE_NAME + '.' + args.save_model_as) + print(f"save trained model to {ckpt_file}") + save_func(ckpt_file) + print("model saved.") + +# endregion diff --git a/train_db.py b/train_db.py index 50ed5a64..4a30d2d5 100644 --- a/train_db.py +++ b/train_db.py @@ -832,11 +832,6 @@ def train(args): if os.path.exists(out_dir_old): shutil.rmtree(out_dir_old) - - - - - if args.save_state: print("saving state.") accelerator.save_state(os.path.join(args.output_dir, train_util.EPOCH_STATE_NAME.format(epoch + 1))) diff --git a/train_network.py b/train_network.py index bcb9db8e..35f50567 100644 --- a/train_network.py +++ b/train_network.py @@ -8,7 +8,6 @@ import os from tqdm import tqdm import torch -from accelerate import Accelerator from accelerate.utils import set_seed from transformers import CLIPTokenizer import diffusers @@ -26,165 +25,48 @@ def collate_fn(examples): def train(args): + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + cache_latents = args.cache_latents - - # latentsをキャッシュする場合のオプション設定を確認する - if cache_latents: - assert not args.color_aug, "when caching latents, color_aug cannot be used / latentをキャッシュするときはcolor_augは使えません" - - # その他のオプション設定を確認する - if args.v_parameterization and not args.v2: - print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") - if args.v2 and args.clip_skip is not None: - print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") - use_dreambooth_method = args.in_json is None - # モデル形式のオプション設定を確認する: - load_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path) - - # 乱数系列を初期化する if args.seed is not None: set_seed(args.seed) - # tokenizerを読み込む - print("prepare tokenizer") - if args.v2: - tokenizer = CLIPTokenizer.from_pretrained(train_util. V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer") - else: - tokenizer = CLIPTokenizer.from_pretrained(train_util. TOKENIZER_PATH) - - if args.max_token_length is not None: - print(f"update token length: {args.max_token_length}") - - # 学習データを用意する - assert args.resolution is not None, f"resolution is required / resolution(解像度)を指定してください" - resolution = tuple([int(r) for r in args.resolution.split(',')]) - if len(resolution) == 1: - resolution = (resolution[0], resolution[0]) - assert len(resolution) == 2, \ - f"resolution must be 'size' or 'width,height' / resolution(解像度)は'サイズ'または'幅','高さ'で指定してください: {args.resolution}" - - if args.face_crop_aug_range is not None: - face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(',')]) - assert len( - face_crop_aug_range) == 2, f"face_crop_aug_range must be two floats / face_crop_aug_rangeは'下限,上限'で指定してください: {args.face_crop_aug_range}" - else: - face_crop_aug_range = None + tokenizer = train_util.load_tokenizer(args) # データセットを準備する if use_dreambooth_method: print("Use DreamBooth method.") train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir, tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens, - resolution, args.prior_loss_weight, args.flip_aug, args.color_aug, face_crop_aug_range, args.random_crop, args.debug_dataset) + args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, args.prior_loss_weight, + args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset) else: print("Train with captions.") - - if args.color_aug: - print(f"latents in npz is ignored when color_aug is True / color_augを有効にした場合、npzファイルのlatentsは無視されます") - train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir, tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens, - resolution, args.flip_aug, args.color_aug, face_crop_aug_range, args.dataset_repeats, args.debug_dataset) - - if train_dataset.min_bucket_reso is not None and (args.enable_bucket or train_dataset.min_bucket_reso != train_dataset.max_bucket_reso): - print(f"using bucket info in metadata / メタデータ内のbucket情報を使います") - args.min_bucket_reso = train_dataset.min_bucket_reso - args.max_bucket_reso = train_dataset.max_bucket_reso - args.enable_bucket = True - print(f"min bucket reso: {args.min_bucket_reso}, max bucket reso: {args.max_bucket_reso}") - - if args.enable_bucket: - assert min(resolution) >= args.min_bucket_reso, f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください" - assert max(resolution) <= args.max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください" - - train_dataset.make_buckets(args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso) + args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, + args.flip_aug, args.color_aug, args.face_crop_aug_range, args.dataset_repeats, args.debug_dataset) + train_dataset.make_buckets() if args.debug_dataset: - print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}") - print("Escape for exit. / Escキーで中断、終了します") - k = 0 - for example in train_dataset: - if example['latents'] is not None: - print("sample has latents from npz file") - for j, (ik, cap, lw) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'])): - print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, caption: "{cap}", loss weight: {lw}') - if example['images'] is not None: - im = example['images'][j] - im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8) - im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c - im = im[:, :, ::-1] # RGB -> BGR (OpenCV) - cv2.imshow("img", im) - k = cv2.waitKey() - cv2.destroyAllWindows() - if k == 27: - break - if k == 27 or example['images'] is None: - break + train_util.debug_dataset(train_dataset) return - if len(train_dataset) == 0: print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください") return # acceleratorを準備する print("prepare accelerator") - if args.logging_dir is None: - log_with = None - logging_dir = None - else: - log_with = "tensorboard" - log_prefix = "" if args.log_prefix is None else args.log_prefix - logging_dir = args.logging_dir + "/" + log_prefix + time.strftime('%Y%m%d%H%M%S', time.localtime()) - - accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, - log_with=log_with, logging_dir=logging_dir) - - # accelerateの互換性問題を解決する - accelerator_0_15 = True - try: - accelerator.unwrap_model("dummy", True) - print("Using accelerator 0.15.0 or above.") - except TypeError: - accelerator_0_15 = False - - def unwrap_model(model): - if accelerator_0_15: - return accelerator.unwrap_model(model, True) - return accelerator.unwrap_model(model) + accelerator, unwrap_model = train_util.prepare_accelerator(args) # mixed precisionに対応した型を用意しておき適宜castする - weight_dtype = torch.float32 - if args.mixed_precision == "fp16": - weight_dtype = torch.float16 - elif args.mixed_precision == "bf16": - weight_dtype = torch.bfloat16 - - save_dtype = None - if args.save_precision == "fp16": - save_dtype = torch.float16 - elif args.save_precision == "bf16": - save_dtype = torch.bfloat16 - elif args.save_precision == "float": - save_dtype = torch.float32 + weight_dtype, save_dtype = train_util.prepare_dtype(args) # モデルを読み込む - if load_stable_diffusion_format: - print("load StableDiffusion checkpoint") - text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.pretrained_model_name_or_path) - else: - print("load Diffusers pretrained models") - pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, tokenizer=None, safety_checker=None) - text_encoder = pipe.text_encoder - vae = pipe.vae - unet = pipe.unet - del pipe - - # VAEを読み込む - if args.vae is not None: - vae = model_util.load_vae(args.vae, weight_dtype) - print("additional VAE loaded") + text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype) # モデルに xformers とか memory efficient attention を組み込む train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) @@ -295,12 +177,7 @@ def train(args): # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする if args.full_fp16: - org_unscale_grads = accelerator.scaler._unscale_grads_ - - def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): - return org_unscale_grads(optimizer, inv_scale, found_inf, True) - - accelerator.scaler._unscale_grads_ = _unscale_grads_replacer + train_util.patch_accelerator_for_fp16_training(accelerator) # resumeする if args.resume is not None: @@ -353,39 +230,7 @@ def train(args): with torch.set_grad_enabled(train_text_encoder): # Get the text embedding for conditioning input_ids = batch["input_ids"].to(accelerator.device) - input_ids = input_ids.reshape((-1, tokenizer.model_max_length)) # batch_size*3, 77 - - if args.clip_skip is None: - encoder_hidden_states = text_encoder(input_ids)[0] - else: - enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True) - encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip] - encoder_hidden_states = encoder_hidden_states.to(weight_dtype) # なぜかこれが必要 - encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) - - # bs*3, 77, 768 or 1024 - encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1])) - - if args.max_token_length is not None: - if args.v2: - # v2: ... ... の三連を ... ... へ戻す 正直この実装でいいのかわからん - states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # - for i in range(1, args.max_token_length, tokenizer.model_max_length): - chunk = encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2] # の後から 最後の前まで - if i > 0: - for j in range(len(chunk)): - if input_ids[j, 1] == tokenizer.eos_token: # 空、つまり ...のパターン - chunk[j, 0] = chunk[j, 1] # 次の の値をコピーする - states_list.append(chunk) # の後から の前まで - states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # のどちらか - encoder_hidden_states = torch.cat(states_list, dim=1) - else: - # v1: ... の三連を ... へ戻す - states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # - for i in range(1, args.max_token_length, tokenizer.model_max_length): - states_list.append(encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2]) # の後から の前まで - states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # - encoder_hidden_states = torch.cat(states_list, dim=1) + encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype) # Sample noise that we'll add to the latents noise = torch.randn_like(latents, device=latents.device) @@ -403,7 +248,6 @@ def train(args): if args.v_parameterization: # v-parameterization training - # Diffusers 0.10.0からv_parameterizationの学習に対応したのでそちらを使う target = noise_scheduler.get_velocity(latents, noise, timesteps) else: target = noise @@ -450,15 +294,9 @@ def train(args): accelerator.wait_for_everyone() if args.save_every_n_epochs is not None: - if (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs: - print("saving checkpoint.") - os.makedirs(args.output_dir, exist_ok=True) - ckpt_file = os.path.join(args.output_dir, train_util.EPOCH_FILE_NAME.format(epoch + 1) + '.' + args.save_model_as) - unwrap_model(network).save_weights(ckpt_file, save_dtype) - - if args.save_state: - print("saving state.") - accelerator.save_state(os.path.join(args.output_dir, train_util.EPOCH_STATE_NAME.format(epoch + 1))) + def save_func(file): + unwrap_model(network).save_weights(file, save_dtype) + train_util.save_on_epoch_end(args, accelerator, epoch, num_train_epochs, save_func) is_main_process = accelerator.is_main_process if is_main_process: @@ -467,103 +305,28 @@ def train(args): accelerator.end_training() if args.save_state: - print("saving last state.") - os.makedirs(args.output_dir, exist_ok=True) - accelerator.save_state(os.path.join(args.output_dir, train_util.LAST_STATE_NAME)) + train_util.save_last_state(args, accelerator) del accelerator # この後メモリを使うのでこれは消す if is_main_process: - os.makedirs(args.output_dir, exist_ok=True) - ckpt_file = os.path.join(args.output_dir, train_util.LAST_FILE_NAME + '.' + args.save_model_as) - print(f"save trained model to {ckpt_file}") - network.save_weights(ckpt_file, save_dtype) - print("model saved.") + def last_save_func(file): + network.save_weights(file, save_dtype) + train_util.save_last_model(args, last_save_func) if __name__ == '__main__': - # torch.cuda.set_per_process_memory_fraction(0.48) parser = argparse.ArgumentParser() - parser.add_argument("--v2", action='store_true', - help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む') - parser.add_argument("--v_parameterization", action='store_true', - help='enable v-parameterization training / v-parameterization学習を有効にする') - parser.add_argument("--pretrained_model_name_or_path", type=str, default=None, - help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル") - parser.add_argument("--network_weights", type=str, default=None, - help="pretrained weights for network / 学習するネットワークの初期重み") - parser.add_argument("--shuffle_caption", action="store_true", - help="shuffle comma-separated caption / コンマで区切られたcaptionの各要素をshuffleする") - parser.add_argument("--keep_tokens", type=int, default=None, - help="keep heading N tokens when shuffling caption tokens / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す") - parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ") - parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ") - parser.add_argument("--in_json", type=str, default=None, help="json metadata for dataset / データセットのmetadataのjsonファイル") - parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子") - parser.add_argument("--dataset_repeats", type=int, default=1, - help="repeat dataset when training with captions / キャプションでの学習時にデータセットを繰り返す回数") - parser.add_argument("--output_dir", type=str, default=None, - help="directory to output trained model / 学習後のモデル出力先ディレクトリ") - parser.add_argument("--save_precision", type=str, default=None, - choices=[None, "float", "fp16", "bf16"], help="precision in saving / 保存時に精度を変更して保存する") - parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"], - help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)") - parser.add_argument("--save_every_n_epochs", type=int, default=None, - help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する") - parser.add_argument("--save_state", action="store_true", - help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する") - parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate") - parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする") - parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする") - parser.add_argument("--face_crop_aug_range", type=str, default=None, - help="enable face-centered crop augmentation and its range (e.g. 2.0,4.0) / 学習時に顔を中心とした切り出しaugmentationを有効にするときは倍率を指定する(例:2.0,4.0)") - parser.add_argument("--random_crop", action="store_true", - help="enable random crop (for style training in face-centered crop augmentation) / ランダムな切り出しを有効にする(顔を中心としたaugmentationを行うときに画風の学習用に指定する)") - parser.add_argument("--debug_dataset", action="store_true", - help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)") - parser.add_argument("--resolution", type=str, default=None, - help="resolution in training ('size' or 'width,height') / 学習時の画像解像度('サイズ'指定、または'幅,高さ'指定)") - parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ") - parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225], - help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)") - parser.add_argument("--use_8bit_adam", action="store_true", - help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)") - parser.add_argument("--mem_eff_attn", action="store_true", - help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う") - parser.add_argument("--xformers", action="store_true", - help="use xformers for CrossAttention / CrossAttentionにxformersを使う") - parser.add_argument("--vae", type=str, default=None, - help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ") - parser.add_argument("--cache_latents", action="store_true", - help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheする(augmentationは使用不可)") - parser.add_argument("--enable_bucket", action="store_true", - help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする") - parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度") - parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最大解像度") - parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率") + + train_util.add_sd_models_arguments(parser) + train_util.add_dataset_arguments(parser, True, True) + train_util.add_training_arguments(parser, True) + parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率") parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率") - parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数") - parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み") - # parser.add_argument("--stop_text_encoder_training", type=int, default=None, - # help="steps to stop text encoder training / Text Encoderの学習を止めるステップ数") - parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed") - parser.add_argument("--gradient_checkpointing", action="store_true", - help="enable gradient checkpointing / grandient checkpointingを有効にする") - parser.add_argument("--gradient_accumulation_steps", type=int, default=1, - help="Number of updates steps to accumulate before performing a backward/update pass / 学習時に逆伝播をする前に勾配を合計するステップ数") - parser.add_argument("--mixed_precision", type=str, default="no", - choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度") - parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する") - parser.add_argument("--clip_skip", type=int, default=None, - help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)") - parser.add_argument("--logging_dir", type=str, default=None, - help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する") - parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列") - parser.add_argument("--lr_scheduler", type=str, default="constant", - help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup") - parser.add_argument("--lr_warmup_steps", type=int, default=0, - help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)") + + parser.add_argument("--network_weights", type=str, default=None, + help="pretrained weights for network / 学習するネットワークの初期重み") parser.add_argument("--network_module", type=str, default=None, help='network module to train / 学習対象のネットワークのモジュール') parser.add_argument("--network_dim", type=int, default=None, help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)') From f56988b2529a21febe3d784b1fccf13bd0e7df27 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 5 Jan 2023 08:10:22 +0900 Subject: [PATCH 05/26] unify dataset and save functions --- fine_tune.py | 300 ++-------------- library/model_util.py | 8 - library/train_util.py | 153 ++++++-- train_db.py | 788 +++++------------------------------------- train_network.py | 54 ++- 5 files changed, 287 insertions(+), 1016 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 5da37b68..8b06abda 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -1,27 +1,17 @@ # training with captions -# XXX dropped option: fine_tune +# XXX dropped option: hypernetwork training import argparse import gc import math import os -import random -import json -import importlib -import time from tqdm import tqdm import torch -from accelerate import Accelerator from accelerate.utils import set_seed -from transformers import CLIPTokenizer import diffusers -from diffusers import DDPMScheduler, StableDiffusionPipeline -import numpy as np -from einops import rearrange -from torch import einsum +from diffusers import DDPMScheduler -import library.model_util as model_util import library.train_util as train_util @@ -29,211 +19,21 @@ def collate_fn(examples): return examples[0] -class FineTuningDataset(torch.utils.data.Dataset): - def __init__(self, metadata, train_data_dir, batch_size, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, dataset_repeats, debug) -> None: - super().__init__() - - self.metadata = metadata - self.train_data_dir = train_data_dir - self.batch_size = batch_size - self.tokenizer: CLIPTokenizer = tokenizer - self.max_token_length = max_token_length - self.shuffle_caption = shuffle_caption - self.shuffle_keep_tokens = shuffle_keep_tokens - self.debug = debug - - self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2 - - print("make buckets") - - # 最初に数を数える - self.bucket_resos = set() - for img_md in metadata.values(): - if 'train_resolution' in img_md: - self.bucket_resos.add(tuple(img_md['train_resolution'])) - self.bucket_resos = list(self.bucket_resos) - self.bucket_resos.sort() - print(f"number of buckets: {len(self.bucket_resos)}") - - reso_to_index = {} - for i, reso in enumerate(self.bucket_resos): - reso_to_index[reso] = i - - # bucketに割り当てていく - self.buckets = [[] for _ in range(len(self.bucket_resos))] - n = 1 if dataset_repeats is None else dataset_repeats - images_count = 0 - for image_key, img_md in metadata.items(): - if 'train_resolution' not in img_md: - continue - if not os.path.exists(self.image_key_to_npz_file(image_key)): - continue - - reso = tuple(img_md['train_resolution']) - for _ in range(n): - self.buckets[reso_to_index[reso]].append(image_key) - images_count += n - - # 参照用indexを作る - self.buckets_indices = [] - for bucket_index, bucket in enumerate(self.buckets): - batch_count = int(math.ceil(len(bucket) / self.batch_size)) - for batch_index in range(batch_count): - self.buckets_indices.append((bucket_index, batch_index)) - - self.shuffle_buckets() - self._length = len(self.buckets_indices) - self.images_count = images_count - - def show_buckets(self): - for i, (reso, bucket) in enumerate(zip(self.bucket_resos, self.buckets)): - print(f"bucket {i}: resolution {reso}, count: {len(bucket)}") - - def shuffle_buckets(self): - random.shuffle(self.buckets_indices) - for bucket in self.buckets: - random.shuffle(bucket) - - def image_key_to_npz_file(self, image_key): - npz_file_norm = os.path.splitext(image_key)[0] + '.npz' - if os.path.exists(npz_file_norm): - if random.random() < .5: - npz_file_flip = os.path.splitext(image_key)[0] + '_flip.npz' - if os.path.exists(npz_file_flip): - return npz_file_flip - return npz_file_norm - - npz_file_norm = os.path.join(self.train_data_dir, image_key + '.npz') - if random.random() < .5: - npz_file_flip = os.path.join(self.train_data_dir, image_key + '_flip.npz') - if os.path.exists(npz_file_flip): - return npz_file_flip - return npz_file_norm - - def load_latent(self, image_key): - return np.load(self.image_key_to_npz_file(image_key))['arr_0'] - - def __len__(self): - return self._length - - def __getitem__(self, index): - if index == 0: - self.shuffle_buckets() - - bucket = self.buckets[self.buckets_indices[index][0]] - image_index = self.buckets_indices[index][1] * self.batch_size - - input_ids_list = [] - latents_list = [] - captions = [] - for image_key in bucket[image_index:image_index + self.batch_size]: - img_md = self.metadata[image_key] - caption = img_md.get('caption') - tags = img_md.get('tags') - - if caption is None: - caption = tags - elif tags is not None and len(tags) > 0: - caption = caption + ', ' + tags - assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{image_key}" - - latents = self.load_latent(image_key) - - if self.shuffle_caption: - tokens = caption.strip().split(",") - if self.shuffle_keep_tokens is None: - random.shuffle(tokens) - else: - if len(tokens) > self.shuffle_keep_tokens: - keep_tokens = tokens[:self.shuffle_keep_tokens] - tokens = tokens[self.shuffle_keep_tokens:] - random.shuffle(tokens) - tokens = keep_tokens + tokens - caption = ",".join(tokens).strip() - - captions.append(caption) - - input_ids = self.tokenizer(caption, padding="max_length", truncation=True, - max_length=self.tokenizer_max_length, return_tensors="pt").input_ids - - if self.tokenizer_max_length > self.tokenizer.model_max_length: - input_ids = input_ids.squeeze(0) - iids_list = [] - if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id: - # v1 - # 77以上の時は " .... " でトータル227とかになっているので、"..."の三連に変換する - # 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に - for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2): # (1, 152, 75) - ids_chunk = (input_ids[0].unsqueeze(0), - input_ids[i:i + self.tokenizer.model_max_length - 2], - input_ids[-1].unsqueeze(0)) - ids_chunk = torch.cat(ids_chunk) - iids_list.append(ids_chunk) - else: - # v2 - # 77以上の時は " .... ..." でトータル227とかになっているので、"... ..."の三連に変換する - for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2): - ids_chunk = (input_ids[0].unsqueeze(0), # BOS - input_ids[i:i + self.tokenizer.model_max_length - 2], - input_ids[-1].unsqueeze(0)) # PAD or EOS - ids_chunk = torch.cat(ids_chunk) - - # 末尾が または の場合は、何もしなくてよい - # 末尾が x の場合は末尾を に変える(x なら結果的に変化なし) - if ids_chunk[-2] != self.tokenizer.eos_token_id and ids_chunk[-2] != self.tokenizer.pad_token_id: - ids_chunk[-1] = self.tokenizer.eos_token_id - # 先頭が ... の場合は ... に変える - if ids_chunk[1] == self.tokenizer.pad_token_id: - ids_chunk[1] = self.tokenizer.eos_token_id - - iids_list.append(ids_chunk) - - input_ids = torch.stack(iids_list) # 3,77 - - input_ids_list.append(input_ids) - latents_list.append(torch.FloatTensor(latents)) - - example = {} - example['input_ids'] = torch.stack(input_ids_list) - example['latents'] = torch.stack(latents_list) - if self.debug: - example['image_keys'] = bucket[image_index:image_index + self.batch_size] - example['captions'] = captions - return example - - def train(args): train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) cache_latents = args.cache_latents - # verify load/save model formats - load_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path) - - if load_stable_diffusion_format: - src_stable_diffusion_ckpt = args.pretrained_model_name_or_path - src_diffusers_model_path = None - else: - src_stable_diffusion_ckpt = None - src_diffusers_model_path = args.pretrained_model_name_or_path - - if args.save_model_as is None: - save_stable_diffusion_format = load_stable_diffusion_format - use_safetensors = args.use_safetensors - else: - save_stable_diffusion_format = args.save_model_as.lower() == 'ckpt' or args.save_model_as.lower() == 'safetensors' - use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower()) - if args.seed is not None: set_seed(args.seed) # 乱数系列を初期化する tokenizer = train_util.load_tokenizer(args) - train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir, - tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens, - args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, - args.flip_aug, args.color_aug, args.face_crop_aug_range, args.dataset_repeats, args.debug_dataset) + train_dataset = train_util.FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir, + tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens, + args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, + args.flip_aug, args.color_aug, args.face_crop_aug_range, args.dataset_repeats, args.debug_dataset) train_dataset.make_buckets() if args.debug_dataset: @@ -253,6 +53,21 @@ def train(args): # モデルを読み込む text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype) + # verify load/save model formats + if load_stable_diffusion_format: + src_stable_diffusion_ckpt = args.pretrained_model_name_or_path + src_diffusers_model_path = None + else: + src_stable_diffusion_ckpt = None + src_diffusers_model_path = args.pretrained_model_name_or_path + + if args.save_model_as is None: + save_stable_diffusion_format = load_stable_diffusion_format + use_safetensors = args.use_safetensors + else: + save_stable_diffusion_format = args.save_model_as.lower() == 'ckpt' or args.save_model_as.lower() == 'safetensors' + use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower()) + # Diffusers版のxformers使用フラグを設定する関数 def set_diffusers_xformers_flag(model, valid): # model.set_use_memory_efficient_attention_xformers(valid) # 次のリリースでなくなりそう @@ -308,7 +123,11 @@ def train(args): else: text_encoder.to(accelerator.device, dtype=weight_dtype) text_encoder.requires_grad_(False) # text encoderは学習しない - text_encoder.eval() + if args.gradient_checkpointing: + text_encoder.gradient_checkpointing_enable() + text_encoder.train() # required for gradient_checkpointing + else: + text_encoder.eval() if not cache_latents: vae.requires_grad_(False) @@ -365,12 +184,7 @@ def train(args): # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする if args.full_fp16: - org_unscale_grads = accelerator.scaler._unscale_grads_ - - def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): - return org_unscale_grads(optimizer, inv_scale, found_inf, True) - - accelerator.scaler._unscale_grads_ = _unscale_grads_replacer + train_util.patch_accelerator_for_fp16_training(accelerator) # resumeする if args.resume is not None: @@ -413,7 +227,6 @@ def train(args): latents = latents * 0.18215 b_size = latents.shape[0] - # with torch.no_grad(): with torch.set_grad_enabled(args.train_text_encoder): # Get the text embedding for conditioning input_ids = batch["input_ids"].to(accelerator.device) @@ -435,7 +248,6 @@ def train(args): if args.v_parameterization: # v-parameterization training - # Diffusers 0.10.0からv_parameterizationの学習に対応したのでそちらを使う target = noise_scheduler.get_velocity(latents, noise, timesteps) else: target = noise @@ -478,63 +290,26 @@ def train(args): accelerator.wait_for_everyone() if args.save_every_n_epochs is not None: - def save_func(file): - model_util.save_diffusers_checkpoint(args.v2, out_dir, unwrap_model(text_encoder), unwrap_model(unet), - src_diffusers_model_path, vae=vae, use_safetensors=use_safetensors) - train_util.save_on_epoch_end(args, accelerator, epoch, num_train_epochs, save_func) - if (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs: - print("saving checkpoint.") - os.makedirs(args.output_dir, exist_ok=True) - ckpt_file = os.path.join(args.output_dir, model_util.get_epoch_ckpt_name(use_safetensors, epoch + 1)) - - if save_stable_diffusion_format: - model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, unwrap_model(text_encoder), unwrap_model(unet), - src_stable_diffusion_ckpt, epoch + 1, global_step, save_dtype, vae) - else: - out_dir = os.path.join(args.output_dir, train_util.EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1)) - os.makedirs(out_dir, exist_ok=True) - model_util.save_diffusers_checkpoint(args.v2, out_dir, unwrap_model(text_encoder), unwrap_model(unet), - src_diffusers_model_path, vae=vae, use_safetensors=use_safetensors) - if args.save_state: - print("saving state.") - accelerator.save_state(os.path.join(args.output_dir, train_util.EPOCH_STATE_NAME.format(epoch + 1))) + src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path + train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors, + save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae) is_main_process = accelerator.is_main_process if is_main_process: - if fine_tuning: - unet = unwrap_model(unet) - text_encoder = unwrap_model(text_encoder) - else: - hypernetwork = unwrap_model(hypernetwork) + unet = unwrap_model(unet) + text_encoder = unwrap_model(text_encoder) accelerator.end_training() if args.save_state: - print("saving last state.") - accelerator.save_state(os.path.join(args.output_dir, train_util.LAST_STATE_NAME)) + train_util.save_state_on_train_end(args, accelerator) del accelerator # この後メモリを使うのでこれは消す if is_main_process: - os.makedirs(args.output_dir, exist_ok=True) - ckpt_file = os.path.join(args.output_dir, model_util.get_last_ckpt_name(use_safetensors)) - - if fine_tuning: - if save_stable_diffusion_format: - print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}") - model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet, - src_stable_diffusion_ckpt, epoch, global_step, save_dtype, vae) - else: - # Create the pipeline using using the trained modules and save it. - print(f"save trained model as Diffusers to {args.output_dir}") - out_dir = os.path.join(args.output_dir, train_util.LAST_DIFFUSERS_DIR_NAME) - os.makedirs(out_dir, exist_ok=True) - model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet, - src_diffusers_model_path, vae=vae, use_safetensors=use_safetensors) - else: - print(f"save trained model to {ckpt_file}") - save_hypernetwork(ckpt_file, hypernetwork) - + src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path + train_util.save_sd_model_on_train_end(args, src_path, save_stable_diffusion_format, use_safetensors, + save_dtype, epoch, global_step, text_encoder, unet, vae) print("model saved.") @@ -544,9 +319,8 @@ if __name__ == '__main__': train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, False, True) train_util.add_training_arguments(parser, False) + train_util.add_sd_saving_arguments(parser) - parser.add_argument("--use_safetensors", action='store_true', - help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存する(save_model_as未指定時)") parser.add_argument("--diffusers_xformers", action='store_true', help='use xformers by diffusers / Diffusersでxformersを使用する') diff --git a/library/model_util.py b/library/model_util.py index 398b6404..bc824a12 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -1133,14 +1133,6 @@ def load_vae(vae_id, dtype): return vae -def get_epoch_ckpt_name(use_safetensors, epoch): - return f"epoch-{epoch:06d}" + (".safetensors" if use_safetensors else ".ckpt") - - -def get_last_ckpt_name(use_safetensors): - return f"last" + (".safetensors" if use_safetensors else ".ckpt") - - # endregion diff --git a/library/train_util.py b/library/train_util.py index 8eedf48c..5033a55b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1,7 +1,9 @@ # common functions for training +# TODO test no_token_padding option import argparse import json +import shutil import time from typing import NamedTuple from accelerate import Accelerator @@ -31,18 +33,16 @@ TOKENIZER_PATH = "openai/clip-vit-large-patch14" V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ # checkpointファイル名 -EPOCH_STATE_NAME = "epoch-{:06d}-state" -LAST_STATE_NAME = "last-state" - -EPOCH_FILE_NAME = "epoch-{:06d}" -LAST_FILE_NAME = "last" - -LAST_DIFFUSERS_DIR_NAME = "last" -EPOCH_DIFFUSERS_DIR_NAME = "epoch-{:06d}" - +EPOCH_STATE_NAME = "{}-{:06d}-state" +EPOCH_FILE_NAME = "{}-{:06d}" +EPOCH_DIFFUSERS_DIR_NAME = "{}-{:06d}" +LAST_STATE_NAME = "{}-state" +DEFAULT_EPOCH_NAME = "epoch" +DEFAULT_LAST_OUTPUT_NAME = "last" # region dataset + class ImageInfo(): def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None: self.image_key: str = image_key @@ -76,6 +76,7 @@ class BaseDataset(torch.utils.data.Dataset): self.flip_aug = flip_aug self.color_aug = color_aug self.debug_dataset = debug_dataset + self.padding_disabled = False self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2 @@ -101,6 +102,9 @@ class BaseDataset(torch.utils.data.Dataset): self.image_data: dict[str, ImageInfo] = {} + def disable_padding(self): + self.padding_disabled = True + def process_caption(self, caption): if self.shuffle_caption: tokens = caption.strip().split(",") @@ -408,11 +412,18 @@ class BaseDataset(torch.utils.data.Dataset): caption = self.process_caption(image_info.caption) captions.append(caption) - input_ids_list.append(self.get_input_ids(caption)) + if not self.padding_disabled: # this option might be omitted in future + input_ids_list.append(self.get_input_ids(caption)) example = {} example['loss_weights'] = torch.FloatTensor(loss_weights) - example['input_ids'] = torch.stack(input_ids_list) + + if self.padding_disabled: + # padding=True means pad in the batch + example['input_ids'] = self.tokenizer(captions, padding=True, truncation=True, return_tensors="pt").input_ids + else: + # batch processing seems to be good + example['input_ids'] = torch.stack(input_ids_list) if images[0] is not None: images = torch.stack(images) @@ -664,6 +675,7 @@ class FineTuningDataset(BaseDataset): return npz_file_norm, npz_file_flip + def debug_dataset(train_dataset): print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}") print("Escape for exit. / Escキーで中断、終了します") @@ -973,12 +985,13 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser): def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool): parser.add_argument("--output_dir", type=str, default=None, help="directory to output trained model / 学習後のモデル出力先ディレクトリ") + parser.add_argument("--output_name", type=str, default=None, + help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名") parser.add_argument("--save_precision", type=str, default=None, choices=[None, "float", "fp16", "bf16"], help="precision in saving / 保存時に精度を変更して保存する") - parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"], - help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)") parser.add_argument("--save_every_n_epochs", type=int, default=None, help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する") + parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する") parser.add_argument("--save_state", action="store_true", help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する") parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate") @@ -1034,6 +1047,8 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b parser.add_argument("--shuffle_caption", action="store_true", help="shuffle comma-separated caption / コンマで区切られたcaptionの各要素をshuffleする") parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子") + parser.add_argument("--caption_extention", type=str, default=None, + help="extension of caption files (backward compatibility) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)") parser.add_argument("--keep_tokens", type=int, default=None, help="keep heading N tokens when shuffling caption tokens / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す") parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする") @@ -1064,7 +1079,19 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b help="repeat dataset when training with captions / キャプションでの学習時にデータセットを繰り返す回数") -def prepare_dataset_args(args: argparse.Namespace, support_caption: bool): +def add_sd_saving_arguments(parser: argparse.ArgumentParser): + parser.add_argument("--save_model_as", type=str, default=None, choices=[None, "ckpt", "safetensors", "diffusers", "diffusers_safetensors"], + help="format to save the model (default is same to original) / モデル保存時の形式(未指定時は元モデルと同じ)") + parser.add_argument("--use_safetensors", action='store_true', + help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存する(save_model_as未指定時)") + + +def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool): + # backward compatibility + if args.caption_extention is not None: + args.caption_extension = args.caption_extention + args.caption_extention = None + if args.cache_latents: assert not args.color_aug, "when caching latents, color_aug cannot be used / latentをキャッシュするときはcolor_augは使えません" @@ -1083,7 +1110,7 @@ def prepare_dataset_args(args: argparse.Namespace, support_caption: bool): else: args.face_crop_aug_range = None - if support_caption: + if support_metadata: if args.in_json is not None and args.color_aug: print(f"latents in npz is ignored when color_aug is True / color_augを有効にした場合、npzファイルのlatentsは無視されます") @@ -1216,29 +1243,95 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod return encoder_hidden_states -def save_on_epoch_end(args: argparse.Namespace, accelerator, epoch: int, num_train_epochs: int, save_func): - if (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs: +def get_epoch_ckpt_name(args: argparse.Namespace, use_safetensors, epoch): + model_name = DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name + ckpt_name = EPOCH_FILE_NAME.format(model_name, epoch) + (".safetensors" if use_safetensors else ".ckpt") + return model_name, ckpt_name + + +def save_on_epoch_end(args: argparse.Namespace, save_func, remove_old_func, epoch_no: int, num_train_epochs: int): + saving = epoch_no % args.save_every_n_epochs == 0 and epoch_no < num_train_epochs + remove_epoch_no = None + if saving: print("saving checkpoint.") os.makedirs(args.output_dir, exist_ok=True) - ckpt_file = os.path.join(args.output_dir, EPOCH_FILE_NAME.format(epoch + 1) + '.' + args.save_model_as) - save_func(ckpt_file) + save_func() - if args.save_state: - print("saving state.") - accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(epoch + 1))) + if args.save_last_n_epochs is not None: + remove_epoch_no = epoch_no - args.save_every_n_epochs * args.save_last_n_epochs + remove_old_func(remove_epoch_no) + return saving, remove_epoch_no -def save_last_state(args, accelerator): +def save_sd_model_on_epoch_end(args: argparse.Namespace, accelerator, src_path: str, save_stable_diffusion_format: bool, use_safetensors: bool, save_dtype: torch.dtype, epoch: int, num_train_epochs: int, global_step: int, text_encoder, unet, vae): + epoch_no = epoch + 1 + model_name, ckpt_name = get_epoch_ckpt_name(args, use_safetensors, epoch_no) + + if save_stable_diffusion_format: + def save_sd(): + ckpt_file = os.path.join(args.output_dir, ckpt_name) + model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet, + src_path, epoch_no, global_step, save_dtype, vae) + + def remove_sd(old_epoch_no): + _, old_ckpt_name = get_epoch_ckpt_name(args, use_safetensors, old_epoch_no) + old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) + if os.path.exists(old_ckpt_file): + os.remove(old_ckpt_file) + + save_func = save_sd + remove_old_func = remove_sd + else: + def save_du(): + out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, epoch_no)) + os.makedirs(out_dir, exist_ok=True) + model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet, + src_path, vae=vae, use_safetensors=use_safetensors) + + def remove_du(old_epoch_no): + out_dir_old = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, old_epoch_no)) + if os.path.exists(out_dir_old): + shutil.rmtree(out_dir_old) + + save_func = save_du + remove_old_func = remove_du + + saving, remove_epoch_no = save_on_epoch_end(args, save_func, remove_old_func, epoch_no, num_train_epochs) + if saving and args.save_state: + print("saving state.") + accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no))) + if remove_epoch_no is not None: + state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no)) + if os.path.exists(state_dir_old): + shutil.rmtree(state_dir_old) + + +def save_state_on_train_end(args: argparse.Namespace, accelerator): print("saving last state.") os.makedirs(args.output_dir, exist_ok=True) - accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME)) + model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name + accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name))) -def save_last_model(args, save_func): - os.makedirs(args.output_dir, exist_ok=True) - ckpt_file = os.path.join(args.output_dir, LAST_FILE_NAME + '.' + args.save_model_as) - print(f"save trained model to {ckpt_file}") - save_func(ckpt_file) - print("model saved.") +def save_sd_model_on_train_end(args: argparse.Namespace, src_path: str, save_stable_diffusion_format: bool, use_safetensors: bool, save_dtype: torch.dtype, epoch: int, global_step: int, text_encoder, unet, vae): + model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name + + if save_stable_diffusion_format: + os.makedirs(args.output_dir, exist_ok=True) + + ckpt_name = model_name + (".safetensors" if use_safetensors else ".ckpt") + ckpt_file = os.path.join(args.output_dir, ckpt_name) + + print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}") + model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet, + src_path, epoch, global_step, save_dtype, vae) + else: + print(f"save trained model as Diffusers to {args.output_dir}") + + out_dir = os.path.join(args.output_dir, model_name) + os.makedirs(out_dir, exist_ok=True) + + model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet, + src_path, vae=vae, use_safetensors=use_safetensors) # endregion diff --git a/train_db.py b/train_db.py index 4a30d2d5..a03ab563 100644 --- a/train_db.py +++ b/train_db.py @@ -1,4 +1,5 @@ # DreamBooth training +# XXX dropped option: fine_tune import gc import time @@ -31,364 +32,49 @@ import library.train_util as train_util from library.train_util import DreamBoothDataset, FineTuningDataset -# region dataset - -class DreamBoothOrFineTuningDataset(torch.utils.data.Dataset): - def __init__(self, batch_size, fine_tuning, train_img_path_captions, reg_img_path_captions, tokenizer, resolution, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, shuffle_caption, disable_padding, debug_dataset) -> None: - super().__init__() - - self.batch_size = batch_size - self.fine_tuning = fine_tuning - self.train_img_path_captions = train_img_path_captions - self.reg_img_path_captions = reg_img_path_captions - self.tokenizer = tokenizer - self.width, self.height = resolution - self.size = min(self.width, self.height) # 短いほう - self.prior_loss_weight = prior_loss_weight - self.face_crop_aug_range = face_crop_aug_range - self.random_crop = random_crop - self.debug_dataset = debug_dataset - self.shuffle_caption = shuffle_caption - self.disable_padding = disable_padding - self.latents_cache = None - self.enable_bucket = False - - # augmentation - flip_p = 0.5 if flip_aug else 0.0 - if color_aug: - # わりと弱めの色合いaugmentation:brightness/contrastあたりは画像のpixel valueの最大値・最小値を変えてしまうのでよくないのではという想定でgamma/hue/saturationあたりを触る - self.aug = albu.Compose([ - albu.OneOf([ - # albu.RandomBrightnessContrast(0.05, 0.05, p=.2), - albu.HueSaturationValue(5, 8, 0, p=.2), - # albu.RGBShift(5, 5, 5, p=.1), - albu.RandomGamma((95, 105), p=.5), - ], p=.33), - albu.HorizontalFlip(p=flip_p) - ], p=1.) - elif flip_aug: - self.aug = albu.Compose([ - albu.HorizontalFlip(p=flip_p) - ], p=1.) - else: - self.aug = None - - self.num_train_images = len(self.train_img_path_captions) - self.num_reg_images = len(self.reg_img_path_captions) - - self.enable_reg_images = self.num_reg_images > 0 - - if self.enable_reg_images and self.num_train_images < self.num_reg_images: - print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります") - - self.image_transforms = transforms.Compose( - [ - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ] - ) - - # bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る) - def make_buckets_with_caching(self, enable_bucket, vae, min_size, max_size): - self.enable_bucket = enable_bucket - - cache_latents = vae is not None - if cache_latents: - if enable_bucket: - print("cache latents with bucketing") - else: - print("cache latents") - else: - if enable_bucket: - print("make buckets") - else: - print("prepare dataset") - - # bucketingを用意する - if enable_bucket: - bucket_resos, bucket_aspect_ratios = model_util.make_bucket_resolutions((self.width, self.height), min_size, max_size) - else: - # bucketはひとつだけ、すべての画像は同じ解像度 - bucket_resos = [(self.width, self.height)] - bucket_aspect_ratios = [self.width / self.height] - bucket_aspect_ratios = np.array(bucket_aspect_ratios) - - # 画像の解像度、latentをあらかじめ取得する - img_ar_errors = [] - self.size_lat_cache = {} - for image_path, _ in tqdm(self.train_img_path_captions + self.reg_img_path_captions): - if image_path in self.size_lat_cache: - continue - - image = self.load_image(image_path)[0] - image_height, image_width = image.shape[0:2] - - if not enable_bucket: - # assert image_width == self.width and image_height == self.height, \ - # f"all images must have specific resolution when bucketing is disabled / bucketを使わない場合、すべての画像のサイズを統一してください: {image_path}" - reso = (self.width, self.height) - else: - # bucketを決める - aspect_ratio = image_width / image_height - ar_errors = bucket_aspect_ratios - aspect_ratio - bucket_id = np.abs(ar_errors).argmin() - reso = bucket_resos[bucket_id] - ar_error = ar_errors[bucket_id] - img_ar_errors.append(ar_error) - - if cache_latents: - image = self.resize_and_trim(image, reso) - - # latentを取得する - if cache_latents: - img_tensor = self.image_transforms(image) - img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype) - latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu") - else: - latents = None - - self.size_lat_cache[image_path] = (reso, latents) - - # 画像をbucketに分割する - self.buckets = [[] for _ in range(len(bucket_resos))] - reso_to_index = {} - for i, reso in enumerate(bucket_resos): - reso_to_index[reso] = i - - def split_to_buckets(is_reg, img_path_captions): - for image_path, caption in img_path_captions: - reso, _ = self.size_lat_cache[image_path] - bucket_index = reso_to_index[reso] - self.buckets[bucket_index].append((is_reg, image_path, caption)) - - split_to_buckets(False, self.train_img_path_captions) - - if self.enable_reg_images: - l = [] - while len(l) < len(self.train_img_path_captions): - l += self.reg_img_path_captions - l = l[:len(self.train_img_path_captions)] - split_to_buckets(True, l) - - if enable_bucket: - print("number of images with repeats / 繰り返し回数込みの各bucketの画像枚数") - for i, (reso, imgs) in enumerate(zip(bucket_resos, self.buckets)): - print(f"bucket {i}: resolution {reso}, count: {len(imgs)}") - img_ar_errors = np.array(img_ar_errors) - print(f"mean ar error: {np.mean(np.abs(img_ar_errors))}") - - # 参照用indexを作る - self.buckets_indices = [] - for bucket_index, bucket in enumerate(self.buckets): - batch_count = int(math.ceil(len(bucket) / self.batch_size)) - for batch_index in range(batch_count): - self.buckets_indices.append((bucket_index, batch_index)) - - self.shuffle_buckets() - self._length = len(self.buckets_indices) - - # どのサイズにリサイズするか→トリミングする方向で - def resize_and_trim(self, image, reso): - image_height, image_width = image.shape[0:2] - ar_img = image_width / image_height - ar_reso = reso[0] / reso[1] - if ar_img > ar_reso: # 横が長い→縦を合わせる - scale = reso[1] / image_height - else: - scale = reso[0] / image_width - resized_size = (int(image_width * scale + .5), int(image_height * scale + .5)) - - image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ - if resized_size[0] > reso[0]: - trim_size = resized_size[0] - reso[0] - image = image[:, trim_size//2:trim_size//2 + reso[0]] - elif resized_size[1] > reso[1]: - trim_size = resized_size[1] - reso[1] - image = image[trim_size//2:trim_size//2 + reso[1]] - assert image.shape[0] == reso[1] and image.shape[1] == reso[0], \ - f"internal error, illegal trimmed size: {image.shape}, {reso}" - return image - - def shuffle_buckets(self): - random.shuffle(self.buckets_indices) - for bucket in self.buckets: - random.shuffle(bucket) - - def load_image(self, image_path): - image = Image.open(image_path) - if not image.mode == "RGB": - image = image.convert("RGB") - img = np.array(image, np.uint8) - - face_cx = face_cy = face_w = face_h = 0 - if self.face_crop_aug_range is not None: - tokens = os.path.splitext(os.path.basename(image_path))[0].split('_') - if len(tokens) >= 5: - face_cx = int(tokens[-4]) - face_cy = int(tokens[-3]) - face_w = int(tokens[-2]) - face_h = int(tokens[-1]) - - return img, face_cx, face_cy, face_w, face_h - - # いい感じに切り出す - def crop_target(self, image, face_cx, face_cy, face_w, face_h): - height, width = image.shape[0:2] - if height == self.height and width == self.width: - return image - - # 画像サイズはsizeより大きいのでリサイズする - face_size = max(face_w, face_h) - min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率) - min_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[1]))) # 指定した顔最小サイズ - max_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[0]))) # 指定した顔最大サイズ - if min_scale >= max_scale: # range指定がmin==max - scale = min_scale - else: - scale = random.uniform(min_scale, max_scale) - - nh = int(height * scale + .5) - nw = int(width * scale + .5) - assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}" - image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA) - face_cx = int(face_cx * scale + .5) - face_cy = int(face_cy * scale + .5) - height, width = nh, nw - - # 顔を中心として448*640とかへを切り出す - for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))): - p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置 - - if self.random_crop: - # 背景も含めるために顔を中心に置く確率を高めつつずらす - range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう - p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数 - else: - # range指定があるときのみ、すこしだけランダムに(わりと適当) - if self.face_crop_aug_range[0] != self.face_crop_aug_range[1]: - if face_size > self.size // 10 and face_size >= 40: - p1 = p1 + random.randint(-face_size // 20, +face_size // 20) - - p1 = max(0, min(p1, length - target_size)) - - if axis == 0: - image = image[p1:p1 + target_size, :] - else: - image = image[:, p1:p1 + target_size] - - return image - - def __len__(self): - return self._length - - def __getitem__(self, index): - if index == 0: - self.shuffle_buckets() - - bucket = self.buckets[self.buckets_indices[index][0]] - image_index = self.buckets_indices[index][1] * self.batch_size - - latents_list = [] - images = [] - captions = [] - loss_weights = [] - - for is_reg, image_path, caption in bucket[image_index:image_index + self.batch_size]: - loss_weights.append(self.prior_loss_weight if is_reg else 1.0) - - # image/latentsを処理する - reso, latents = self.size_lat_cache[image_path] - - if latents is None: - # 画像を読み込み必要ならcropする - img, face_cx, face_cy, face_w, face_h = self.load_image(image_path) - im_h, im_w = img.shape[0:2] - - if self.enable_bucket: - img = self.resize_and_trim(img, reso) - else: - if face_cx > 0: # 顔位置情報あり - img = self.crop_target(img, face_cx, face_cy, face_w, face_h) - elif im_h > self.height or im_w > self.width: - assert self.random_crop, f"image too large, and face_crop_aug_range and random_crop are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_cropを有効にしてください" - if im_h > self.height: - p = random.randint(0, im_h - self.height) - img = img[p:p + self.height] - if im_w > self.width: - p = random.randint(0, im_w - self.width) - img = img[:, p:p + self.width] - - im_h, im_w = img.shape[0:2] - assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_path}" - - # augmentation - if self.aug is not None: - img = self.aug(image=img)['image'] - - image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる - else: - image = None - - images.append(image) - latents_list.append(latents) - - # captionを処理する - if self.shuffle_caption: # captionのshuffleをする - tokens = caption.strip().split(",") - random.shuffle(tokens) - caption = ",".join(tokens).strip() - captions.append(caption) - - # input_idsをpadしてTensor変換 - if self.disable_padding: - # paddingしない:padding==Trueはバッチの中の最大長に合わせるだけ(やはりバグでは……?) - input_ids = self.tokenizer(captions, padding=True, truncation=True, return_tensors="pt").input_ids - else: - # paddingする - input_ids = self.tokenizer(captions, padding='max_length', truncation=True, return_tensors='pt').input_ids - - example = {} - example['loss_weights'] = torch.FloatTensor(loss_weights) - example['input_ids'] = input_ids - if images[0] is not None: - images = torch.stack(images) - images = images.to(memory_format=torch.contiguous_format).float() - else: - images = None - example['images'] = images - example['latents'] = torch.stack(latents_list) if latents_list[0] is not None else None - if self.debug_dataset: - example['image_paths'] = [image_path for _, image_path, _ in bucket[image_index:image_index + self.batch_size]] - example['captions'] = captions - return example -# endregion - - def collate_fn(examples): return examples[0] def train(args): - if args.caption_extention is not None: - args.caption_extension = args.caption_extention - args.caption_extention = None + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, False) - fine_tuning = args.fine_tuning cache_latents = args.cache_latents - # latentsをキャッシュする場合のオプション設定を確認する - if cache_latents: - assert not args.flip_aug and not args.color_aug, "when caching latents, augmentation cannot be used / latentをキャッシュするときはaugmentationは使えません" + if args.seed is not None: + set_seed(args.seed) # 乱数系列を初期化する - # その他のオプション設定を確認する - if args.v_parameterization and not args.v2: - print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") - if args.v2 and args.clip_skip is not None: - print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") + tokenizer = train_util.load_tokenizer(args) - # モデル形式のオプション設定を確認する: - load_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path) + train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir, + tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens, + args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, args.prior_loss_weight, + args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset) + if args.no_token_padding: + train_dataset.disable_padding() + train_dataset.make_buckets() + if args.debug_dataset: + train_util.debug_dataset(train_dataset) + + # acceleratorを準備する + print("prepare accelerator") + + if args.gradient_accumulation_steps > 1: + print(f"gradient_accumulation_steps is {args.gradient_accumulation_steps}. accelerate does not support gradient_accumulation_steps when training multiple models (U-Net and Text Encoder), so something might be wrong") + print( + f"gradient_accumulation_stepsが{args.gradient_accumulation_steps}に設定されています。accelerateは複数モデル(U-NetおよびText Encoder)の学習時にgradient_accumulation_stepsをサポートしていないため結果は未知数です") + + accelerator, unwrap_model = train_util.prepare_accelerator(args) + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + + # モデルを読み込む + text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype) + + # verify load/save model formats if load_stable_diffusion_format: src_stable_diffusion_ckpt = args.pretrained_model_name_or_path src_diffusers_model_path = None @@ -403,202 +89,6 @@ def train(args): save_stable_diffusion_format = args.save_model_as.lower() == 'ckpt' or args.save_model_as.lower() == 'safetensors' use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower()) - # 乱数系列を初期化する - if args.seed is not None: - set_seed(args.seed) - - # 学習データを用意する - def read_caption(img_path): - # captionの候補ファイル名を作る - base_name = os.path.splitext(img_path)[0] - base_name_face_det = base_name - tokens = base_name.split("_") - if len(tokens) >= 5: - base_name_face_det = "_".join(tokens[:-4]) - cap_paths = [base_name + args.caption_extension, base_name_face_det + args.caption_extension] - - caption = None - for cap_path in cap_paths: - if os.path.isfile(cap_path): - with open(cap_path, "rt", encoding='utf-8') as f: - lines = f.readlines() - assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}" - caption = lines[0].strip() - break - return caption - - def load_dreambooth_dir(dir): - tokens = os.path.basename(dir).split('_') - try: - n_repeats = int(tokens[0]) - except ValueError as e: - return 0, [] - - caption_by_folder = '_'.join(tokens[1:]) - - print(f"found directory {n_repeats}_{caption_by_folder}") - - img_paths = glob.glob(os.path.join(dir, "*.png")) + glob.glob(os.path.join(dir, "*.jpg")) + \ - glob.glob(os.path.join(dir, "*.webp")) - - # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う(v11から仕様変更した) - captions = [] - for img_path in img_paths: - cap_for_img = read_caption(img_path) - captions.append(caption_by_folder if cap_for_img is None else cap_for_img) - - return n_repeats, list(zip(img_paths, captions)) - - print("prepare train images.") - train_img_path_captions = [] - - if fine_tuning: - img_paths = glob.glob(os.path.join(args.train_data_dir, "*.png")) + \ - glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.webp")) - for img_path in tqdm(img_paths): - caption = read_caption(img_path) - assert caption is not None and len( - caption) > 0, f"no caption for image. check caption_extension option / キャプションファイルが見つからないかcaptionが空です。caption_extensionオプションを確認してください: {img_path}" - - train_img_path_captions.append((img_path, caption)) - - if args.dataset_repeats is not None: - l = [] - for _ in range(args.dataset_repeats): - l.extend(train_img_path_captions) - train_img_path_captions = l - else: - train_dirs = os.listdir(args.train_data_dir) - for dir in train_dirs: - n_repeats, img_caps = load_dreambooth_dir(os.path.join(args.train_data_dir, dir)) - for _ in range(n_repeats): - train_img_path_captions.extend(img_caps) - print(f"{len(train_img_path_captions)} train images with repeating.") - - reg_img_path_captions = [] - if args.reg_data_dir: - print("prepare reg images.") - reg_dirs = os.listdir(args.reg_data_dir) - for dir in reg_dirs: - n_repeats, img_caps = load_dreambooth_dir(os.path.join(args.reg_data_dir, dir)) - for _ in range(n_repeats): - reg_img_path_captions.extend(img_caps) - print(f"{len(reg_img_path_captions)} reg images.") - - # データセットを準備する - resolution = tuple([int(r) for r in args.resolution.split(',')]) - if len(resolution) == 1: - resolution = (resolution[0], resolution[0]) - assert len(resolution) == 2, \ - f"resolution must be 'size' or 'width,height' / resolutionは'サイズ'または'幅','高さ'で指定してください: {args.resolution}" - - if args.enable_bucket: - assert min(resolution) >= args.min_bucket_reso, f"min_bucket_reso must be equal or greater than resolution / min_bucket_resoは解像度の数値以上で指定してください" - assert max(resolution) <= args.max_bucket_reso, f"max_bucket_reso must be equal or less than resolution / max_bucket_resoは解像度の数値以下で指定してください" - - if args.face_crop_aug_range is not None: - face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(',')]) - assert len( - face_crop_aug_range) == 2, f"face_crop_aug_range must be two floats / face_crop_aug_rangeは'下限,上限'で指定してください: {args.face_crop_aug_range}" - else: - face_crop_aug_range = None - - # tokenizerを読み込む - print("prepare tokenizer") - if args.v2: - tokenizer = CLIPTokenizer.from_pretrained(train_util.V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer") - else: - tokenizer = CLIPTokenizer.from_pretrained(train_util.TOKENIZER_PATH) - - print("prepare dataset") - train_dataset = DreamBoothOrFineTuningDataset(args.train_batch_size, fine_tuning, train_img_path_captions, reg_img_path_captions, tokenizer, resolution, - args.prior_loss_weight, args.flip_aug, args.color_aug, face_crop_aug_range, args.random_crop, - args.shuffle_caption, args.no_token_padding, args.debug_dataset) - - if args.debug_dataset: - train_dataset.make_buckets_with_caching(args.enable_bucket, None, args.min_bucket_reso, - args.max_bucket_reso) # デバッグ用にcacheなしで作る - print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}") - print("Escape for exit. / Escキーで中断、終了します") - for example in train_dataset: - for im, cap, lw in zip(example['images'], example['captions'], example['loss_weights']): - im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8) - im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c - im = im[:, :, ::-1] # RGB -> BGR (OpenCV) - print(f'size: {im.shape[1]}*{im.shape[0]}, caption: "{cap}", loss weight: {lw}') - cv2.imshow("img", im) - k = cv2.waitKey() - cv2.destroyAllWindows() - if k == 27: - break - if k == 27: - break - return - - # acceleratorを準備する - # gradient accumulationは複数モデルを学習する場合には対応していないとのことなので、1固定にする - print("prepare accelerator") - if args.logging_dir is None: - log_with = None - logging_dir = None - else: - log_with = "tensorboard" - log_prefix = "" if args.log_prefix is None else args.log_prefix - logging_dir = args.logging_dir + "/" + log_prefix + time.strftime('%Y%m%d%H%M%S', time.localtime()) - accelerator = Accelerator(gradient_accumulation_steps=1, mixed_precision=args.mixed_precision, - log_with=log_with, logging_dir=logging_dir) - - # accelerateの互換性問題を解決する - accelerator_0_15 = True - try: - accelerator.unwrap_model("dummy", True) - print("Using accelerator 0.15.0 or above.") - except TypeError: - accelerator_0_15 = False - - def unwrap_model(model): - if accelerator_0_15: - return accelerator.unwrap_model(model, True) - return accelerator.unwrap_model(model) - - # mixed precisionに対応した型を用意しておき適宜castする - weight_dtype = torch.float32 - if args.mixed_precision == "fp16": - weight_dtype = torch.float16 - elif args.mixed_precision == "bf16": - weight_dtype = torch.bfloat16 - - save_dtype = None - if args.save_precision == "fp16": - save_dtype = torch.float16 - elif args.save_precision == "bf16": - save_dtype = torch.bfloat16 - elif args.save_precision == "float": - save_dtype = torch.float32 - - # モデルを読み込む - if load_stable_diffusion_format: - print("load StableDiffusion checkpoint") - text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.pretrained_model_name_or_path) - else: - print("load Diffusers pretrained models") - pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, tokenizer=None, safety_checker=None) - # , torch_dtype=weight_dtype) ここでtorch_dtypeを指定すると学習時にエラーになる - text_encoder = pipe.text_encoder - vae = pipe.vae - unet = pipe.unet - del pipe - - # # 置換するCLIPを読み込む - # if args.replace_clip_l14_336: - # text_encoder = load_clip_l14_336(weight_dtype) - # print(f"large clip {CLIP_ID_L14_336} is loaded") - - # VAEを読み込む - if args.vae is not None: - vae = model_util.load_vae(args.vae, weight_dtype) - print("additional VAE loaded") - # モデルに xformers とか memory efficient attention を組み込む train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) @@ -608,23 +98,29 @@ def train(args): vae.requires_grad_(False) vae.eval() with torch.no_grad(): - train_dataset.make_buckets_with_caching(args.enable_bucket, vae, args.min_bucket_reso, args.max_bucket_reso) + train_dataset.cache_latents(vae) vae.to("cpu") if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() - else: - train_dataset.make_buckets_with_caching(args.enable_bucket, None, args.min_bucket_reso, args.max_bucket_reso) - vae.requires_grad_(False) - vae.eval() + # 学習を準備する:モデルを適切な状態にする + if args.stop_text_encoder_training is None: + args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end + + train_text_encoder = args.stop_text_encoder_training >= 0 unet.requires_grad_(True) # 念のため追加 - text_encoder.requires_grad_(True) + text_encoder.requires_grad_(train_text_encoder) if args.gradient_checkpointing: unet.enable_gradient_checkpointing() text_encoder.gradient_checkpointing_enable() + if not cache_latents: + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=weight_dtype) + # 学習に必要なクラスを準備する print("prepare optimizer, data loader etc.") @@ -639,7 +135,10 @@ def train(args): else: optimizer_class = torch.optim.AdamW - trainable_params = (itertools.chain(unet.parameters(), text_encoder.parameters())) + if train_text_encoder: + trainable_params = (itertools.chain(unet.parameters(), text_encoder.parameters())) + else: + trainable_params = unet.parameters() # betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略 optimizer = optimizer_class(trainable_params, lr=args.learning_rate) @@ -662,20 +161,15 @@ def train(args): text_encoder.to(weight_dtype) # acceleratorがなんかよろしくやってくれるらしい - unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader, lr_scheduler) - - if not cache_latents: - vae.to(accelerator.device, dtype=weight_dtype) + if train_text_encoder: + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする if args.full_fp16: - org_unscale_grads = accelerator.scaler._unscale_grads_ - - def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): - return org_unscale_grads(optimizer, inv_scale, found_inf, True) - - accelerator.scaler._unscale_grads_ = _unscale_grads_replacer + train_util.patch_accelerator_for_fp16_training(accelerator) # resumeする if args.resume is not None: @@ -683,7 +177,8 @@ def train(args): accelerator.load_state(args.resume) # epoch数を計算する - num_train_epochs = math.ceil(args.max_train_steps / len(train_dataloader)) + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) # 学習する total_batch_size = args.train_batch_size # * accelerator.num_processes @@ -700,33 +195,28 @@ def train(args): progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") global_step = 0 - # v12で更新:clip_sample=Falseに - # Diffusersのtrain_dreambooth.pyがconfigから持ってくるように変更されたので、clip_sample=Falseになるため、それに合わせる - # 既存の1.4/1.5/2.0/2.1はすべてschedulerのconfigは(クラス名を除いて)同じ - # よくソースを見たら学習時はclip_sampleは関係ないや(;'∀')  noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False) if accelerator.is_main_process: accelerator.init_trackers("dreambooth") - # 以下 train_dreambooth.py からほぼコピペ for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") # 指定したステップ数までText Encoderを学習する:epoch最初の状態 - train_text_encoder = args.stop_text_encoder_training is None or global_step < args.stop_text_encoder_training unet.train() - if train_text_encoder: + # train==True is required to enable gradient_checkpointing + if args.gradient_checkpointing or global_step < args.stop_text_encoder_training: text_encoder.train() loss_total = 0 for step, batch in enumerate(train_dataloader): # 指定したステップ数でText Encoderの学習を止める - stop_text_encoder_training = args.stop_text_encoder_training is not None and global_step == args.stop_text_encoder_training - if stop_text_encoder_training: + if global_step == args.stop_text_encoder_training: print(f"stop text encoder training at step {global_step}") - text_encoder.train(False) + if not args.gradient_checkpointing: + text_encoder.train(False) text_encoder.requires_grad_(False) with accelerator.accumulate(unet): @@ -742,6 +232,11 @@ def train(args): noise = torch.randn_like(latents, device=latents.device) b_size = latents.shape[0] + # Get the text embedding for conditioning + with torch.set_grad_enabled(global_step < args.stop_text_encoder_training): + input_ids = batch["input_ids"].to(accelerator.device) + encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder) + # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) timesteps = timesteps.long() @@ -750,20 +245,11 @@ def train(args): # (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - # Get the text embedding for conditioning - if args.clip_skip is None: - encoder_hidden_states = text_encoder(batch["input_ids"])[0] - else: - enc_out = text_encoder(batch["input_ids"], output_hidden_states=True, return_dict=True) - encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip] - encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) - # Predict the noise residual noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample if args.v_parameterization: # v-parameterization training - # Diffusers 0.10.0からv_parameterizationの学習に対応したのでそちらを使う target = noise_scheduler.get_velocity(latents, noise, timesteps) else: target = noise @@ -778,7 +264,10 @@ def train(args): accelerator.backward(loss) if accelerator.sync_gradients: - params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters())) + if train_text_encoder: + params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters())) + else: + params_to_clip = unet.parameters() accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm) optimizer.step() @@ -810,35 +299,9 @@ def train(args): accelerator.wait_for_everyone() if args.save_every_n_epochs is not None: - if (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs: - print("saving checkpoint.") - if save_stable_diffusion_format: - os.makedirs(args.output_dir, exist_ok=True) - ckpt_file = os.path.join(args.output_dir, model_util.get_epoch_ckpt_name(use_safetensors, epoch + 1)) - model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, unwrap_model(text_encoder), unwrap_model(unet), - src_stable_diffusion_ckpt, epoch + 1, global_step, save_dtype, vae) - if args.save_last_n_epochs is not None: - old_ckpt_file = os.path.join(args.output_dir, model_util.get_epoch_ckpt_name(use_safetensors, epoch + 1 - args.save_every_n_epochs * args.save_last_n_epochs)) - if os.path.exists(old_ckpt_file): - os.remove(old_ckpt_file) - else: - out_dir = os.path.join(args.output_dir, train_util.EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1)) - os.makedirs(out_dir, exist_ok=True) - model_util.save_diffusers_checkpoint(args.v2, out_dir, unwrap_model(text_encoder), - unwrap_model(unet), src_diffusers_model_path, - use_safetensors=use_safetensors) - if args.save_last_n_epochs is not None: - out_dir_old = os.path.join(args.output_dir, train_util.EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1 - args.save_every_n_epochs * args.save_last_n_epochs)) - if os.path.exists(out_dir_old): - shutil.rmtree(out_dir_old) - - if args.save_state: - print("saving state.") - accelerator.save_state(os.path.join(args.output_dir, train_util.EPOCH_STATE_NAME.format(epoch + 1))) - if args.save_last_n_epochs is not None: - state_dir_old = os.path.join(args.output_dir, train_util.EPOCH_STATE_NAME.format(epoch + 1 - args.save_every_n_epochs * args.save_last_n_epochs)) - if os.path.exists(state_dir_old): - shutil.rmtree(state_dir_old) + src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path + train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors, + save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae) is_main_process = accelerator.is_main_process if is_main_process: @@ -854,107 +317,24 @@ def train(args): del accelerator # この後メモリを使うのでこれは消す if is_main_process: - os.makedirs(args.output_dir, exist_ok=True) - if save_stable_diffusion_format: - ckpt_file = os.path.join(args.output_dir, model_util.get_last_ckpt_name(use_safetensors)) - print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}") - model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet, - src_stable_diffusion_ckpt, epoch, global_step, save_dtype, vae) - else: - print(f"save trained model as Diffusers to {args.output_dir}") - out_dir = os.path.join(args.output_dir, train_util.LAST_DIFFUSERS_DIR_NAME) - os.makedirs(out_dir, exist_ok=True) - model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet, src_diffusers_model_path, - use_safetensors=use_safetensors) + src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path + train_util.save_sd_model_on_train_end(args, src_path, save_stable_diffusion_format, use_safetensors, + save_dtype, epoch, global_step, text_encoder, unet, vae) print("model saved.") if __name__ == '__main__': - # torch.cuda.set_per_process_memory_fraction(0.48) parser = argparse.ArgumentParser() - parser.add_argument("--v2", action='store_true', - help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む') - parser.add_argument("--v_parameterization", action='store_true', - help='enable v-parameterization training / v-parameterization学習を有効にする') - parser.add_argument("--pretrained_model_name_or_path", type=str, default=None, - help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル") - # parser.add_argument("--replace_clip_l14_336", action='store_true', - # help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える") - parser.add_argument("--fine_tuning", action="store_true", - help="fine tune the model instead of DreamBooth / DreamBoothではなくfine tuningする") - parser.add_argument("--shuffle_caption", action="store_true", - help="shuffle comma-separated caption / コンマで区切られたcaptionの各要素をshuffleする") - parser.add_argument("--caption_extention", type=str, default=None, - help="extension of caption files (backward compatibility) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)") - parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子") - parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ") - parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ") - parser.add_argument("--dataset_repeats", type=int, default=None, - help="repeat dataset in fine tuning / fine tuning時にデータセットを繰り返す回数") - parser.add_argument("--output_dir", type=str, default=None, - help="directory to output trained model / 学習後のモデル出力先ディレクトリ") - parser.add_argument("--save_precision", type=str, default=None, - choices=[None, "float", "fp16", "bf16"], help="precision in saving (available in StableDiffusion checkpoint) / 保存時に精度を変更して保存する(StableDiffusion形式での保存時のみ有効)") - parser.add_argument("--save_model_as", type=str, default=None, choices=[None, "ckpt", "safetensors", "diffusers", "diffusers_safetensors"], - help="format to save the model (default is same to original) / モデル保存時の形式(未指定時は元モデルと同じ)") - parser.add_argument("--use_safetensors", action='store_true', - help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存する(save_model_as未指定時)") - parser.add_argument("--save_every_n_epochs", type=int, default=None, - help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する") - parser.add_argument("--save_last_n_epochs", type=int, default=None, - help="save last N checkpoints / 最大Nエポック保存する") - parser.add_argument("--save_state", action="store_true", - help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する") - parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate") - parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み") + + train_util.add_sd_models_arguments(parser) + train_util.add_dataset_arguments(parser, True, False) + train_util.add_training_arguments(parser, True) + train_util.add_sd_saving_arguments(parser) + parser.add_argument("--no_token_padding", action="store_true", help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)") parser.add_argument("--stop_text_encoder_training", type=int, default=None, - help="steps to stop text encoder training / Text Encoderの学習を止めるステップ数") - parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする") - parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする") - parser.add_argument("--face_crop_aug_range", type=str, default=None, - help="enable face-centered crop augmentation and its range (e.g. 2.0,4.0) / 学習時に顔を中心とした切り出しaugmentationを有効にするときは倍率を指定する(例:2.0,4.0)") - parser.add_argument("--random_crop", action="store_true", - help="enable random crop (for style training in face-centered crop augmentation) / ランダムな切り出しを有効にする(顔を中心としたaugmentationを行うときに画風の学習用に指定する)") - parser.add_argument("--debug_dataset", action="store_true", - help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)") - parser.add_argument("--resolution", type=str, default=None, - help="resolution in training ('size' or 'width,height') / 学習時の画像解像度('サイズ'指定、または'幅,高さ'指定)") - parser.add_argument("--train_batch_size", type=int, default=1, - help="batch size for training (1 means one train or reg data, not train/reg pair) / 学習時のバッチサイズ(1でtrain/regをそれぞれ1件ずつ学習)") - parser.add_argument("--use_8bit_adam", action="store_true", - help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)") - parser.add_argument("--mem_eff_attn", action="store_true", - help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う") - parser.add_argument("--xformers", action="store_true", - help="use xformers for CrossAttention / CrossAttentionにxformersを使う") - parser.add_argument("--vae", type=str, default=None, - help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ") - parser.add_argument("--cache_latents", action="store_true", - help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheする(augmentationは使用不可)") - parser.add_argument("--enable_bucket", action="store_true", - help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする") - parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度") - parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最小解像度") - parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率") - parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数") - parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed") - parser.add_argument("--gradient_checkpointing", action="store_true", - help="enable gradient checkpointing / grandient checkpointingを有効にする") - parser.add_argument("--mixed_precision", type=str, default="no", - choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度") - parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する") - parser.add_argument("--clip_skip", type=int, default=None, - help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)") - parser.add_argument("--logging_dir", type=str, default=None, - help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する") - parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列") - parser.add_argument("--lr_scheduler", type=str, default="constant", - help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup") - parser.add_argument("--lr_warmup_steps", type=int, default=0, - help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)") + help="steps to stop text encoder training, -1 for no training / Text Encoderの学習を止めるステップ数、-1で最初から学習しない") args = parser.parse_args() train(args) - diff --git a/train_network.py b/train_network.py index 35f50567..bfb2d860 100644 --- a/train_network.py +++ b/train_network.py @@ -1,6 +1,7 @@ import gc import importlib import json +import shutil import time import argparse import math @@ -143,8 +144,6 @@ def train(args): if args.full_fp16: assert args.mixed_precision == "fp16", "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" print("enable full fp16 training.") - # unet.to(weight_dtype) - # text_encoder.to(weight_dtype) network.to(weight_dtype) # acceleratorがなんかよろしくやってくれるらしい @@ -163,10 +162,14 @@ def train(args): unet.requires_grad_(False) unet.to(accelerator.device, dtype=weight_dtype) - unet.eval() text_encoder.requires_grad_(False) text_encoder.to(accelerator.device, dtype=weight_dtype) - text_encoder.eval() + if args.gradient_checkpointing: # according to TI example in Diffusers, train is required + unet.train() + text_encoder.train() + else: + unet.eval() + text_encoder.eval() network.prepare_grad_etc(text_encoder, unet) @@ -294,9 +297,29 @@ def train(args): accelerator.wait_for_everyone() if args.save_every_n_epochs is not None: - def save_func(file): - unwrap_model(network).save_weights(file, save_dtype) - train_util.save_on_epoch_end(args, accelerator, epoch, num_train_epochs, save_func) + model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name + + def save_func(): + ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as + ckpt_file = os.path.join(args.output_dir, ckpt_name) + unwrap_model(network).save_weights(ckpt_file, save_dtype) + + def remove_old_func(old_epoch_no): + old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as + old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) + if os.path.exists(old_ckpt_file): + os.remove(old_ckpt_file) + + saving, remove_epoch_no = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs) + if saving and args.save_state: + print("saving state.") + accelerator.save_state(os.path.join(args.output_dir, train_util.EPOCH_STATE_NAME.format(model_name, epoch + 1))) + if remove_epoch_no is not None: + state_dir_old = os.path.join(args.output_dir, train_util.EPOCH_STATE_NAME.format(model_name, remove_epoch_no)) + if os.path.exists(state_dir_old): + shutil.rmtree(state_dir_old) + + # end of epoch is_main_process = accelerator.is_main_process if is_main_process: @@ -305,14 +328,20 @@ def train(args): accelerator.end_training() if args.save_state: - train_util.save_last_state(args, accelerator) + train_util.save_state_on_train_end(args, accelerator) del accelerator # この後メモリを使うのでこれは消す if is_main_process: - def last_save_func(file): - network.save_weights(file, save_dtype) - train_util.save_last_model(args, last_save_func) + os.makedirs(args.output_dir, exist_ok=True) + + model_name = train_util.DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name + ckpt_name = model_name + '.' + args.save_model_as + ckpt_file = os.path.join(args.output_dir, ckpt_name) + + print(f"save trained model to {ckpt_file}") + network.save_weights(ckpt_file, save_dtype) + print("model saved.") if __name__ == '__main__': @@ -322,6 +351,9 @@ if __name__ == '__main__': train_util.add_dataset_arguments(parser, True, True) train_util.add_training_arguments(parser, True) + parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"], + help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)") + parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率") parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率") From 1b222dbf9b058164bf6f09977686dd4b51a7f4a8 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 6 Jan 2023 17:13:56 +0900 Subject: [PATCH 06/26] erase using of deleted property --- train_db.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/train_db.py b/train_db.py index a03ab563..89a5d492 100644 --- a/train_db.py +++ b/train_db.py @@ -181,15 +181,15 @@ def train(args): num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) # 学習する - total_batch_size = args.train_batch_size # * accelerator.num_processes + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps print("running training / 学習開始") print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}") print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}") - print(f" num examples / サンプル数: {train_dataset.num_train_images * (2 if train_dataset.enable_reg_images else 1)}") print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") print(f" num epochs / epoch数: {num_train_epochs}") print(f" batch size per device / バッチサイズ: {args.train_batch_size}") - print(f" total train batch size (with parallel & distributed) / 総バッチサイズ(並列学習含む): {total_batch_size}") + print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") From 2efced0a9ac187c46a6c395386228b7461503075 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 7 Jan 2023 20:19:25 +0900 Subject: [PATCH 07/26] fix training starts with debug_dataset --- train_db.py | 24 +++++------------------- 1 file changed, 5 insertions(+), 19 deletions(-) diff --git a/train_db.py b/train_db.py index 89a5d492..a47da472 100644 --- a/train_db.py +++ b/train_db.py @@ -3,33 +3,19 @@ import gc import time -from torch.autograd.function import Function import argparse -import glob import itertools import math import os -import random -import shutil from tqdm import tqdm import torch -from torchvision import transforms -from accelerate import Accelerator from accelerate.utils import set_seed -from transformers import CLIPTokenizer import diffusers -from diffusers import DDPMScheduler, StableDiffusionPipeline -import albumentations as albu -import numpy as np -from PIL import Image -import cv2 -from einops import rearrange -from torch import einsum +from diffusers import DDPMScheduler -import library.model_util as model_util import library.train_util as train_util -from library.train_util import DreamBoothDataset, FineTuningDataset +from library.train_util import DreamBoothDataset def collate_fn(examples): @@ -52,11 +38,12 @@ def train(args): args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset) if args.no_token_padding: - train_dataset.disable_padding() + train_dataset.disable_token_padding() train_dataset.make_buckets() if args.debug_dataset: train_util.debug_dataset(train_dataset) + return # acceleratorを準備する print("prepare accelerator") @@ -311,8 +298,7 @@ def train(args): accelerator.end_training() if args.save_state: - print("saving last state.") - accelerator.save_state(os.path.join(args.output_dir, train_util.LAST_STATE_NAME)) + train_util.save_state_on_train_end(args, accelerator) del accelerator # この後メモリを使うのでこれは消す From 9f1d3aca2416b1b7ab37cee01ddbca7bcff6bb88 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 7 Jan 2023 20:20:37 +0900 Subject: [PATCH 08/26] add save_state_on_train end, fix reg imgs repeats --- library/train_util.py | 59 +++++++++++++++++++++++++++---------------- train_network.py | 9 +++---- 2 files changed, 40 insertions(+), 28 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 5033a55b..2eb16c00 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -76,7 +76,7 @@ class BaseDataset(torch.utils.data.Dataset): self.flip_aug = flip_aug self.color_aug = color_aug self.debug_dataset = debug_dataset - self.padding_disabled = False + self.token_padding_disabled = False self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2 @@ -102,8 +102,8 @@ class BaseDataset(torch.utils.data.Dataset): self.image_data: dict[str, ImageInfo] = {} - def disable_padding(self): - self.padding_disabled = True + def disable_token_padding(self): + self.token_padding_disabled = True def process_caption(self, caption): if self.shuffle_caption: @@ -412,13 +412,13 @@ class BaseDataset(torch.utils.data.Dataset): caption = self.process_caption(image_info.caption) captions.append(caption) - if not self.padding_disabled: # this option might be omitted in future + if not self.token_padding_disabled: # this option might be omitted in future input_ids_list.append(self.get_input_ids(caption)) example = {} example['loss_weights'] = torch.FloatTensor(loss_weights) - if self.padding_disabled: + if self.token_padding_disabled: # padding=True means pad in the batch example['input_ids'] = self.tokenizer(captions, padding=True, truncation=True, return_tensors="pt").input_ids else: @@ -540,13 +540,20 @@ class DreamBoothDataset(BaseDataset): if num_reg_images == 0: print("no regularization images / 正則化画像が見つかりませんでした") else: + # num_repeatsを計算する:どうせ大した数ではないのでループで処理する n = 0 + first_loop = True while n < num_train_images: for info in reg_infos: - self.register_image(info) - n += info.num_repeats - if n >= num_train_images: # reg画像にnum_repeats>1のときはまずありえないので考慮しない + if first_loop: + self.register_image(info) + n += info.num_repeats + else: + info.num_repeats += 1 + n += 1 + if n >= num_train_images: break + first_loop = False self.num_reg_images = num_reg_images @@ -1253,7 +1260,6 @@ def save_on_epoch_end(args: argparse.Namespace, save_func, remove_old_func, epoc saving = epoch_no % args.save_every_n_epochs == 0 and epoch_no < num_train_epochs remove_epoch_no = None if saving: - print("saving checkpoint.") os.makedirs(args.output_dir, exist_ok=True) save_func() @@ -1270,6 +1276,7 @@ def save_sd_model_on_epoch_end(args: argparse.Namespace, accelerator, src_path: if save_stable_diffusion_format: def save_sd(): ckpt_file = os.path.join(args.output_dir, ckpt_name) + print(f"saving checkpoint: {ckpt_file}") model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae) @@ -1277,6 +1284,7 @@ def save_sd_model_on_epoch_end(args: argparse.Namespace, accelerator, src_path: _, old_ckpt_name = get_epoch_ckpt_name(args, use_safetensors, old_epoch_no) old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) if os.path.exists(old_ckpt_file): + print(f"removing old checkpoint: {old_ckpt_file}") os.remove(old_ckpt_file) save_func = save_sd @@ -1284,6 +1292,7 @@ def save_sd_model_on_epoch_end(args: argparse.Namespace, accelerator, src_path: else: def save_du(): out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, epoch_no)) + print(f"saving model: {out_dir}") os.makedirs(out_dir, exist_ok=True) model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors) @@ -1291,6 +1300,7 @@ def save_sd_model_on_epoch_end(args: argparse.Namespace, accelerator, src_path: def remove_du(old_epoch_no): out_dir_old = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, old_epoch_no)) if os.path.exists(out_dir_old): + print(f"removing old model: {out_dir_old}") shutil.rmtree(out_dir_old) save_func = save_du @@ -1298,19 +1308,17 @@ def save_sd_model_on_epoch_end(args: argparse.Namespace, accelerator, src_path: saving, remove_epoch_no = save_on_epoch_end(args, save_func, remove_old_func, epoch_no, num_train_epochs) if saving and args.save_state: - print("saving state.") - accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no))) - if remove_epoch_no is not None: - state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no)) - if os.path.exists(state_dir_old): - shutil.rmtree(state_dir_old) + save_state_on_epoch_end(args, accelerator, model_name, epoch_no, remove_epoch_no) -def save_state_on_train_end(args: argparse.Namespace, accelerator): - print("saving last state.") - os.makedirs(args.output_dir, exist_ok=True) - model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name - accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name))) +def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, epoch_no, remove_epoch_no): + print("saving state.") + accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no))) + if remove_epoch_no is not None: + state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no)) + if os.path.exists(state_dir_old): + print(f"removing old state: {state_dir_old}") + shutil.rmtree(state_dir_old) def save_sd_model_on_train_end(args: argparse.Namespace, src_path: str, save_stable_diffusion_format: bool, use_safetensors: bool, save_dtype: torch.dtype, epoch: int, global_step: int, text_encoder, unet, vae): @@ -1326,12 +1334,19 @@ def save_sd_model_on_train_end(args: argparse.Namespace, src_path: str, save_sta model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet, src_path, epoch, global_step, save_dtype, vae) else: - print(f"save trained model as Diffusers to {args.output_dir}") - out_dir = os.path.join(args.output_dir, model_name) os.makedirs(out_dir, exist_ok=True) + print(f"save trained model as Diffusers to {out_dir}") model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors) + +def save_state_on_train_end(args: argparse.Namespace, accelerator): + print("saving last state.") + os.makedirs(args.output_dir, exist_ok=True) + model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name + accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name))) + + # endregion diff --git a/train_network.py b/train_network.py index bfb2d860..24dfa5b0 100644 --- a/train_network.py +++ b/train_network.py @@ -302,22 +302,19 @@ def train(args): def save_func(): ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as ckpt_file = os.path.join(args.output_dir, ckpt_name) + print(f"saving checkpoint: {ckpt_file}") unwrap_model(network).save_weights(ckpt_file, save_dtype) def remove_old_func(old_epoch_no): old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) if os.path.exists(old_ckpt_file): + print(f"removing old checkpoint: {old_ckpt_file}") os.remove(old_ckpt_file) saving, remove_epoch_no = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs) if saving and args.save_state: - print("saving state.") - accelerator.save_state(os.path.join(args.output_dir, train_util.EPOCH_STATE_NAME.format(model_name, epoch + 1))) - if remove_epoch_no is not None: - state_dir_old = os.path.join(args.output_dir, train_util.EPOCH_STATE_NAME.format(model_name, remove_epoch_no)) - if os.path.exists(state_dir_old): - shutil.rmtree(state_dir_old) + train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1, remove_epoch_no) # end of epoch From 80af4c0c424e4561c33359e5669f6fe9f9af0e39 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 7 Jan 2023 21:43:27 +0900 Subject: [PATCH 09/26] Set dtype if text encoder is not trained at all --- train_db.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/train_db.py b/train_db.py index a47da472..d1ef350c 100644 --- a/train_db.py +++ b/train_db.py @@ -98,6 +98,8 @@ def train(args): train_text_encoder = args.stop_text_encoder_training >= 0 unet.requires_grad_(True) # 念のため追加 text_encoder.requires_grad_(train_text_encoder) + if not train_text_encoder: + print("Text Encoder is not trained.") if args.gradient_checkpointing: unet.enable_gradient_checkpointing() @@ -153,6 +155,9 @@ def train(args): unet, text_encoder, optimizer, train_dataloader, lr_scheduler) else: unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) + + if not train_text_encoder: + text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする if args.full_fp16: From 82e585cf01d6a776db0212206672bd066ed4499e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 8 Jan 2023 18:49:34 +0900 Subject: [PATCH 10/26] Fix full_fp16 and clip_skip==2 is not working --- fine_tune.py | 3 ++- train_db.py | 5 +++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 8b06abda..fa3c81be 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -230,7 +230,8 @@ def train(args): with torch.set_grad_enabled(args.train_text_encoder): # Get the text embedding for conditioning input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder) + encoder_hidden_states = train_util.get_hidden_states( + args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype) # Sample noise that we'll add to the latents noise = torch.randn_like(latents, device=latents.device) diff --git a/train_db.py b/train_db.py index d1ef350c..8c9cdb95 100644 --- a/train_db.py +++ b/train_db.py @@ -155,7 +155,7 @@ def train(args): unet, text_encoder, optimizer, train_dataloader, lr_scheduler) else: unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) - + if not train_text_encoder: text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error @@ -227,7 +227,8 @@ def train(args): # Get the text embedding for conditioning with torch.set_grad_enabled(global_step < args.stop_text_encoder_training): input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder) + encoder_hidden_states = train_util.get_hidden_states( + args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype) # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) From 1945fa186d76fb5545b1f6491b77512bf4953320 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 8 Jan 2023 18:50:52 +0900 Subject: [PATCH 11/26] Show error if caption isn't UTF-8, add bmp support --- library/train_util.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 2eb16c00..98ad10ef 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1,5 +1,4 @@ # common functions for training -# TODO test no_token_padding option import argparse import json @@ -42,6 +41,8 @@ DEFAULT_LAST_OUTPUT_NAME = "last" # region dataset +IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp"] + class ImageInfo(): def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None: @@ -476,7 +477,11 @@ class DreamBoothDataset(BaseDataset): for cap_path in cap_paths: if os.path.isfile(cap_path): with open(cap_path, "rt", encoding='utf-8') as f: - lines = f.readlines() + try: + lines = f.readlines() + except UnicodeDecodeError as e: + print(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}") + raise e assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}" caption = lines[0].strip() break @@ -495,8 +500,7 @@ class DreamBoothDataset(BaseDataset): return 0, [], [] caption_by_folder = '_'.join(tokens[1:]) - img_paths = glob.glob(os.path.join(dir, "*.png")) + glob.glob(os.path.join(dir, "*.jpg")) + \ - glob.glob(os.path.join(dir, "*.webp")) + img_paths = glob_images(dir, "*") print(f"found directory {n_repeats}_{caption_by_folder} contains {len(img_paths)} image files") # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う @@ -581,8 +585,7 @@ class FineTuningDataset(BaseDataset): abs_path = image_key else: # わりといい加減だがいい方法が思いつかん - abs_path = (glob.glob(os.path.join(train_data_dir, f"{image_key}.png")) + glob.glob(os.path.join(train_data_dir, f"{image_key}.jpg")) + - glob.glob(os.path.join(train_data_dir, f"{image_key}.webp"))) + abs_path = glob_images(train_data_dir, image_key) assert len(abs_path) >= 1, f"no image / 画像がありません: {abs_path}" abs_path = abs_path[0] @@ -705,6 +708,12 @@ def debug_dataset(train_dataset): if k == 27 or example['images'] is None: break +def glob_images(dir, base): + img_paths = [] + for ext in IMAGE_EXTENSIONS: + img_paths.extend(glob.glob(os.path.join(dir, base + ext))) + return img_paths + # endregion @@ -1210,6 +1219,10 @@ def patch_accelerator_for_fp16_training(accelerator): def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encoder, weight_dtype=None): + # with no_token_padding, the length is not max length, return result immediately + if input_ids.size()[-1] != tokenizer.model_max_length: + return text_encoder(input_ids)[0] + b_size = input_ids.size()[0] input_ids = input_ids.reshape((-1, tokenizer.model_max_length)) # batch_size*3, 77 From 6b62c44022aef9ad786b465b59216e488a28de73 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 8 Jan 2023 21:40:40 +0900 Subject: [PATCH 12/26] fix errors in fine tuning --- fine_tune.py | 17 ++++++++++++----- library/train_util.py | 26 +++++++++++++++++--------- train_network.py | 5 +++-- 3 files changed, 32 insertions(+), 16 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index fa3c81be..1a94870f 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -33,7 +33,8 @@ def train(args): train_dataset = train_util.FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir, tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens, args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, - args.flip_aug, args.color_aug, args.face_crop_aug_range, args.dataset_repeats, args.debug_dataset) + args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, + args.dataset_repeats, args.debug_dataset) train_dataset.make_buckets() if args.debug_dataset: @@ -198,7 +199,7 @@ def train(args): # 学習する total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps print("running training / 学習開始") - print(f" num examples / サンプル数: {train_dataset.images_count}") + print(f" num examples / サンプル数: {train_dataset.num_train_images}") print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") print(f" num epochs / epoch数: {num_train_epochs}") print(f" batch size per device / バッチサイズ: {args.train_batch_size}") @@ -223,8 +224,13 @@ def train(args): loss_total = 0 for step, batch in enumerate(train_dataloader): with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく - latents = batch["latents"].to(accelerator.device) - latents = latents * 0.18215 + with torch.no_grad(): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device) + else: + # latentに変換 + latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 b_size = latents.shape[0] with torch.set_grad_enabled(args.train_text_encoder): @@ -310,7 +316,7 @@ def train(args): if is_main_process: src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path train_util.save_sd_model_on_train_end(args, src_path, save_stable_diffusion_format, use_safetensors, - save_dtype, epoch, global_step, text_encoder, unet, vae) + save_dtype, epoch, global_step, text_encoder, unet, vae) print("model saved.") @@ -324,6 +330,7 @@ if __name__ == '__main__': parser.add_argument("--diffusers_xformers", action='store_true', help='use xformers by diffusers / Diffusersでxformersを使用する') + parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") args = parser.parse_args() train(args) diff --git a/library/train_util.py b/library/train_util.py index 98ad10ef..bad954c2 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -65,7 +65,7 @@ class BucketBatchIndex(NamedTuple): class BaseDataset(torch.utils.data.Dataset): - def __init__(self, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, flip_aug: bool, color_aug: bool, face_crop_aug_range, debug_dataset: bool) -> None: + def __init__(self, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, flip_aug: bool, color_aug: bool, face_crop_aug_range, random_crop, debug_dataset: bool) -> None: super().__init__() self.tokenizer: CLIPTokenizer = tokenizer self.max_token_length = max_token_length @@ -77,6 +77,7 @@ class BaseDataset(torch.utils.data.Dataset): self.flip_aug = flip_aug self.color_aug = color_aug self.debug_dataset = debug_dataset + self.random_crop = random_crop self.token_padding_disabled = False self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2 @@ -265,8 +266,9 @@ class BaseDataset(torch.utils.data.Dataset): if info.latents_npz is not None: info.latents = self.load_latents_from_npz(info, False) info.latents = torch.FloatTensor(info.latents) - info.latents_flipped = self.load_latents_from_npz(info, True) - info.latents_flipped = torch.FloatTensor(info.latents_flipped) + info.latents_flipped = self.load_latents_from_npz(info, True) # might be None + if info.latents_flipped is not None: + info.latents_flipped = torch.FloatTensor(info.latents_flipped) continue image = self.load_image(info.absolute_path) @@ -349,6 +351,8 @@ class BaseDataset(torch.utils.data.Dataset): def load_latents_from_npz(self, image_info: ImageInfo, flipped): npz_file = image_info.latents_npz_flipped if flipped else image_info.latents_npz + if npz_file is None: + return None return np.load(npz_file)['arr_0'] def __len__(self): @@ -444,14 +448,13 @@ class BaseDataset(torch.utils.data.Dataset): class DreamBoothDataset(BaseDataset): def __init__(self, batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset) -> None: super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, - resolution, flip_aug, color_aug, face_crop_aug_range, debug_dataset) + resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset) assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です" self.batch_size = batch_size self.size = min(self.width, self.height) # 短いほう self.prior_loss_weight = prior_loss_weight - self.random_crop = random_crop self.latents_cache = None self.enable_bucket = enable_bucket @@ -563,9 +566,9 @@ class DreamBoothDataset(BaseDataset): class FineTuningDataset(BaseDataset): - def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, flip_aug, color_aug, face_crop_aug_range, dataset_repeats, debug_dataset) -> None: + def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, flip_aug, color_aug, face_crop_aug_range, random_crop, dataset_repeats, debug_dataset) -> None: super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, - resolution, min_bucket_reso, max_bucket_reso, flip_aug, color_aug, face_crop_aug_range, debug_dataset) + resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset) # メタデータを読み込む if os.path.exists(json_file_name): @@ -639,7 +642,7 @@ class FineTuningDataset(BaseDataset): break sizes.add(image_info.image_size[0]) sizes.add(image_info.image_size[1]) - resos.add(image_info.image_size) + resos.add(tuple(image_info.image_size)) if sizes is None: assert resolution is not None, "if metadata doesn't have bucket info, resolution is required / メタデータにbucket情報がない場合はresolutionを指定してください" @@ -708,6 +711,7 @@ def debug_dataset(train_dataset): if k == 27 or example['images'] is None: break + def glob_images(dir, base): img_paths = [] for ext in IMAGE_EXTENSIONS: @@ -986,7 +990,7 @@ def replace_unet_cross_attn_to_xformers(): # endregion -# region utils +# region arguments def add_sd_models_arguments(parser: argparse.ArgumentParser): # for pretrained models @@ -1101,6 +1105,10 @@ def add_sd_saving_arguments(parser: argparse.ArgumentParser): parser.add_argument("--use_safetensors", action='store_true', help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存する(save_model_as未指定時)") +# endregion + +# region utils + def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool): # backward compatibility diff --git a/train_network.py b/train_network.py index 24dfa5b0..e557b1de 100644 --- a/train_network.py +++ b/train_network.py @@ -49,7 +49,8 @@ def train(args): train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir, tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens, args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, - args.flip_aug, args.color_aug, args.face_crop_aug_range, args.dataset_repeats, args.debug_dataset) + args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, + args.dataset_repeats, args.debug_dataset) train_dataset.make_buckets() if args.debug_dataset: @@ -315,7 +316,7 @@ def train(args): saving, remove_epoch_no = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs) if saving and args.save_state: train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1, remove_epoch_no) - + # end of epoch is_main_process = accelerator.is_main_process From fbaf373c8a88af398836d84937819a48313ad9f3 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 9 Jan 2023 13:13:37 +0900 Subject: [PATCH 13/26] fix gradient accum not used for lr schduler --- train_network.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/train_network.py b/train_network.py index e557b1de..9f292b97 100644 --- a/train_network.py +++ b/train_network.py @@ -1,22 +1,15 @@ -import gc import importlib -import json -import shutil -import time import argparse +import gc import math import os from tqdm import tqdm import torch from accelerate.utils import set_seed -from transformers import CLIPTokenizer import diffusers -from diffusers import DDPMScheduler, StableDiffusionPipeline -import numpy as np -import cv2 +from diffusers import DDPMScheduler -import library.model_util as model_util import library.train_util as train_util from library.train_util import DreamBoothDataset, FineTuningDataset @@ -139,7 +132,7 @@ def train(args): # lr schedulerを用意する lr_scheduler = diffusers.optimization.get_scheduler( - args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps) + args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps) # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする if args.full_fp16: @@ -216,17 +209,16 @@ def train(args): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") - # 指定したステップ数までText Encoderを学習する:epoch最初の状態 network.on_epoch_start(text_encoder, unet) loss_total = 0 for step, batch in enumerate(train_dataloader): with accelerator.accumulate(network): with torch.no_grad(): - # latentに変換 - if batch["latents"] is not None: + if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device) else: + # latentに変換 latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() latents = latents * 0.18215 b_size = latents.shape[0] From 223640e1ae7d65d216be31c6953b923c27e8cbc5 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 9 Jan 2023 14:49:56 +0900 Subject: [PATCH 14/26] Add updates. --- README.md | 80 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) diff --git a/README.md b/README.md index 423282cd..8f402187 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,9 @@ This repository contains training, generation and utility scripts for Stable Diffusion. +**January 9, 2023: Information about the update can be found at the end of the page.** + +**20231/1/9: 更新情報がページ末尾にありますのでご覧ください。** + [日本語版README](./README-ja.md) For easier use (GUI and PowerShell scripts etc...), please visit [the repository maintained by bmaltais](https://github.com/bmaltais/kohya_ss). Thanks to @bmaltais! @@ -8,6 +12,7 @@ This repository contains the scripts for: * DreamBooth training, including U-Net and Text Encoder * fine-tuning (native training), including U-Net and Text Encoder +* LoRA training * image generation * model conversion (supports 1.x and 2.x, Stable Diffision ckpt/safetensors and Diffusers) @@ -104,3 +109,78 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser [BLIP](https://github.com/salesforce/BLIP): BSD-3-Clause + +# Updates: Jan 9. 2023 + +All training scripts are updated. + +## Breaking Changes + +- The ``fine_tuning`` option in ``train_db.py`` is removed. Please use DreamBooth with captions or ``fine_tune.py``. +- The Hypernet feature in ``fine_tune.py`` is removed, will be implemented in ``train_network.py`` in future. + +## Features, Improvements and Bug Fixes + +### for all script: train_db.py, fine_tune.py and train_network.py + +- Added ``output_name`` option. The name of output file can be specified. + - With ``--output_name style1``, the output file is like ``style1_000001.ckpt`` (or ``.safetensors``) for each epoch and ``style1.ckpt`` for last. + - If ommitted (default), same to previous. ``epoch-000001.ckpt`` and ``last.ckpt``. +- Added ``save_last_n_epochs`` option. Keep only latest n files for the checkpoints and the states. Older files are removed. (Thanks to shirayu!) + - If the options are ``--save_every_n_epochs=2 --save_last_n_epochs=3``, in the end of epoch 8, ``epoch-000008.ckpt`` is created and ``epoch-000002.ckpt`` is removed. + +### train_db.py + +- Added ``max_token_length`` option. Captions can have more than 75 tokens. + +### fine_tune.py + +- The script now works without .npz files. If .npz is not found, the scripts get the latents with VAE. + - You can omit ``prepare_buckets_latents.py`` in preprocessing. However, it is recommended if you train more than 1 or 2 epochs. + - ``--resolution`` option is required to specify the training resolution. +- Added ``cache_latents`` and ``color_aug`` options. + +### train_network.py + +- Now ``--gradient_checkpointing`` is effective for U-Net and Text Encoder. + - The memory usage is reduced. The larger batch size is avilable, but the training speed will be slow. + - The training might be possible with 6GB VRAM for dimension=4 with batch size=1. + +Documents are not updated now, I will update one by one. + +# 更新情報 (2023/1/9) + +学習スクリプトを更新しました。 + +## 削除された機能 +- ``train_db.py`` の ``fine_tuning`` は削除されました。キャプション付きの DreamBooth または ``fine_tune.py`` を使ってください。 +- ``fine_tune.py`` の Hypernet学習の機能は削除されました。将来的に``train_network.py``に追加される予定です。 + +## その他の機能追加、バグ修正など + +### 学習スクリプトに共通: train_db.py, fine_tune.py and train_network.py + +- ``output_name``オプションを追加しました。保存されるモデルファイルの名前を指定できます。 + - ``--output_name style1``と指定すると、エポックごとに保存されるファイル名は``style1_000001.ckpt`` (または ``.safetensors``) に、最後に保存されるファイル名は``style1.ckpt``になります。 + - 省略時は今までと同じです(``epoch-000001.ckpt``および``last.ckpt``)。 +- ``save_last_n_epochs``オプションを追加しました。最新の n ファイル、stateだけ保存し、古いものは削除します。(shirayu氏に感謝します。) + - たとえば``--save_every_n_epochs=2 --save_last_n_epochs=3``と指定した時、8エポック目の終了時には、``epoch-000008.ckpt``が保存され``epoch-000002.ckpt``が削除されます。 + +### train_db.py + +- ``max_token_length``オプションを追加しました。75文字を超えるキャプションが使えるようになります。 + +### fine_tune.py + +- .npzファイルがなくても動作するようになりました。.npzファイルがない場合、VAEからlatentsを取得して動作します。 + - ``prepare_buckets_latents.py``を前処理で実行しなくても良くなります。ただし事前取得をしておいたほうが、2エポック以上学習する場合にはトータルで高速です。 + - この場合、解像度を指定するために``--resolution``オプションが必要です。 +- ``cache_latents``と``color_aug``オプションを追加しました。 + +### train_network.py + +- ``--gradient_checkpointing``がU-NetとText Encoderにも有効になりました。 + - メモリ消費が減ります。バッチサイズを大きくできますが、トータルでの学習時間は長くなるかもしれません。 + - dimension=4のLoRAはバッチサイズ1で6GB VRAMで学習できるかもしれません。 + +ドキュメントは未更新ですが少しずつ更新の予定です。 From c4bc435bc442d7222d72979e0b9b0366272d6685 Mon Sep 17 00:00:00 2001 From: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Mon, 9 Jan 2023 15:00:20 +0900 Subject: [PATCH 15/26] Update README --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 8f402187..a7829591 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ This repository contains training, generation and utility scripts for Stable Diffusion. -**January 9, 2023: Information about the update can be found at the end of the page.** +**January 9, 2023: Information about the update can be found at [the end of the page](#updates-jan-9-2023).** -**20231/1/9: 更新情報がページ末尾にありますのでご覧ください。** +**20231/1/9: 更新情報が[ページ末尾](#更新情報-202319)にありますのでご覧ください。** [日本語版README](./README-ja.md) From d8da85b38bcc70cd4d9381a10586b1e1f580493e Mon Sep 17 00:00:00 2001 From: Gaetano Bonofiglio Date: Mon, 9 Jan 2023 11:40:00 +0100 Subject: [PATCH 16/26] fix file not found when `[` is in the filename --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index bad954c2..fcd9880d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -715,7 +715,7 @@ def debug_dataset(train_dataset): def glob_images(dir, base): img_paths = [] for ext in IMAGE_EXTENSIONS: - img_paths.extend(glob.glob(os.path.join(dir, base + ext))) + img_paths.extend(glob.glob(glob.escape(os.path.join(dir, base + ext)))) return img_paths # endregion From 673f9ced4792607724669fd7edefcc11f3c6cddb Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 9 Jan 2023 21:06:58 +0900 Subject: [PATCH 17/26] Fix '*' is not working for DreamBooth --- library/train_util.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index fcd9880d..7a0f794b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -715,7 +715,10 @@ def debug_dataset(train_dataset): def glob_images(dir, base): img_paths = [] for ext in IMAGE_EXTENSIONS: - img_paths.extend(glob.glob(glob.escape(os.path.join(dir, base + ext)))) + if base == '*': + img_paths.extend(glob.glob(os.path.join(glob.escape(dir), base + ext))) + else: + img_paths.extend(glob.glob(glob.escape(os.path.join(dir, base + ext)))) return img_paths # endregion From f981dfd38ae857b25608d2baefabdce28b92b754 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 10 Jan 2023 17:43:35 +0900 Subject: [PATCH 18/26] Add credits --- README-ja.md | 6 +++++- README.md | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/README-ja.md b/README-ja.md index 4391b1db..95327b36 100644 --- a/README-ja.md +++ b/README-ja.md @@ -114,9 +114,13 @@ pip install --upgrade -r コマンドが成功すれば新しいバージョンが使用できます。 +## 謝意 + +LoRAの実装は[cloneofsimo氏のリポジトリ](https://github.com/cloneofsimo/lora)を基にしたものです。感謝申し上げます。 + ## ライセンス -スクリプトのライセンスはASL 2.0ですが、一部他のライセンスのコードを含みます。 +スクリプトのライセンスはASL 2.0ですが(Diffusersおよびcloneofsimo氏のリポジトリ由来のものも同様)、一部他のライセンスのコードを含みます。 [Memory Efficient Attention Pytorch](https://github.com/lucidrains/memory-efficient-attention-pytorch): MIT diff --git a/README.md b/README.md index a7829591..2495a129 100644 --- a/README.md +++ b/README.md @@ -99,9 +99,13 @@ pip install --upgrade -r requirements.txt Once the commands have completed successfully you should be ready to use the new version. +## Credits + +The implementation for LoRA is based on [cloneofsimo's repo](https://github.com/cloneofsimo/lora). Thank you for great work!!! + ## License -The majority of scripts is licensed under ASL 2.0 (including codes from Diffusers), however portions of the project are available under separate license terms: +The majority of scripts is licensed under ASL 2.0 (including codes from Diffusers, cloneofsimo's), however portions of the project are available under separate license terms: [Memory Efficient Attention Pytorch](https://github.com/lucidrains/memory-efficient-attention-pytorch): MIT From 2e4ce0fdff11843d2437b9b3bb4eb3756e8e6139 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Tue, 10 Jan 2023 02:49:52 -0800 Subject: [PATCH 19/26] Add training metadata to output LoRA model --- library/train_util.py | 14 +++++++++ networks/extract_lora_from_models.py | 2 +- networks/lora.py | 16 ++++++++-- train_network.py | 45 ++++++++++++++++++++++++++-- 4 files changed, 72 insertions(+), 5 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 7a0f794b..70af44c9 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -747,6 +747,20 @@ def exists(val): def default(val, d): return val if exists(val) else d + +def model_hash(filename): + try: + with open(filename, "rb") as file: + import hashlib + m = hashlib.sha256() + + file.seek(0x100000) + m.update(file.read(0x10000)) + return m.hexdigest()[0:8] + except FileNotFoundError: + return 'NOFILE' + + # flash attention forwards and backwards # https://arxiv.org/abs/2205.14135 diff --git a/networks/extract_lora_from_models.py b/networks/extract_lora_from_models.py index c882e88f..0a4c3a00 100644 --- a/networks/extract_lora_from_models.py +++ b/networks/extract_lora_from_models.py @@ -135,7 +135,7 @@ def svd(args): if dir_name and not os.path.exists(dir_name): os.makedirs(dir_name, exist_ok=True) - lora_network_o.save_weights(args.save_to, save_dtype) + lora_network_o.save_weights(args.save_to, save_dtype, {}) print(f"LoRA weights are saved to: {args.save_to}") diff --git a/networks/lora.py b/networks/lora.py index 730a6376..dbef2aa1 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -6,6 +6,8 @@ import math import os import torch +import zipfile +import json class LoRAModule(torch.nn.Module): @@ -61,6 +63,7 @@ class LoRANetwork(torch.nn.Module): super().__init__() self.multiplier = multiplier self.lora_dim = lora_dim + self.metadata = {} # create module instances def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> list[LoRAModule]: @@ -91,11 +94,17 @@ class LoRANetwork(torch.nn.Module): names.add(lora.lora_name) def load_weights(self, file): + self.metadata = {} if os.path.splitext(file)[1] == '.safetensors': from safetensors.torch import load_file self.weights_sd = load_file(file) + self.metadata = self.weights_sd.metadata() else: self.weights_sd = torch.load(file, map_location='cpu') + with zipfile.ZipFile(file, "w") as zipf: + if "sd_scripts_metadata.json" in zipf.namelist(): + with zipf.open("sd_scripts_metadata.json", "r") as jsfile: + self.metadata = json.load(jsfile) def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None): if self.weights_sd: @@ -174,7 +183,8 @@ class LoRANetwork(torch.nn.Module): def get_trainable_params(self): return self.parameters() - def save_weights(self, file, dtype): + def save_weights(self, file, dtype, metadata): + self.metadata = metadata state_dict = self.state_dict() if dtype is not None: @@ -185,6 +195,8 @@ class LoRANetwork(torch.nn.Module): if os.path.splitext(file)[1] == '.safetensors': from safetensors.torch import save_file - save_file(state_dict, file) + save_file(state_dict, file, metadata) else: torch.save(state_dict, file) + with zipfile.ZipFile(file, "w") as zipf: + zipf.writestr("sd_scripts_metadata.json", json.dumps(metadata)) diff --git a/train_network.py b/train_network.py index 9f292b97..3f45bae0 100644 --- a/train_network.py +++ b/train_network.py @@ -197,6 +197,47 @@ def train(args): print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + metadata = { + "ss_learning_rate": args.learning_rate, + "ss_text_encoder_lr": args.text_encoder_lr, + "ss_unet_lr": args.unet_lr, + "ss_num_train_images": train_dataset.num_train_images, + "ss_num_reg_images": train_dataset.num_reg_images, + "ss_num_batches_per_epoch": len(train_dataloader), + "ss_num_epochs": num_train_epochs, + "ss_batch_size_per_device": args.train_batch_size, + "ss_total_batch_size": total_batch_size, + "ss_gradient_accumulation_steps": args.gradient_accumulation_steps, + "ss_max_train_steps": args.max_train_steps, + "ss_lr_warmup_steps": args.lr_warmup_steps, + "ss_lr_scheduler": args.lr_scheduler, + "ss_network_module": args.network_module, + "ss_network_dim": 4 if args.network_dim is None else args.network_dim, + "ss_full_fp16": bool(args.full_fp16), + "ss_v2": bool(args.v2), + "ss_resolution": args.resolution, + "ss_clip_skip": args.clip_skip, + "ss_max_token_length": args.max_token_length, + "ss_color_aug": bool(args.color_aug), + "ss_flip_aug": bool(args.flip_aug), + "ss_random_crop": bool(args.random_crop), + "ss_shuffle_caption": bool(args.shuffle_caption), + "ss_cache_latents": bool(args.cache_latents), + "ss_enable_bucket": bool(args.enable_bucket), + "ss_min_bucket_reso": args.min_bucket_reso, + "ss_max_bucket_reso": args.max_bucket_reso, + "ss_seed": args.seed + } + + if args.pretrained_model_name_or_path is not None: + sd_model_name = args.pretrained_model_name_or_path + if os.path.exists(sd_model_name): + metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name) + sd_model_name = os.path.basename(sd_model_name) + metadata["ss_sd_model_name"] = sd_model_name + + metadata = {k: str(v) for k, v in metadata.items()} + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") global_step = 0 @@ -296,7 +337,7 @@ def train(args): ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as ckpt_file = os.path.join(args.output_dir, ckpt_name) print(f"saving checkpoint: {ckpt_file}") - unwrap_model(network).save_weights(ckpt_file, save_dtype) + unwrap_model(network).save_weights(ckpt_file, save_dtype, metadata) def remove_old_func(old_epoch_no): old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as @@ -330,7 +371,7 @@ def train(args): ckpt_file = os.path.join(args.output_dir, ckpt_name) print(f"save trained model to {ckpt_file}") - network.save_weights(ckpt_file, save_dtype) + network.save_weights(ckpt_file, save_dtype, metadata) print("model saved.") From 0c4423d9dc65465cf1a65762aca6c8b642a52759 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Tue, 10 Jan 2023 02:50:04 -0800 Subject: [PATCH 20/26] Add epoch number to metadata --- train_network.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index 3f45bae0..c5593c46 100644 --- a/train_network.py +++ b/train_network.py @@ -194,7 +194,7 @@ def train(args): print(f" num epochs / epoch数: {num_train_epochs}") print(f" batch size per device / バッチサイズ: {args.train_batch_size}") print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") - print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") metadata = { @@ -249,6 +249,7 @@ def train(args): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") + metadata["ss_epoch"] = str(epoch+1) network.on_epoch_start(text_encoder, unet) @@ -352,6 +353,8 @@ def train(args): # end of epoch + metadata["ss_epoch"] = str(num_train_epochs) + is_main_process = accelerator.is_main_process if is_main_process: network = unwrap_model(network) From de37fd9906dd5d5d2ad725580ede018c89faf117 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Tue, 10 Jan 2023 02:55:25 -0800 Subject: [PATCH 21/26] Fix metadata loading --- networks/lora.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/networks/lora.py b/networks/lora.py index dbef2aa1..98e8e4a4 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -58,12 +58,12 @@ class LoRANetwork(torch.nn.Module): TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] LORA_PREFIX_UNET = 'lora_unet' LORA_PREFIX_TEXT_ENCODER = 'lora_te' + METADATA_FILENAME = "sd_scripts_metadata.json" def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4) -> None: super().__init__() self.multiplier = multiplier self.lora_dim = lora_dim - self.metadata = {} # create module instances def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> list[LoRAModule]: @@ -94,17 +94,11 @@ class LoRANetwork(torch.nn.Module): names.add(lora.lora_name) def load_weights(self, file): - self.metadata = {} if os.path.splitext(file)[1] == '.safetensors': - from safetensors.torch import load_file + from safetensors.torch import load_file, safe_open self.weights_sd = load_file(file) - self.metadata = self.weights_sd.metadata() else: self.weights_sd = torch.load(file, map_location='cpu') - with zipfile.ZipFile(file, "w") as zipf: - if "sd_scripts_metadata.json" in zipf.namelist(): - with zipf.open("sd_scripts_metadata.json", "r") as jsfile: - self.metadata = json.load(jsfile) def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None): if self.weights_sd: @@ -184,7 +178,6 @@ class LoRANetwork(torch.nn.Module): return self.parameters() def save_weights(self, file, dtype, metadata): - self.metadata = metadata state_dict = self.state_dict() if dtype is not None: @@ -199,4 +192,4 @@ class LoRANetwork(torch.nn.Module): else: torch.save(state_dict, file) with zipfile.ZipFile(file, "w") as zipf: - zipf.writestr("sd_scripts_metadata.json", json.dumps(metadata)) + zipf.writestr(LoRANetwork.METADATA_FILENAME, json.dumps(metadata)) From e4f9b2b71504892bcce6191274ce5666ad87e7db Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 11 Jan 2023 23:12:18 +0900 Subject: [PATCH 22/26] Add VAE to meatada, add no_metadata option --- library/train_util.py | 18 +++++------ networks/lora.py | 8 +++-- train_network.py | 75 +++++++++++++++++++++++++------------------ 3 files changed, 59 insertions(+), 42 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 70af44c9..ade66a38 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -749,16 +749,16 @@ def default(val, d): def model_hash(filename): - try: - with open(filename, "rb") as file: - import hashlib - m = hashlib.sha256() + try: + with open(filename, "rb") as file: + import hashlib + m = hashlib.sha256() - file.seek(0x100000) - m.update(file.read(0x10000)) - return m.hexdigest()[0:8] - except FileNotFoundError: - return 'NOFILE' + file.seek(0x100000) + m.update(file.read(0x10000)) + return m.hexdigest()[0:8] + except FileNotFoundError: + return 'NOFILE' # flash attention forwards and backwards diff --git a/networks/lora.py b/networks/lora.py index 98e8e4a4..77fe26a7 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -178,6 +178,9 @@ class LoRANetwork(torch.nn.Module): return self.parameters() def save_weights(self, file, dtype, metadata): + if len(metadata) == 0: + metadata = None + state_dict = self.state_dict() if dtype is not None: @@ -191,5 +194,6 @@ class LoRANetwork(torch.nn.Module): save_file(state_dict, file, metadata) else: torch.save(state_dict, file) - with zipfile.ZipFile(file, "w") as zipf: - zipf.writestr(LoRANetwork.METADATA_FILENAME, json.dumps(metadata)) + if metadata is not None: + with zipfile.ZipFile(file, "w") as zipf: + zipf.writestr(LoRANetwork.METADATA_FILENAME, json.dumps(metadata)) diff --git a/train_network.py b/train_network.py index c5593c46..c920c5ed 100644 --- a/train_network.py +++ b/train_network.py @@ -198,37 +198,42 @@ def train(args): print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") metadata = { - "ss_learning_rate": args.learning_rate, - "ss_text_encoder_lr": args.text_encoder_lr, - "ss_unet_lr": args.unet_lr, - "ss_num_train_images": train_dataset.num_train_images, - "ss_num_reg_images": train_dataset.num_reg_images, - "ss_num_batches_per_epoch": len(train_dataloader), - "ss_num_epochs": num_train_epochs, - "ss_batch_size_per_device": args.train_batch_size, - "ss_total_batch_size": total_batch_size, - "ss_gradient_accumulation_steps": args.gradient_accumulation_steps, - "ss_max_train_steps": args.max_train_steps, - "ss_lr_warmup_steps": args.lr_warmup_steps, - "ss_lr_scheduler": args.lr_scheduler, - "ss_network_module": args.network_module, - "ss_network_dim": 4 if args.network_dim is None else args.network_dim, - "ss_full_fp16": bool(args.full_fp16), - "ss_v2": bool(args.v2), - "ss_resolution": args.resolution, - "ss_clip_skip": args.clip_skip, - "ss_max_token_length": args.max_token_length, - "ss_color_aug": bool(args.color_aug), - "ss_flip_aug": bool(args.flip_aug), - "ss_random_crop": bool(args.random_crop), - "ss_shuffle_caption": bool(args.shuffle_caption), - "ss_cache_latents": bool(args.cache_latents), - "ss_enable_bucket": bool(args.enable_bucket), - "ss_min_bucket_reso": args.min_bucket_reso, - "ss_max_bucket_reso": args.max_bucket_reso, - "ss_seed": args.seed + "ss_learning_rate": args.learning_rate, + "ss_text_encoder_lr": args.text_encoder_lr, + "ss_unet_lr": args.unet_lr, + "ss_num_train_images": train_dataset.num_train_images, # includes repeating TODO more detailed data + "ss_num_reg_images": train_dataset.num_reg_images, + "ss_num_batches_per_epoch": len(train_dataloader), + "ss_num_epochs": num_train_epochs, + "ss_batch_size_per_device": args.train_batch_size, + "ss_total_batch_size": total_batch_size, + "ss_gradient_accumulation_steps": args.gradient_accumulation_steps, + "ss_max_train_steps": args.max_train_steps, + "ss_lr_warmup_steps": args.lr_warmup_steps, + "ss_lr_scheduler": args.lr_scheduler, + "ss_network_module": args.network_module, + "ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim + "ss_mixed_precision": args.mixed_precision, + "ss_full_fp16": bool(args.full_fp16), + "ss_v2": bool(args.v2), + "ss_resolution": args.resolution, + "ss_clip_skip": args.clip_skip, + "ss_max_token_length": args.max_token_length, + "ss_color_aug": bool(args.color_aug), + "ss_flip_aug": bool(args.flip_aug), + "ss_random_crop": bool(args.random_crop), + "ss_shuffle_caption": bool(args.shuffle_caption), + "ss_cache_latents": bool(args.cache_latents), + "ss_enable_bucket": bool(train_dataset.enable_bucket), # TODO move to BaseDataset from DB/FT + "ss_min_bucket_reso": args.min_bucket_reso, # TODO get from dataset + "ss_max_bucket_reso": args.max_bucket_reso, + "ss_seed": args.seed } + # uncomment if another network is added + # for key, value in net_kwargs.items(): + # metadata["ss_arg_" + key] = value + if args.pretrained_model_name_or_path is not None: sd_model_name = args.pretrained_model_name_or_path if os.path.exists(sd_model_name): @@ -236,6 +241,13 @@ def train(args): sd_model_name = os.path.basename(sd_model_name) metadata["ss_sd_model_name"] = sd_model_name + if args.vae is not None: + vae_name = args.vae + if os.path.exists(vae_name): + metadata["ss_vae_hash"] = train_util.model_hash(vae_name) + vae_name = os.path.basename(vae_name) + metadata["ss_vae_name"] = vae_name + metadata = {k: str(v) for k, v in metadata.items()} progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") @@ -338,7 +350,7 @@ def train(args): ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as ckpt_file = os.path.join(args.output_dir, ckpt_name) print(f"saving checkpoint: {ckpt_file}") - unwrap_model(network).save_weights(ckpt_file, save_dtype, metadata) + unwrap_model(network).save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata) def remove_old_func(old_epoch_no): old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as @@ -374,7 +386,7 @@ def train(args): ckpt_file = os.path.join(args.output_dir, ckpt_name) print(f"save trained model to {ckpt_file}") - network.save_weights(ckpt_file, save_dtype, metadata) + network.save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata) print("model saved.") @@ -385,6 +397,7 @@ if __name__ == '__main__': train_util.add_dataset_arguments(parser, True, True) train_util.add_training_arguments(parser, True) + parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない") parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"], help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)") From 9622082eb8183730bee17d21df9d4f20422979ac Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 11 Jan 2023 23:12:35 +0900 Subject: [PATCH 23/26] Print metadata for additional network --- gen_img_diffusers.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 208b1b70..1912e720 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -46,11 +46,13 @@ VGG( ) """ +import json from typing import List, Optional, Union import glob import importlib import inspect import time +import zipfile from diffusers.utils import deprecate from diffusers.configuration_utils import FrozenDict import argparse @@ -1972,6 +1974,19 @@ def main(args): if args.network_weights and i < len(args.network_weights): network_weight = args.network_weights[i] print("load network weights from:", network_weight) + + metadata = None + if os.path.splitext(network_weight)[1] == '.safetensors': + from safetensors.torch import safe_open + with safe_open(network_weight, framework="pt") as f: + metadata = f.metadata() + else: + with zipfile.ZipFile(network_weight, "r") as zipf: + if "sd_scripts_metadata.json" in zipf.namelist(): + with zipf.open("sd_scripts_metadata.json", "r") as jsfile: + metadata = json.load(jsfile) + print(f"metadata for: {network_weight}: {metadata}") + network.load_weights(network_weight) network.apply_to(text_encoder, unet) From 9fd91d26a34436c1f906439f9b9fc3be1790a72b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 12 Jan 2023 10:54:21 +0900 Subject: [PATCH 24/26] Store metadata to .ckpt as value of state dict --- networks/lora.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/networks/lora.py b/networks/lora.py index 77fe26a7..de87d064 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -58,7 +58,7 @@ class LoRANetwork(torch.nn.Module): TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] LORA_PREFIX_UNET = 'lora_unet' LORA_PREFIX_TEXT_ENCODER = 'lora_te' - METADATA_FILENAME = "sd_scripts_metadata.json" + METADATA_KEY_NAME = "sd_scripts_metadata" def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4) -> None: super().__init__() @@ -178,7 +178,7 @@ class LoRANetwork(torch.nn.Module): return self.parameters() def save_weights(self, file, dtype, metadata): - if len(metadata) == 0: + if metadata is not None and len(metadata) == 0: metadata = None state_dict = self.state_dict() @@ -193,7 +193,6 @@ class LoRANetwork(torch.nn.Module): from safetensors.torch import save_file save_file(state_dict, file, metadata) else: - torch.save(state_dict, file) if metadata is not None: - with zipfile.ZipFile(file, "w") as zipf: - zipf.writestr(LoRANetwork.METADATA_FILENAME, json.dumps(metadata)) + state_dict[LoRANetwork.METADATA_KEY_NAME] = metadata + torch.save(state_dict, file) From eba142ccb2da4a5c96df5489ae3441a97333a461 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 12 Jan 2023 21:52:55 +0900 Subject: [PATCH 25/26] do not save metadata in .pt/.ckpt --- gen_img_diffusers.py | 9 ++------- networks/lora.py | 5 ----- 2 files changed, 2 insertions(+), 12 deletions(-) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 1912e720..f5133407 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -1975,17 +1975,12 @@ def main(args): network_weight = args.network_weights[i] print("load network weights from:", network_weight) - metadata = None if os.path.splitext(network_weight)[1] == '.safetensors': from safetensors.torch import safe_open with safe_open(network_weight, framework="pt") as f: metadata = f.metadata() - else: - with zipfile.ZipFile(network_weight, "r") as zipf: - if "sd_scripts_metadata.json" in zipf.namelist(): - with zipf.open("sd_scripts_metadata.json", "r") as jsfile: - metadata = json.load(jsfile) - print(f"metadata for: {network_weight}: {metadata}") + if metadata is not None: + print(f"metadata for: {network_weight}: {metadata}") network.load_weights(network_weight) diff --git a/networks/lora.py b/networks/lora.py index de87d064..3f8244e0 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -6,8 +6,6 @@ import math import os import torch -import zipfile -import json class LoRAModule(torch.nn.Module): @@ -58,7 +56,6 @@ class LoRANetwork(torch.nn.Module): TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] LORA_PREFIX_UNET = 'lora_unet' LORA_PREFIX_TEXT_ENCODER = 'lora_te' - METADATA_KEY_NAME = "sd_scripts_metadata" def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4) -> None: super().__init__() @@ -193,6 +190,4 @@ class LoRANetwork(torch.nn.Module): from safetensors.torch import save_file save_file(state_dict, file, metadata) else: - if metadata is not None: - state_dict[LoRANetwork.METADATA_KEY_NAME] = metadata torch.save(state_dict, file) From bf691aef69d883e4d9e61104609b479ba3be9aad Mon Sep 17 00:00:00 2001 From: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Thu, 12 Jan 2023 23:21:21 +0900 Subject: [PATCH 26/26] Update README.md Add updates. --- README.md | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 2495a129..66611977 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,19 @@ This repository contains training, generation and utility scripts for Stable Diffusion. -**January 9, 2023: Information about the update can be found at [the end of the page](#updates-jan-9-2023).** +## Updates + +- January 12, 2023, 2023/1/23 + - Metadata is saved on the model (.safetensors only) (model name, VAE name, training steps, learning rate etc.) The metadata will be able to inspect by sd-webui-additional-networks extension in near future. If you do not want to save it, specify ``no_metadata`` option. + - メタデータが保存されるようになりました( .safetensors 形式の場合のみ)(モデル名、VAE 名、ステップ数、学習率など)。近日中に拡張から確認できるようになる予定です。メタデータを保存したくない場合は ``no_metadata`` オプションをしてしてください。 + +**January 9, 2023: Important information about the update can be found at [the end of the page](#updates-jan-9-2023).** **20231/1/9: 更新情報が[ページ末尾](#更新情報-202319)にありますのでご覧ください。** [日本語版README](./README-ja.md) +## + For easier use (GUI and PowerShell scripts etc...), please visit [the repository maintained by bmaltais](https://github.com/bmaltais/kohya_ss). Thanks to @bmaltais! This repository contains the scripts for: