mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Merge branch 'original-u-net' into dev
This commit is contained in:
@@ -5,20 +5,37 @@ import re
|
||||
from typing import List, Optional, Union
|
||||
|
||||
|
||||
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma):
|
||||
def prepare_scheduler_for_custom_training(noise_scheduler, device):
|
||||
if hasattr(noise_scheduler, "all_snr"):
|
||||
return
|
||||
|
||||
alphas_cumprod = noise_scheduler.alphas_cumprod
|
||||
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
|
||||
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
|
||||
alpha = sqrt_alphas_cumprod
|
||||
sigma = sqrt_one_minus_alphas_cumprod
|
||||
all_snr = (alpha / sigma) ** 2
|
||||
snr = torch.stack([all_snr[t] for t in timesteps])
|
||||
|
||||
noise_scheduler.all_snr = all_snr.to(device)
|
||||
|
||||
|
||||
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma):
|
||||
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
|
||||
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
|
||||
snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float().to(loss.device) # from paper
|
||||
loss = loss * snr_weight
|
||||
return loss
|
||||
|
||||
|
||||
def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler):
|
||||
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
|
||||
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
|
||||
scale = snr_t / (snr_t + 1)
|
||||
|
||||
loss = loss * scale
|
||||
return loss
|
||||
|
||||
|
||||
# TODO train_utilと分散しているのでどちらかに寄せる
|
||||
|
||||
|
||||
@@ -29,6 +46,11 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted
|
||||
default=None,
|
||||
help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale_v_pred_loss_like_noise_pred",
|
||||
action="store_true",
|
||||
help="scale v-prediction loss like noise prediction loss / v-prediction lossをnoise prediction lossと同じようにスケーリングする",
|
||||
)
|
||||
if support_weighted_captions:
|
||||
parser.add_argument(
|
||||
"--weighted_captions",
|
||||
@@ -243,11 +265,6 @@ def get_unweighted_text_embeddings(
|
||||
text_embedding = enc_out["hidden_states"][-clip_skip]
|
||||
text_embedding = text_encoder.text_model.final_layer_norm(text_embedding)
|
||||
|
||||
# cover the head and the tail by the starting and the ending tokens
|
||||
text_input_chunk[:, 0] = text_input[0, 0]
|
||||
text_input_chunk[:, -1] = text_input[0, -1]
|
||||
text_embedding = text_encoder(text_input_chunk, attention_mask=None)[0]
|
||||
|
||||
if no_boseos_middle:
|
||||
if i == 0:
|
||||
# discard the ending token
|
||||
@@ -262,7 +279,12 @@ def get_unweighted_text_embeddings(
|
||||
text_embeddings.append(text_embedding)
|
||||
text_embeddings = torch.concat(text_embeddings, axis=1)
|
||||
else:
|
||||
text_embeddings = text_encoder(text_input)[0]
|
||||
if clip_skip is None or clip_skip == 1:
|
||||
text_embeddings = text_encoder(text_input)[0]
|
||||
else:
|
||||
enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True)
|
||||
text_embeddings = enc_out["hidden_states"][-clip_skip]
|
||||
text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings)
|
||||
return text_embeddings
|
||||
|
||||
|
||||
@@ -434,46 +456,3 @@ def perlin_noise(noise, device, octaves):
|
||||
noise += noise_perlin # broadcast for each batch
|
||||
return noise / noise.std() # Scaled back to roughly unit variance
|
||||
"""
|
||||
|
||||
|
||||
def max_norm(state_dict, max_norm_value, device):
|
||||
downkeys = []
|
||||
upkeys = []
|
||||
alphakeys = []
|
||||
norms = []
|
||||
keys_scaled = 0
|
||||
|
||||
for key in state_dict.keys():
|
||||
if "lora_down" in key and "weight" in key:
|
||||
downkeys.append(key)
|
||||
upkeys.append(key.replace("lora_down", "lora_up"))
|
||||
alphakeys.append(key.replace("lora_down.weight", "alpha"))
|
||||
|
||||
for i in range(len(downkeys)):
|
||||
down = state_dict[downkeys[i]].to(device)
|
||||
up = state_dict[upkeys[i]].to(device)
|
||||
alpha = state_dict[alphakeys[i]].to(device)
|
||||
dim = down.shape[0]
|
||||
scale = alpha / dim
|
||||
|
||||
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
|
||||
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||
elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
|
||||
updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
|
||||
else:
|
||||
updown = up @ down
|
||||
|
||||
updown *= scale
|
||||
|
||||
norm = updown.norm().clamp(min=max_norm_value / 2)
|
||||
desired = torch.clamp(norm, max=max_norm_value)
|
||||
ratio = desired.cpu() / norm.cpu()
|
||||
sqrt_ratio = ratio**0.5
|
||||
if ratio != 1:
|
||||
keys_scaled += 1
|
||||
state_dict[upkeys[i]] *= sqrt_ratio
|
||||
state_dict[downkeys[i]] *= sqrt_ratio
|
||||
scalednorm = updown.norm() * ratio
|
||||
norms.append(scalednorm.item())
|
||||
|
||||
return keys_scaled, sum(norms) / len(norms), max(norms)
|
||||
|
||||
@@ -245,11 +245,6 @@ def get_unweighted_text_embeddings(
|
||||
text_embedding = enc_out["hidden_states"][-clip_skip]
|
||||
text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding)
|
||||
|
||||
# cover the head and the tail by the starting and the ending tokens
|
||||
text_input_chunk[:, 0] = text_input[0, 0]
|
||||
text_input_chunk[:, -1] = text_input[0, -1]
|
||||
text_embedding = pipe.text_encoder(text_input_chunk, attention_mask=None)[0]
|
||||
|
||||
if no_boseos_middle:
|
||||
if i == 0:
|
||||
# discard the ending token
|
||||
@@ -264,7 +259,12 @@ def get_unweighted_text_embeddings(
|
||||
text_embeddings.append(text_embedding)
|
||||
text_embeddings = torch.concat(text_embeddings, axis=1)
|
||||
else:
|
||||
text_embeddings = pipe.text_encoder(text_input)[0]
|
||||
if clip_skip is None or clip_skip == 1:
|
||||
text_embeddings = pipe.text_encoder(text_input)[0]
|
||||
else:
|
||||
enc_out = pipe.text_encoder(text_input, output_hidden_states=True, return_dict=True)
|
||||
text_embeddings = enc_out["hidden_states"][-clip_skip]
|
||||
text_embeddings = pipe.text_encoder.text_model.final_layer_norm(text_embeddings)
|
||||
return text_embeddings
|
||||
|
||||
|
||||
@@ -517,6 +517,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: SchedulerMixin,
|
||||
# clip_skip: int,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
|
||||
@@ -4,9 +4,11 @@
|
||||
import math
|
||||
import os
|
||||
import torch
|
||||
import diffusers
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
|
||||
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
||||
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel
|
||||
from safetensors.torch import load_file, save_file
|
||||
from library.original_unet import UNet2DConditionModel
|
||||
|
||||
# DiffUsers版StableDiffusionのモデルパラメータ
|
||||
NUM_TRAIN_TIMESTEPS = 1000
|
||||
@@ -126,17 +128,30 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
||||
new_item = new_item.replace("norm.weight", "group_norm.weight")
|
||||
new_item = new_item.replace("norm.bias", "group_norm.bias")
|
||||
|
||||
new_item = new_item.replace("q.weight", "query.weight")
|
||||
new_item = new_item.replace("q.bias", "query.bias")
|
||||
if diffusers.__version__ < "0.17.0":
|
||||
new_item = new_item.replace("q.weight", "query.weight")
|
||||
new_item = new_item.replace("q.bias", "query.bias")
|
||||
|
||||
new_item = new_item.replace("k.weight", "key.weight")
|
||||
new_item = new_item.replace("k.bias", "key.bias")
|
||||
new_item = new_item.replace("k.weight", "key.weight")
|
||||
new_item = new_item.replace("k.bias", "key.bias")
|
||||
|
||||
new_item = new_item.replace("v.weight", "value.weight")
|
||||
new_item = new_item.replace("v.bias", "value.bias")
|
||||
new_item = new_item.replace("v.weight", "value.weight")
|
||||
new_item = new_item.replace("v.bias", "value.bias")
|
||||
|
||||
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
|
||||
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
|
||||
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
|
||||
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
|
||||
else:
|
||||
new_item = new_item.replace("q.weight", "to_q.weight")
|
||||
new_item = new_item.replace("q.bias", "to_q.bias")
|
||||
|
||||
new_item = new_item.replace("k.weight", "to_k.weight")
|
||||
new_item = new_item.replace("k.bias", "to_k.bias")
|
||||
|
||||
new_item = new_item.replace("v.weight", "to_v.weight")
|
||||
new_item = new_item.replace("v.bias", "to_v.bias")
|
||||
|
||||
new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
|
||||
new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
|
||||
|
||||
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
||||
|
||||
@@ -191,8 +206,16 @@ def assign_to_checkpoint(
|
||||
new_path = new_path.replace(replacement["old"], replacement["new"])
|
||||
|
||||
# proj_attn.weight has to be converted from conv 1D to linear
|
||||
if "proj_attn.weight" in new_path:
|
||||
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
|
||||
reshaping = False
|
||||
if diffusers.__version__ < "0.17.0":
|
||||
if "proj_attn.weight" in new_path:
|
||||
reshaping = True
|
||||
else:
|
||||
if ".attentions." in new_path and ".0.to_" in new_path and old_checkpoint[path["old"]].ndim > 2:
|
||||
reshaping = True
|
||||
|
||||
if reshaping:
|
||||
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0]
|
||||
else:
|
||||
checkpoint[new_path] = old_checkpoint[path["old"]]
|
||||
|
||||
@@ -361,7 +384,7 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
|
||||
|
||||
# SDのv2では1*1のconv2dがlinearに変わっている
|
||||
# 誤って Diffusers 側を conv2d のままにしてしまったので、変換必要
|
||||
if v2 and not config.get('use_linear_projection', False):
|
||||
if v2 and not config.get("use_linear_projection", False):
|
||||
linear_transformer_to_conv(new_checkpoint)
|
||||
|
||||
return new_checkpoint
|
||||
@@ -877,14 +900,24 @@ def convert_vae_state_dict(vae_state_dict):
|
||||
sd_mid_res_prefix = f"mid.block_{i+1}."
|
||||
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||
|
||||
vae_conversion_map_attn = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("norm.", "group_norm."),
|
||||
("q.", "query."),
|
||||
("k.", "key."),
|
||||
("v.", "value."),
|
||||
("proj_out.", "proj_attn."),
|
||||
]
|
||||
if diffusers.__version__ < "0.17.0":
|
||||
vae_conversion_map_attn = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("norm.", "group_norm."),
|
||||
("q.", "query."),
|
||||
("k.", "key."),
|
||||
("v.", "value."),
|
||||
("proj_out.", "proj_attn."),
|
||||
]
|
||||
else:
|
||||
vae_conversion_map_attn = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("norm.", "group_norm."),
|
||||
("q.", "to_q."),
|
||||
("k.", "to_k."),
|
||||
("v.", "to_v."),
|
||||
("proj_out.", "to_out.0."),
|
||||
]
|
||||
|
||||
mapping = {k: k for k in vae_state_dict.keys()}
|
||||
for k, v in mapping.items():
|
||||
@@ -901,7 +934,7 @@ def convert_vae_state_dict(vae_state_dict):
|
||||
for k, v in new_state_dict.items():
|
||||
for weight_name in weights_to_convert:
|
||||
if f"mid.attn_1.{weight_name}.weight" in k:
|
||||
# print(f"Reshaping {k} for SD format")
|
||||
# print(f"Reshaping {k} for SD format: shape {v.shape} -> {v.shape} x 1 x 1")
|
||||
new_state_dict[k] = reshape_weight_for_sd(v)
|
||||
|
||||
return new_state_dict
|
||||
@@ -998,10 +1031,31 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt
|
||||
else:
|
||||
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
|
||||
|
||||
logging.set_verbosity_error() # don't show annoying warning
|
||||
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
|
||||
logging.set_verbosity_warning()
|
||||
|
||||
# logging.set_verbosity_error() # don't show annoying warning
|
||||
# text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
|
||||
# logging.set_verbosity_warning()
|
||||
# print(f"config: {text_model.config}")
|
||||
cfg = CLIPTextConfig(
|
||||
vocab_size=49408,
|
||||
hidden_size=768,
|
||||
intermediate_size=3072,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
max_position_embeddings=77,
|
||||
hidden_act="quick_gelu",
|
||||
layer_norm_eps=1e-05,
|
||||
dropout=0.0,
|
||||
attention_dropout=0.0,
|
||||
initializer_range=0.02,
|
||||
initializer_factor=1.0,
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
model_type="clip_text_model",
|
||||
projection_dim=768,
|
||||
torch_dtype="float32",
|
||||
)
|
||||
text_model = CLIPTextModel._from_config(cfg)
|
||||
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
||||
print("loading text encoder:", info)
|
||||
|
||||
|
||||
1593
library/original_unet.py
Normal file
1593
library/original_unet.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -36,7 +36,6 @@ from torch.optim import Optimizer
|
||||
from torchvision import transforms
|
||||
from transformers import CLIPTokenizer
|
||||
import transformers
|
||||
import diffusers
|
||||
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
|
||||
from diffusers import (
|
||||
StableDiffusionPipeline,
|
||||
@@ -52,6 +51,7 @@ from diffusers import (
|
||||
KDPM2DiscreteScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
)
|
||||
from library.original_unet import UNet2DConditionModel
|
||||
from huggingface_hub import hf_hub_download
|
||||
import albumentations as albu
|
||||
import numpy as np
|
||||
@@ -65,6 +65,7 @@ import library.model_util as model_util
|
||||
import library.huggingface_util as huggingface_util
|
||||
from library.attention_processors import FlashAttnProcessor
|
||||
from library.hypernetwork import replace_attentions_for_hypernetwork
|
||||
from library.original_unet import UNet2DConditionModel
|
||||
|
||||
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
|
||||
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
||||
@@ -1828,6 +1829,76 @@ def glob_images_pathlib(dir_path, recursive):
|
||||
return image_paths
|
||||
|
||||
|
||||
class MinimalDataset(BaseDataset):
|
||||
def __init__(self, tokenizer, max_token_length, resolution, debug_dataset=False):
|
||||
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
|
||||
|
||||
self.num_train_images = 0 # update in subclass
|
||||
self.num_reg_images = 0 # update in subclass
|
||||
self.datasets = [self]
|
||||
self.batch_size = 1 # update in subclass
|
||||
|
||||
self.subsets = [self]
|
||||
self.num_repeats = 1 # update in subclass if needed
|
||||
self.img_count = 1 # update in subclass if needed
|
||||
self.bucket_info = {}
|
||||
self.is_reg = False
|
||||
self.image_dir = "dummy" # for metadata
|
||||
|
||||
def is_latent_cacheable(self) -> bool:
|
||||
return False
|
||||
|
||||
def __len__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
# override to avoid shuffling buckets
|
||||
def set_current_epoch(self, epoch):
|
||||
self.current_epoch = epoch
|
||||
|
||||
def __getitem__(self, idx):
|
||||
r"""
|
||||
The subclass may have image_data for debug_dataset, which is a dict of ImageInfo objects.
|
||||
|
||||
Returns: example like this:
|
||||
|
||||
for i in range(batch_size):
|
||||
image_key = ... # whatever hashable
|
||||
image_keys.append(image_key)
|
||||
|
||||
image = ... # PIL Image
|
||||
img_tensor = self.image_transforms(img)
|
||||
images.append(img_tensor)
|
||||
|
||||
caption = ... # str
|
||||
input_ids = self.get_input_ids(caption)
|
||||
input_ids_list.append(input_ids)
|
||||
|
||||
captions.append(caption)
|
||||
|
||||
images = torch.stack(images, dim=0)
|
||||
input_ids_list = torch.stack(input_ids_list, dim=0)
|
||||
example = {
|
||||
"images": images,
|
||||
"input_ids": input_ids_list,
|
||||
"captions": captions, # for debug_dataset
|
||||
"latents": None,
|
||||
"image_keys": image_keys, # for debug_dataset
|
||||
"loss_weights": torch.ones(batch_size, dtype=torch.float32),
|
||||
}
|
||||
return example
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset:
|
||||
module = ".".join(args.dataset_class.split(".")[:-1])
|
||||
dataset_class = args.dataset_class.split(".")[-1]
|
||||
module = importlib.import_module(module)
|
||||
dataset_class = getattr(module, dataset_class)
|
||||
train_dataset_group: MinimalDataset = dataset_class(tokenizer, args.max_token_length, args.resolution, args.debug_dataset)
|
||||
return train_dataset_group
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region モジュール入れ替え部
|
||||
@@ -1941,59 +2012,73 @@ def get_git_revision_hash() -> str:
|
||||
|
||||
|
||||
|
||||
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
|
||||
replace_attentions_for_hypernetwork()
|
||||
# unet is not used currently, but it is here for future use
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
return
|
||||
# def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
|
||||
# replace_attentions_for_hypernetwork()
|
||||
# # unet is not used currently, but it is here for future use
|
||||
# unet.enable_xformers_memory_efficient_attention()
|
||||
# return
|
||||
# if mem_eff_attn:
|
||||
# unet.set_attn_processor(FlashAttnProcessor())
|
||||
# elif xformers:
|
||||
# unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
|
||||
# def replace_unet_cross_attn_to_xformers():
|
||||
# print("CrossAttention.forward has been replaced to enable xformers.")
|
||||
# try:
|
||||
# import xformers.ops
|
||||
# except ImportError:
|
||||
# raise ImportError("No xformers / xformersがインストールされていないようです")
|
||||
|
||||
# def forward_xformers(self, x, context=None, mask=None):
|
||||
# h = self.heads
|
||||
# q_in = self.to_q(x)
|
||||
|
||||
# context = default(context, x)
|
||||
# context = context.to(x.dtype)
|
||||
|
||||
# if hasattr(self, "hypernetwork") and self.hypernetwork is not None:
|
||||
# context_k, context_v = self.hypernetwork.forward(x, context)
|
||||
# context_k = context_k.to(x.dtype)
|
||||
# context_v = context_v.to(x.dtype)
|
||||
# else:
|
||||
# context_k = context
|
||||
# context_v = context
|
||||
|
||||
# k_in = self.to_k(context_k)
|
||||
# v_in = self.to_v(context_v)
|
||||
|
||||
# q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in))
|
||||
# del q_in, k_in, v_in
|
||||
|
||||
# q = q.contiguous()
|
||||
# k = k.contiguous()
|
||||
# v = v.contiguous()
|
||||
# out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
|
||||
|
||||
# out = rearrange(out, "b n h d -> b n (h d)", h=h)
|
||||
|
||||
# # diffusers 0.7.0~
|
||||
# out = self.to_out[0](out)
|
||||
# out = self.to_out[1](out)
|
||||
# return out
|
||||
|
||||
# diffusers.models.attention.CrossAttention.forward = forward_xformers
|
||||
def replace_unet_modules(unet:UNet2DConditionModel, mem_eff_attn, xformers, sdpa):
|
||||
if mem_eff_attn:
|
||||
unet.set_attn_processor(FlashAttnProcessor())
|
||||
print("Enable memory efficient attention for U-Net")
|
||||
unet.set_use_memory_efficient_attention(False, True)
|
||||
elif xformers:
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
|
||||
def replace_unet_cross_attn_to_xformers():
|
||||
print("CrossAttention.forward has been replaced to enable xformers.")
|
||||
try:
|
||||
import xformers.ops
|
||||
except ImportError:
|
||||
raise ImportError("No xformers / xformersがインストールされていないようです")
|
||||
|
||||
def forward_xformers(self, x, context=None, mask=None):
|
||||
h = self.heads
|
||||
q_in = self.to_q(x)
|
||||
|
||||
context = default(context, x)
|
||||
context = context.to(x.dtype)
|
||||
|
||||
if hasattr(self, "hypernetwork") and self.hypernetwork is not None:
|
||||
context_k, context_v = self.hypernetwork.forward(x, context)
|
||||
context_k = context_k.to(x.dtype)
|
||||
context_v = context_v.to(x.dtype)
|
||||
else:
|
||||
context_k = context
|
||||
context_v = context
|
||||
|
||||
k_in = self.to_k(context_k)
|
||||
v_in = self.to_v(context_v)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in))
|
||||
del q_in, k_in, v_in
|
||||
|
||||
q = q.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
|
||||
|
||||
out = rearrange(out, "b n h d -> b n (h d)", h=h)
|
||||
|
||||
# diffusers 0.7.0~
|
||||
out = self.to_out[0](out)
|
||||
out = self.to_out[1](out)
|
||||
return out
|
||||
|
||||
diffusers.models.attention.CrossAttention.forward = forward_xformers
|
||||
print("Enable xformers for U-Net")
|
||||
try:
|
||||
import xformers.ops
|
||||
except ImportError:
|
||||
raise ImportError("No xformers / xformersがインストールされていないようです")
|
||||
|
||||
unet.set_use_memory_efficient_attention(True, False)
|
||||
elif sdpa:
|
||||
print("Enable SDPA for U-Net")
|
||||
unet.set_use_sdpa(True)
|
||||
|
||||
"""
|
||||
def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers):
|
||||
@@ -2242,6 +2327,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う",
|
||||
)
|
||||
parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
|
||||
parser.add_argument("--sdpa", action="store_true", help="use sdpa for CrossAttention (requires PyTorch 2.0) / CrossAttentionにsdpaを使う(PyTorch 2.0が必要)")
|
||||
parser.add_argument(
|
||||
"--vae", type=str, default=None, help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ"
|
||||
)
|
||||
@@ -2428,6 +2514,11 @@ def verify_training_args(args: argparse.Namespace):
|
||||
if args.adaptive_noise_scale is not None and args.noise_offset is None:
|
||||
raise ValueError("adaptive_noise_scale requires noise_offset / adaptive_noise_scaleを使用するにはnoise_offsetが必要です")
|
||||
|
||||
if args.scale_v_pred_loss_like_noise_pred and not args.v_parameterization:
|
||||
raise ValueError(
|
||||
"scale_v_pred_loss_like_noise_pred can be enabled only with v_parameterization / scale_v_pred_loss_like_noise_predはv_parameterizationが有効なときのみ有効にできます"
|
||||
)
|
||||
|
||||
|
||||
def add_dataset_arguments(
|
||||
parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool
|
||||
@@ -2506,7 +2597,6 @@ def add_dataset_arguments(
|
||||
default=1,
|
||||
help="start learning at N tags (token means comma separated strinfloatgs) / タグ数をN個から増やしながら学習する",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--token_warmup_step",
|
||||
type=float,
|
||||
@@ -2514,6 +2604,13 @@ def add_dataset_arguments(
|
||||
help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / N(N<1ならN*max_train_steps)ステップでタグ長が最大になる。デフォルトは0(最初から最大)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dataset_class",
|
||||
type=str,
|
||||
default=None,
|
||||
help="dataset class for arbitrary dataset (package.module.Class) / 任意のデータセットを用いるときのクラス名 (package.module.Class)",
|
||||
)
|
||||
|
||||
if support_caption_dropout:
|
||||
# Textual Inversion はcaptionのdropoutをsupportしない
|
||||
# いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
|
||||
@@ -2788,15 +2885,7 @@ def get_optimizer(args, trainable_params):
|
||||
optimizer_class = torch.optim.SGD
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type.startswith("DAdapt".lower()):
|
||||
# DAdaptation family
|
||||
# check dadaptation is installed
|
||||
try:
|
||||
import dadaptation
|
||||
import dadaptation.experimental as experimental
|
||||
except ImportError:
|
||||
raise ImportError("No dadaptation / dadaptation がインストールされていないようです")
|
||||
|
||||
elif optimizer_type.startswith("DAdapt".lower()) or optimizer_type == "Prodigy".lower():
|
||||
# check lr and lr_count, and print warning
|
||||
actual_lr = lr
|
||||
lr_count = 1
|
||||
@@ -2809,40 +2898,60 @@ def get_optimizer(args, trainable_params):
|
||||
|
||||
if actual_lr <= 0.1:
|
||||
print(
|
||||
f"learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: lr={actual_lr}"
|
||||
f"learning rate is too low. If using D-Adaptation or Prodigy, set learning rate around 1.0 / 学習率が低すぎるようです。D-AdaptationまたはProdigyの使用時は1.0前後の値を指定してください: lr={actual_lr}"
|
||||
)
|
||||
print("recommend option: lr=1.0 / 推奨は1.0です")
|
||||
if lr_count > 1:
|
||||
print(
|
||||
f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-Adaptationで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}"
|
||||
f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-AdaptationまたはProdigyで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}"
|
||||
)
|
||||
|
||||
# set optimizer
|
||||
if optimizer_type == "DAdaptation".lower() or optimizer_type == "DAdaptAdamPreprint".lower():
|
||||
optimizer_class = experimental.DAdaptAdamPreprint
|
||||
print(f"use D-Adaptation AdamPreprint optimizer | {optimizer_kwargs}")
|
||||
elif optimizer_type == "DAdaptAdaGrad".lower():
|
||||
optimizer_class = dadaptation.DAdaptAdaGrad
|
||||
print(f"use D-Adaptation AdaGrad optimizer | {optimizer_kwargs}")
|
||||
elif optimizer_type == "DAdaptAdam".lower():
|
||||
optimizer_class = dadaptation.DAdaptAdam
|
||||
print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}")
|
||||
elif optimizer_type == "DAdaptAdan".lower():
|
||||
optimizer_class = dadaptation.DAdaptAdan
|
||||
print(f"use D-Adaptation Adan optimizer | {optimizer_kwargs}")
|
||||
elif optimizer_type == "DAdaptAdanIP".lower():
|
||||
optimizer_class = experimental.DAdaptAdanIP
|
||||
print(f"use D-Adaptation AdanIP optimizer | {optimizer_kwargs}")
|
||||
elif optimizer_type == "DAdaptLion".lower():
|
||||
optimizer_class = dadaptation.DAdaptLion
|
||||
print(f"use D-Adaptation Lion optimizer | {optimizer_kwargs}")
|
||||
elif optimizer_type == "DAdaptSGD".lower():
|
||||
optimizer_class = dadaptation.DAdaptSGD
|
||||
print(f"use D-Adaptation SGD optimizer | {optimizer_kwargs}")
|
||||
else:
|
||||
raise ValueError(f"Unknown optimizer type: {optimizer_type}")
|
||||
if optimizer_type.startswith("DAdapt".lower()):
|
||||
# DAdaptation family
|
||||
# check dadaptation is installed
|
||||
try:
|
||||
import dadaptation
|
||||
import dadaptation.experimental as experimental
|
||||
except ImportError:
|
||||
raise ImportError("No dadaptation / dadaptation がインストールされていないようです")
|
||||
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
# set optimizer
|
||||
if optimizer_type == "DAdaptation".lower() or optimizer_type == "DAdaptAdamPreprint".lower():
|
||||
optimizer_class = experimental.DAdaptAdamPreprint
|
||||
print(f"use D-Adaptation AdamPreprint optimizer | {optimizer_kwargs}")
|
||||
elif optimizer_type == "DAdaptAdaGrad".lower():
|
||||
optimizer_class = dadaptation.DAdaptAdaGrad
|
||||
print(f"use D-Adaptation AdaGrad optimizer | {optimizer_kwargs}")
|
||||
elif optimizer_type == "DAdaptAdam".lower():
|
||||
optimizer_class = dadaptation.DAdaptAdam
|
||||
print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}")
|
||||
elif optimizer_type == "DAdaptAdan".lower():
|
||||
optimizer_class = dadaptation.DAdaptAdan
|
||||
print(f"use D-Adaptation Adan optimizer | {optimizer_kwargs}")
|
||||
elif optimizer_type == "DAdaptAdanIP".lower():
|
||||
optimizer_class = experimental.DAdaptAdanIP
|
||||
print(f"use D-Adaptation AdanIP optimizer | {optimizer_kwargs}")
|
||||
elif optimizer_type == "DAdaptLion".lower():
|
||||
optimizer_class = dadaptation.DAdaptLion
|
||||
print(f"use D-Adaptation Lion optimizer | {optimizer_kwargs}")
|
||||
elif optimizer_type == "DAdaptSGD".lower():
|
||||
optimizer_class = dadaptation.DAdaptSGD
|
||||
print(f"use D-Adaptation SGD optimizer | {optimizer_kwargs}")
|
||||
else:
|
||||
raise ValueError(f"Unknown optimizer type: {optimizer_type}")
|
||||
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
else:
|
||||
# Prodigy
|
||||
# check Prodigy is installed
|
||||
try:
|
||||
import prodigyopt
|
||||
except ImportError:
|
||||
raise ImportError("No Prodigy / Prodigy がインストールされていないようです")
|
||||
|
||||
print(f"use Prodigy optimizer | {optimizer_kwargs}")
|
||||
optimizer_class = prodigyopt.Prodigy
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type == "Adafactor".lower():
|
||||
# 引数を確認して適宜補正する
|
||||
@@ -3093,23 +3202,9 @@ def prepare_accelerator(args: argparse.Namespace):
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
mixed_precision=args.mixed_precision,
|
||||
log_with=log_with,
|
||||
logging_dir=logging_dir,
|
||||
project_dir=logging_dir,
|
||||
)
|
||||
|
||||
# accelerateの互換性問題を解決する
|
||||
accelerator_0_15 = True
|
||||
try:
|
||||
accelerator.unwrap_model("dummy", True)
|
||||
print("Using accelerator 0.15.0 or above.")
|
||||
except TypeError:
|
||||
accelerator_0_15 = False
|
||||
|
||||
def unwrap_model(model):
|
||||
if accelerator_0_15:
|
||||
return accelerator.unwrap_model(model, True)
|
||||
return accelerator.unwrap_model(model)
|
||||
|
||||
return accelerator, unwrap_model
|
||||
return accelerator
|
||||
|
||||
|
||||
def prepare_dtype(args: argparse.Namespace):
|
||||
@@ -3146,11 +3241,26 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une
|
||||
print(
|
||||
f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}"
|
||||
)
|
||||
raise ex
|
||||
text_encoder = pipe.text_encoder
|
||||
vae = pipe.vae
|
||||
unet = pipe.unet
|
||||
del pipe
|
||||
|
||||
# Diffusers U-Net to original U-Net
|
||||
# TODO *.ckpt/*.safetensorsのv2と同じ形式にここで変換すると良さそう
|
||||
# print(f"unet config: {unet.config}")
|
||||
original_unet = UNet2DConditionModel(
|
||||
unet.config.sample_size,
|
||||
unet.config.attention_head_dim,
|
||||
unet.config.cross_attention_dim,
|
||||
unet.config.use_linear_projection,
|
||||
unet.config.upcast_attention,
|
||||
)
|
||||
original_unet.load_state_dict(unet.state_dict())
|
||||
unet = original_unet
|
||||
print("U-Net converted to original U-Net")
|
||||
|
||||
# VAEを読み込む
|
||||
if args.vae is not None:
|
||||
vae = model_util.load_vae(args.vae, weight_dtype)
|
||||
@@ -3580,6 +3690,7 @@ def sample_images(
|
||||
requires_safety_checker=False,
|
||||
clip_skip=args.clip_skip,
|
||||
)
|
||||
pipeline.clip_skip = args.clip_skip # Pipelineのコンストラクタにckip_skipを追加できないので後から設定する
|
||||
pipeline.to(device)
|
||||
|
||||
save_dir = args.output_dir + "/sample"
|
||||
@@ -3769,4 +3880,4 @@ class collater_class:
|
||||
# set epoch and step
|
||||
dataset.set_current_epoch(self.current_epoch.value)
|
||||
dataset.set_current_step(self.current_step.value)
|
||||
return examples[0]
|
||||
return examples[0]
|
||||
|
||||
Reference in New Issue
Block a user