Merge branch 'original-u-net' into dev

This commit is contained in:
Kohya S
2023-06-17 21:57:08 +09:00
committed by GitHub
18 changed files with 2951 additions and 915 deletions

View File

@@ -5,20 +5,37 @@ import re
from typing import List, Optional, Union
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma):
def prepare_scheduler_for_custom_training(noise_scheduler, device):
if hasattr(noise_scheduler, "all_snr"):
return
alphas_cumprod = noise_scheduler.alphas_cumprod
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
alpha = sqrt_alphas_cumprod
sigma = sqrt_one_minus_alphas_cumprod
all_snr = (alpha / sigma) ** 2
snr = torch.stack([all_snr[t] for t in timesteps])
noise_scheduler.all_snr = all_snr.to(device)
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma):
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float().to(loss.device) # from paper
loss = loss * snr_weight
return loss
def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler):
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
scale = snr_t / (snr_t + 1)
loss = loss * scale
return loss
# TODO train_utilと分散しているのでどちらかに寄せる
@@ -29,6 +46,11 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted
default=None,
help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨",
)
parser.add_argument(
"--scale_v_pred_loss_like_noise_pred",
action="store_true",
help="scale v-prediction loss like noise prediction loss / v-prediction lossをnoise prediction lossと同じようにスケーリングする",
)
if support_weighted_captions:
parser.add_argument(
"--weighted_captions",
@@ -243,11 +265,6 @@ def get_unweighted_text_embeddings(
text_embedding = enc_out["hidden_states"][-clip_skip]
text_embedding = text_encoder.text_model.final_layer_norm(text_embedding)
# cover the head and the tail by the starting and the ending tokens
text_input_chunk[:, 0] = text_input[0, 0]
text_input_chunk[:, -1] = text_input[0, -1]
text_embedding = text_encoder(text_input_chunk, attention_mask=None)[0]
if no_boseos_middle:
if i == 0:
# discard the ending token
@@ -262,7 +279,12 @@ def get_unweighted_text_embeddings(
text_embeddings.append(text_embedding)
text_embeddings = torch.concat(text_embeddings, axis=1)
else:
text_embeddings = text_encoder(text_input)[0]
if clip_skip is None or clip_skip == 1:
text_embeddings = text_encoder(text_input)[0]
else:
enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True)
text_embeddings = enc_out["hidden_states"][-clip_skip]
text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings)
return text_embeddings
@@ -434,46 +456,3 @@ def perlin_noise(noise, device, octaves):
noise += noise_perlin # broadcast for each batch
return noise / noise.std() # Scaled back to roughly unit variance
"""
def max_norm(state_dict, max_norm_value, device):
downkeys = []
upkeys = []
alphakeys = []
norms = []
keys_scaled = 0
for key in state_dict.keys():
if "lora_down" in key and "weight" in key:
downkeys.append(key)
upkeys.append(key.replace("lora_down", "lora_up"))
alphakeys.append(key.replace("lora_down.weight", "alpha"))
for i in range(len(downkeys)):
down = state_dict[downkeys[i]].to(device)
up = state_dict[upkeys[i]].to(device)
alpha = state_dict[alphakeys[i]].to(device)
dim = down.shape[0]
scale = alpha / dim
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
else:
updown = up @ down
updown *= scale
norm = updown.norm().clamp(min=max_norm_value / 2)
desired = torch.clamp(norm, max=max_norm_value)
ratio = desired.cpu() / norm.cpu()
sqrt_ratio = ratio**0.5
if ratio != 1:
keys_scaled += 1
state_dict[upkeys[i]] *= sqrt_ratio
state_dict[downkeys[i]] *= sqrt_ratio
scalednorm = updown.norm() * ratio
norms.append(scalednorm.item())
return keys_scaled, sum(norms) / len(norms), max(norms)

View File

@@ -245,11 +245,6 @@ def get_unweighted_text_embeddings(
text_embedding = enc_out["hidden_states"][-clip_skip]
text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding)
# cover the head and the tail by the starting and the ending tokens
text_input_chunk[:, 0] = text_input[0, 0]
text_input_chunk[:, -1] = text_input[0, -1]
text_embedding = pipe.text_encoder(text_input_chunk, attention_mask=None)[0]
if no_boseos_middle:
if i == 0:
# discard the ending token
@@ -264,7 +259,12 @@ def get_unweighted_text_embeddings(
text_embeddings.append(text_embedding)
text_embeddings = torch.concat(text_embeddings, axis=1)
else:
text_embeddings = pipe.text_encoder(text_input)[0]
if clip_skip is None or clip_skip == 1:
text_embeddings = pipe.text_encoder(text_input)[0]
else:
enc_out = pipe.text_encoder(text_input, output_hidden_states=True, return_dict=True)
text_embeddings = enc_out["hidden_states"][-clip_skip]
text_embeddings = pipe.text_encoder.text_model.final_layer_norm(text_embeddings)
return text_embeddings
@@ -517,6 +517,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: SchedulerMixin,
# clip_skip: int,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,

View File

@@ -4,9 +4,11 @@
import math
import os
import torch
import diffusers
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel
from safetensors.torch import load_file, save_file
from library.original_unet import UNet2DConditionModel
# DiffUsers版StableDiffusionのモデルパラメータ
NUM_TRAIN_TIMESTEPS = 1000
@@ -126,17 +128,30 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
new_item = new_item.replace("norm.weight", "group_norm.weight")
new_item = new_item.replace("norm.bias", "group_norm.bias")
new_item = new_item.replace("q.weight", "query.weight")
new_item = new_item.replace("q.bias", "query.bias")
if diffusers.__version__ < "0.17.0":
new_item = new_item.replace("q.weight", "query.weight")
new_item = new_item.replace("q.bias", "query.bias")
new_item = new_item.replace("k.weight", "key.weight")
new_item = new_item.replace("k.bias", "key.bias")
new_item = new_item.replace("k.weight", "key.weight")
new_item = new_item.replace("k.bias", "key.bias")
new_item = new_item.replace("v.weight", "value.weight")
new_item = new_item.replace("v.bias", "value.bias")
new_item = new_item.replace("v.weight", "value.weight")
new_item = new_item.replace("v.bias", "value.bias")
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
else:
new_item = new_item.replace("q.weight", "to_q.weight")
new_item = new_item.replace("q.bias", "to_q.bias")
new_item = new_item.replace("k.weight", "to_k.weight")
new_item = new_item.replace("k.bias", "to_k.bias")
new_item = new_item.replace("v.weight", "to_v.weight")
new_item = new_item.replace("v.bias", "to_v.bias")
new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
@@ -191,8 +206,16 @@ def assign_to_checkpoint(
new_path = new_path.replace(replacement["old"], replacement["new"])
# proj_attn.weight has to be converted from conv 1D to linear
if "proj_attn.weight" in new_path:
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
reshaping = False
if diffusers.__version__ < "0.17.0":
if "proj_attn.weight" in new_path:
reshaping = True
else:
if ".attentions." in new_path and ".0.to_" in new_path and old_checkpoint[path["old"]].ndim > 2:
reshaping = True
if reshaping:
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0]
else:
checkpoint[new_path] = old_checkpoint[path["old"]]
@@ -361,7 +384,7 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
# SDのv2では1*1のconv2dがlinearに変わっている
# 誤って Diffusers 側を conv2d のままにしてしまったので、変換必要
if v2 and not config.get('use_linear_projection', False):
if v2 and not config.get("use_linear_projection", False):
linear_transformer_to_conv(new_checkpoint)
return new_checkpoint
@@ -877,14 +900,24 @@ def convert_vae_state_dict(vae_state_dict):
sd_mid_res_prefix = f"mid.block_{i+1}."
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
vae_conversion_map_attn = [
# (stable-diffusion, HF Diffusers)
("norm.", "group_norm."),
("q.", "query."),
("k.", "key."),
("v.", "value."),
("proj_out.", "proj_attn."),
]
if diffusers.__version__ < "0.17.0":
vae_conversion_map_attn = [
# (stable-diffusion, HF Diffusers)
("norm.", "group_norm."),
("q.", "query."),
("k.", "key."),
("v.", "value."),
("proj_out.", "proj_attn."),
]
else:
vae_conversion_map_attn = [
# (stable-diffusion, HF Diffusers)
("norm.", "group_norm."),
("q.", "to_q."),
("k.", "to_k."),
("v.", "to_v."),
("proj_out.", "to_out.0."),
]
mapping = {k: k for k in vae_state_dict.keys()}
for k, v in mapping.items():
@@ -901,7 +934,7 @@ def convert_vae_state_dict(vae_state_dict):
for k, v in new_state_dict.items():
for weight_name in weights_to_convert:
if f"mid.attn_1.{weight_name}.weight" in k:
# print(f"Reshaping {k} for SD format")
# print(f"Reshaping {k} for SD format: shape {v.shape} -> {v.shape} x 1 x 1")
new_state_dict[k] = reshape_weight_for_sd(v)
return new_state_dict
@@ -998,10 +1031,31 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt
else:
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
logging.set_verbosity_error() # don't show annoying warning
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
logging.set_verbosity_warning()
# logging.set_verbosity_error() # don't show annoying warning
# text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
# logging.set_verbosity_warning()
# print(f"config: {text_model.config}")
cfg = CLIPTextConfig(
vocab_size=49408,
hidden_size=768,
intermediate_size=3072,
num_hidden_layers=12,
num_attention_heads=12,
max_position_embeddings=77,
hidden_act="quick_gelu",
layer_norm_eps=1e-05,
dropout=0.0,
attention_dropout=0.0,
initializer_range=0.02,
initializer_factor=1.0,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
model_type="clip_text_model",
projection_dim=768,
torch_dtype="float32",
)
text_model = CLIPTextModel._from_config(cfg)
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
print("loading text encoder:", info)

1593
library/original_unet.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -36,7 +36,6 @@ from torch.optim import Optimizer
from torchvision import transforms
from transformers import CLIPTokenizer
import transformers
import diffusers
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
from diffusers import (
StableDiffusionPipeline,
@@ -52,6 +51,7 @@ from diffusers import (
KDPM2DiscreteScheduler,
KDPM2AncestralDiscreteScheduler,
)
from library.original_unet import UNet2DConditionModel
from huggingface_hub import hf_hub_download
import albumentations as albu
import numpy as np
@@ -65,6 +65,7 @@ import library.model_util as model_util
import library.huggingface_util as huggingface_util
from library.attention_processors import FlashAttnProcessor
from library.hypernetwork import replace_attentions_for_hypernetwork
from library.original_unet import UNet2DConditionModel
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
@@ -1828,6 +1829,76 @@ def glob_images_pathlib(dir_path, recursive):
return image_paths
class MinimalDataset(BaseDataset):
def __init__(self, tokenizer, max_token_length, resolution, debug_dataset=False):
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
self.num_train_images = 0 # update in subclass
self.num_reg_images = 0 # update in subclass
self.datasets = [self]
self.batch_size = 1 # update in subclass
self.subsets = [self]
self.num_repeats = 1 # update in subclass if needed
self.img_count = 1 # update in subclass if needed
self.bucket_info = {}
self.is_reg = False
self.image_dir = "dummy" # for metadata
def is_latent_cacheable(self) -> bool:
return False
def __len__(self):
raise NotImplementedError
# override to avoid shuffling buckets
def set_current_epoch(self, epoch):
self.current_epoch = epoch
def __getitem__(self, idx):
r"""
The subclass may have image_data for debug_dataset, which is a dict of ImageInfo objects.
Returns: example like this:
for i in range(batch_size):
image_key = ... # whatever hashable
image_keys.append(image_key)
image = ... # PIL Image
img_tensor = self.image_transforms(img)
images.append(img_tensor)
caption = ... # str
input_ids = self.get_input_ids(caption)
input_ids_list.append(input_ids)
captions.append(caption)
images = torch.stack(images, dim=0)
input_ids_list = torch.stack(input_ids_list, dim=0)
example = {
"images": images,
"input_ids": input_ids_list,
"captions": captions, # for debug_dataset
"latents": None,
"image_keys": image_keys, # for debug_dataset
"loss_weights": torch.ones(batch_size, dtype=torch.float32),
}
return example
"""
raise NotImplementedError
def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset:
module = ".".join(args.dataset_class.split(".")[:-1])
dataset_class = args.dataset_class.split(".")[-1]
module = importlib.import_module(module)
dataset_class = getattr(module, dataset_class)
train_dataset_group: MinimalDataset = dataset_class(tokenizer, args.max_token_length, args.resolution, args.debug_dataset)
return train_dataset_group
# endregion
# region モジュール入れ替え部
@@ -1941,59 +2012,73 @@ def get_git_revision_hash() -> str:
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
replace_attentions_for_hypernetwork()
# unet is not used currently, but it is here for future use
unet.enable_xformers_memory_efficient_attention()
return
# def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
# replace_attentions_for_hypernetwork()
# # unet is not used currently, but it is here for future use
# unet.enable_xformers_memory_efficient_attention()
# return
# if mem_eff_attn:
# unet.set_attn_processor(FlashAttnProcessor())
# elif xformers:
# unet.enable_xformers_memory_efficient_attention()
# def replace_unet_cross_attn_to_xformers():
# print("CrossAttention.forward has been replaced to enable xformers.")
# try:
# import xformers.ops
# except ImportError:
# raise ImportError("No xformers / xformersがインストールされていないようです")
# def forward_xformers(self, x, context=None, mask=None):
# h = self.heads
# q_in = self.to_q(x)
# context = default(context, x)
# context = context.to(x.dtype)
# if hasattr(self, "hypernetwork") and self.hypernetwork is not None:
# context_k, context_v = self.hypernetwork.forward(x, context)
# context_k = context_k.to(x.dtype)
# context_v = context_v.to(x.dtype)
# else:
# context_k = context
# context_v = context
# k_in = self.to_k(context_k)
# v_in = self.to_v(context_v)
# q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in))
# del q_in, k_in, v_in
# q = q.contiguous()
# k = k.contiguous()
# v = v.contiguous()
# out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
# out = rearrange(out, "b n h d -> b n (h d)", h=h)
# # diffusers 0.7.0~
# out = self.to_out[0](out)
# out = self.to_out[1](out)
# return out
# diffusers.models.attention.CrossAttention.forward = forward_xformers
def replace_unet_modules(unet:UNet2DConditionModel, mem_eff_attn, xformers, sdpa):
if mem_eff_attn:
unet.set_attn_processor(FlashAttnProcessor())
print("Enable memory efficient attention for U-Net")
unet.set_use_memory_efficient_attention(False, True)
elif xformers:
unet.enable_xformers_memory_efficient_attention()
def replace_unet_cross_attn_to_xformers():
print("CrossAttention.forward has been replaced to enable xformers.")
try:
import xformers.ops
except ImportError:
raise ImportError("No xformers / xformersがインストールされていないようです")
def forward_xformers(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
context = default(context, x)
context = context.to(x.dtype)
if hasattr(self, "hypernetwork") and self.hypernetwork is not None:
context_k, context_v = self.hypernetwork.forward(x, context)
context_k = context_k.to(x.dtype)
context_v = context_v.to(x.dtype)
else:
context_k = context
context_v = context
k_in = self.to_k(context_k)
v_in = self.to_v(context_v)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
out = rearrange(out, "b n h d -> b n (h d)", h=h)
# diffusers 0.7.0~
out = self.to_out[0](out)
out = self.to_out[1](out)
return out
diffusers.models.attention.CrossAttention.forward = forward_xformers
print("Enable xformers for U-Net")
try:
import xformers.ops
except ImportError:
raise ImportError("No xformers / xformersがインストールされていないようです")
unet.set_use_memory_efficient_attention(True, False)
elif sdpa:
print("Enable SDPA for U-Net")
unet.set_use_sdpa(True)
"""
def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers):
@@ -2242,6 +2327,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う",
)
parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
parser.add_argument("--sdpa", action="store_true", help="use sdpa for CrossAttention (requires PyTorch 2.0) / CrossAttentionにsdpaを使うPyTorch 2.0が必要)")
parser.add_argument(
"--vae", type=str, default=None, help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ"
)
@@ -2428,6 +2514,11 @@ def verify_training_args(args: argparse.Namespace):
if args.adaptive_noise_scale is not None and args.noise_offset is None:
raise ValueError("adaptive_noise_scale requires noise_offset / adaptive_noise_scaleを使用するにはnoise_offsetが必要です")
if args.scale_v_pred_loss_like_noise_pred and not args.v_parameterization:
raise ValueError(
"scale_v_pred_loss_like_noise_pred can be enabled only with v_parameterization / scale_v_pred_loss_like_noise_predはv_parameterizationが有効なときのみ有効にできます"
)
def add_dataset_arguments(
parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool
@@ -2506,7 +2597,6 @@ def add_dataset_arguments(
default=1,
help="start learning at N tags (token means comma separated strinfloatgs) / タグ数をN個から増やしながら学習する",
)
parser.add_argument(
"--token_warmup_step",
type=float,
@@ -2514,6 +2604,13 @@ def add_dataset_arguments(
help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / NN<1ならN*max_train_stepsステップでタグ長が最大になる。デフォルトは0最初から最大",
)
parser.add_argument(
"--dataset_class",
type=str,
default=None,
help="dataset class for arbitrary dataset (package.module.Class) / 任意のデータセットを用いるときのクラス名 (package.module.Class)",
)
if support_caption_dropout:
# Textual Inversion はcaptionのdropoutをsupportしない
# いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
@@ -2788,15 +2885,7 @@ def get_optimizer(args, trainable_params):
optimizer_class = torch.optim.SGD
optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs)
elif optimizer_type.startswith("DAdapt".lower()):
# DAdaptation family
# check dadaptation is installed
try:
import dadaptation
import dadaptation.experimental as experimental
except ImportError:
raise ImportError("No dadaptation / dadaptation がインストールされていないようです")
elif optimizer_type.startswith("DAdapt".lower()) or optimizer_type == "Prodigy".lower():
# check lr and lr_count, and print warning
actual_lr = lr
lr_count = 1
@@ -2809,40 +2898,60 @@ def get_optimizer(args, trainable_params):
if actual_lr <= 0.1:
print(
f"learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: lr={actual_lr}"
f"learning rate is too low. If using D-Adaptation or Prodigy, set learning rate around 1.0 / 学習率が低すぎるようです。D-AdaptationまたはProdigyの使用時は1.0前後の値を指定してください: lr={actual_lr}"
)
print("recommend option: lr=1.0 / 推奨は1.0です")
if lr_count > 1:
print(
f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-Adaptationで複数の学習率を指定した場合Text EncoderとU-Netなど、最初の学習率のみが有効になります: lr={actual_lr}"
f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-AdaptationまたはProdigyで複数の学習率を指定した場合Text EncoderとU-Netなど、最初の学習率のみが有効になります: lr={actual_lr}"
)
# set optimizer
if optimizer_type == "DAdaptation".lower() or optimizer_type == "DAdaptAdamPreprint".lower():
optimizer_class = experimental.DAdaptAdamPreprint
print(f"use D-Adaptation AdamPreprint optimizer | {optimizer_kwargs}")
elif optimizer_type == "DAdaptAdaGrad".lower():
optimizer_class = dadaptation.DAdaptAdaGrad
print(f"use D-Adaptation AdaGrad optimizer | {optimizer_kwargs}")
elif optimizer_type == "DAdaptAdam".lower():
optimizer_class = dadaptation.DAdaptAdam
print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}")
elif optimizer_type == "DAdaptAdan".lower():
optimizer_class = dadaptation.DAdaptAdan
print(f"use D-Adaptation Adan optimizer | {optimizer_kwargs}")
elif optimizer_type == "DAdaptAdanIP".lower():
optimizer_class = experimental.DAdaptAdanIP
print(f"use D-Adaptation AdanIP optimizer | {optimizer_kwargs}")
elif optimizer_type == "DAdaptLion".lower():
optimizer_class = dadaptation.DAdaptLion
print(f"use D-Adaptation Lion optimizer | {optimizer_kwargs}")
elif optimizer_type == "DAdaptSGD".lower():
optimizer_class = dadaptation.DAdaptSGD
print(f"use D-Adaptation SGD optimizer | {optimizer_kwargs}")
else:
raise ValueError(f"Unknown optimizer type: {optimizer_type}")
if optimizer_type.startswith("DAdapt".lower()):
# DAdaptation family
# check dadaptation is installed
try:
import dadaptation
import dadaptation.experimental as experimental
except ImportError:
raise ImportError("No dadaptation / dadaptation がインストールされていないようです")
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
# set optimizer
if optimizer_type == "DAdaptation".lower() or optimizer_type == "DAdaptAdamPreprint".lower():
optimizer_class = experimental.DAdaptAdamPreprint
print(f"use D-Adaptation AdamPreprint optimizer | {optimizer_kwargs}")
elif optimizer_type == "DAdaptAdaGrad".lower():
optimizer_class = dadaptation.DAdaptAdaGrad
print(f"use D-Adaptation AdaGrad optimizer | {optimizer_kwargs}")
elif optimizer_type == "DAdaptAdam".lower():
optimizer_class = dadaptation.DAdaptAdam
print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}")
elif optimizer_type == "DAdaptAdan".lower():
optimizer_class = dadaptation.DAdaptAdan
print(f"use D-Adaptation Adan optimizer | {optimizer_kwargs}")
elif optimizer_type == "DAdaptAdanIP".lower():
optimizer_class = experimental.DAdaptAdanIP
print(f"use D-Adaptation AdanIP optimizer | {optimizer_kwargs}")
elif optimizer_type == "DAdaptLion".lower():
optimizer_class = dadaptation.DAdaptLion
print(f"use D-Adaptation Lion optimizer | {optimizer_kwargs}")
elif optimizer_type == "DAdaptSGD".lower():
optimizer_class = dadaptation.DAdaptSGD
print(f"use D-Adaptation SGD optimizer | {optimizer_kwargs}")
else:
raise ValueError(f"Unknown optimizer type: {optimizer_type}")
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
else:
# Prodigy
# check Prodigy is installed
try:
import prodigyopt
except ImportError:
raise ImportError("No Prodigy / Prodigy がインストールされていないようです")
print(f"use Prodigy optimizer | {optimizer_kwargs}")
optimizer_class = prodigyopt.Prodigy
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
elif optimizer_type == "Adafactor".lower():
# 引数を確認して適宜補正する
@@ -3093,23 +3202,9 @@ def prepare_accelerator(args: argparse.Namespace):
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=log_with,
logging_dir=logging_dir,
project_dir=logging_dir,
)
# accelerateの互換性問題を解決する
accelerator_0_15 = True
try:
accelerator.unwrap_model("dummy", True)
print("Using accelerator 0.15.0 or above.")
except TypeError:
accelerator_0_15 = False
def unwrap_model(model):
if accelerator_0_15:
return accelerator.unwrap_model(model, True)
return accelerator.unwrap_model(model)
return accelerator, unwrap_model
return accelerator
def prepare_dtype(args: argparse.Namespace):
@@ -3146,11 +3241,26 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une
print(
f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}"
)
raise ex
text_encoder = pipe.text_encoder
vae = pipe.vae
unet = pipe.unet
del pipe
# Diffusers U-Net to original U-Net
# TODO *.ckpt/*.safetensorsのv2と同じ形式にここで変換すると良さそう
# print(f"unet config: {unet.config}")
original_unet = UNet2DConditionModel(
unet.config.sample_size,
unet.config.attention_head_dim,
unet.config.cross_attention_dim,
unet.config.use_linear_projection,
unet.config.upcast_attention,
)
original_unet.load_state_dict(unet.state_dict())
unet = original_unet
print("U-Net converted to original U-Net")
# VAEを読み込む
if args.vae is not None:
vae = model_util.load_vae(args.vae, weight_dtype)
@@ -3580,6 +3690,7 @@ def sample_images(
requires_safety_checker=False,
clip_skip=args.clip_skip,
)
pipeline.clip_skip = args.clip_skip # Pipelineのコンストラクタにckip_skipを追加できないので後から設定する
pipeline.to(device)
save_dir = args.output_dir + "/sample"
@@ -3769,4 +3880,4 @@ class collater_class:
# set epoch and step
dataset.set_current_epoch(self.current_epoch.value)
dataset.set_current_step(self.current_step.value)
return examples[0]
return examples[0]