update diffusers to 1.16 | train_network

This commit is contained in:
ddPn08
2023-06-01 10:29:48 +09:00
parent f8e8df5a04
commit c8d209d36c
7 changed files with 482 additions and 228 deletions

View File

@@ -0,0 +1,227 @@
import math
from typing import Any
from einops import rearrange
import torch
from diffusers.models.attention_processor import Attention
# flash attention forwards and backwards
# https://arxiv.org/abs/2205.14135
EPSILON = 1e-6
class FlashAttentionFunction(torch.autograd.function.Function):
@staticmethod
@torch.no_grad()
def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
"""Algorithm 2 in the paper"""
device = q.device
dtype = q.dtype
max_neg_value = -torch.finfo(q.dtype).max
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
o = torch.zeros_like(q)
all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
all_row_maxes = torch.full(
(*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device
)
scale = q.shape[-1] ** -0.5
if mask is None:
mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
else:
mask = rearrange(mask, "b n -> b 1 1 n")
mask = mask.split(q_bucket_size, dim=-1)
row_splits = zip(
q.split(q_bucket_size, dim=-2),
o.split(q_bucket_size, dim=-2),
mask,
all_row_sums.split(q_bucket_size, dim=-2),
all_row_maxes.split(q_bucket_size, dim=-2),
)
for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
q_start_index = ind * q_bucket_size - qk_len_diff
col_splits = zip(
k.split(k_bucket_size, dim=-2),
v.split(k_bucket_size, dim=-2),
)
for k_ind, (kc, vc) in enumerate(col_splits):
k_start_index = k_ind * k_bucket_size
attn_weights = (
torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
)
if row_mask is not None:
attn_weights.masked_fill_(~row_mask, max_neg_value)
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
causal_mask = torch.ones(
(qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device
).triu(q_start_index - k_start_index + 1)
attn_weights.masked_fill_(causal_mask, max_neg_value)
block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
attn_weights -= block_row_maxes
exp_weights = torch.exp(attn_weights)
if row_mask is not None:
exp_weights.masked_fill_(~row_mask, 0.0)
block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(
min=EPSILON
)
new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
exp_values = torch.einsum(
"... i j, ... j d -> ... i d", exp_weights, vc
)
exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
new_row_sums = (
exp_row_max_diff * row_sums
+ exp_block_row_max_diff * block_row_sums
)
oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_(
(exp_block_row_max_diff / new_row_sums) * exp_values
)
row_maxes.copy_(new_row_maxes)
row_sums.copy_(new_row_sums)
ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
return o
@staticmethod
@torch.no_grad()
def backward(ctx, do):
"""Algorithm 4 in the paper"""
causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
q, k, v, o, l, m = ctx.saved_tensors
device = q.device
max_neg_value = -torch.finfo(q.dtype).max
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
dq = torch.zeros_like(q)
dk = torch.zeros_like(k)
dv = torch.zeros_like(v)
row_splits = zip(
q.split(q_bucket_size, dim=-2),
o.split(q_bucket_size, dim=-2),
do.split(q_bucket_size, dim=-2),
mask,
l.split(q_bucket_size, dim=-2),
m.split(q_bucket_size, dim=-2),
dq.split(q_bucket_size, dim=-2),
)
for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
q_start_index = ind * q_bucket_size - qk_len_diff
col_splits = zip(
k.split(k_bucket_size, dim=-2),
v.split(k_bucket_size, dim=-2),
dk.split(k_bucket_size, dim=-2),
dv.split(k_bucket_size, dim=-2),
)
for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
k_start_index = k_ind * k_bucket_size
attn_weights = (
torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
)
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
causal_mask = torch.ones(
(qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device
).triu(q_start_index - k_start_index + 1)
attn_weights.masked_fill_(causal_mask, max_neg_value)
exp_attn_weights = torch.exp(attn_weights - mc)
if row_mask is not None:
exp_attn_weights.masked_fill_(~row_mask, 0.0)
p = exp_attn_weights / lc
dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc)
dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc)
D = (doc * oc).sum(dim=-1, keepdims=True)
ds = p * scale * (dp - D)
dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc)
dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc)
dqc.add_(dq_chunk)
dkc.add_(dk_chunk)
dvc.add_(dv_chunk)
return dq, dk, dv, None, None, None, None
class FlashAttnProcessor:
def __call__(
self,
attn: Attention,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
) -> Any:
q_bucket_size = 512
k_bucket_size = 1024
h = attn.heads
q = attn.to_q(hidden_states)
encoder_hidden_states = (
encoder_hidden_states
if encoder_hidden_states is not None
else hidden_states
)
encoder_hidden_states = encoder_hidden_states.to(hidden_states.dtype)
if hasattr(attn, "hypernetwork") and attn.hypernetwork is not None:
context_k, context_v = attn.hypernetwork.forward(
hidden_states, encoder_hidden_states
)
context_k = context_k.to(hidden_states.dtype)
context_v = context_v.to(hidden_states.dtype)
else:
context_k = encoder_hidden_states
context_v = encoder_hidden_states
k = attn.to_k(context_k)
v = attn.to_v(context_v)
del encoder_hidden_states, hidden_states
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
out = FlashAttentionFunction.apply(
q, k, v, attention_mask, False, q_bucket_size, k_bucket_size
)
out = rearrange(out, "b h n d -> b n (h d)")
out = attn.to_out[0](out)
out = attn.to_out[1](out)
return out

223
library/hypernetwork.py Normal file
View File

@@ -0,0 +1,223 @@
import torch
import torch.nn.functional as F
from diffusers.models.attention_processor import (
Attention,
AttnProcessor2_0,
SlicedAttnProcessor,
XFormersAttnProcessor
)
try:
import xformers.ops
except:
xformers = None
loaded_networks = []
def apply_single_hypernetwork(
hypernetwork, hidden_states, encoder_hidden_states
):
context_k, context_v = hypernetwork.forward(hidden_states, encoder_hidden_states)
return context_k, context_v
def apply_hypernetworks(context_k, context_v, layer=None):
if len(loaded_networks) == 0:
return context_v, context_v
for hypernetwork in loaded_networks:
context_k, context_v = hypernetwork.forward(context_k, context_v)
context_k = context_k.to(dtype=context_k.dtype)
context_v = context_v.to(dtype=context_k.dtype)
return context_k, context_v
def xformers_forward(
self: XFormersAttnProcessor,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
attention_mask: torch.Tensor = None,
):
batch_size, sequence_length, _ = (
hidden_states.shape
if encoder_hidden_states is None
else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(
attention_mask, sequence_length, batch_size
)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)
key = attn.to_k(context_k)
value = attn.to_v(context_v)
query = attn.head_to_batch_dim(query).contiguous()
key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous()
hidden_states = xformers.ops.memory_efficient_attention(
query,
key,
value,
attn_bias=attention_mask,
op=self.attention_op,
scale=attn.scale,
)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
def sliced_attn_forward(
self: SlicedAttnProcessor,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
attention_mask: torch.Tensor = None,
):
batch_size, sequence_length, _ = (
hidden_states.shape
if encoder_hidden_states is None
else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(
attention_mask, sequence_length, batch_size
)
query = attn.to_q(hidden_states)
dim = query.shape[-1]
query = attn.head_to_batch_dim(query)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)
key = attn.to_k(context_k)
value = attn.to_v(context_v)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
batch_size_attention, query_tokens, _ = query.shape
hidden_states = torch.zeros(
(batch_size_attention, query_tokens, dim // attn.heads),
device=query.device,
dtype=query.dtype,
)
for i in range(batch_size_attention // self.slice_size):
start_idx = i * self.slice_size
end_idx = (i + 1) * self.slice_size
query_slice = query[start_idx:end_idx]
key_slice = key[start_idx:end_idx]
attn_mask_slice = (
attention_mask[start_idx:end_idx] if attention_mask is not None else None
)
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
def v2_0_forward(
self: AttnProcessor2_0,
attn: Attention,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
):
batch_size, sequence_length, _ = (
hidden_states.shape
if encoder_hidden_states is None
else encoder_hidden_states.shape
)
inner_dim = hidden_states.shape[-1]
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(
attention_mask, sequence_length, batch_size
)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(
batch_size, attn.heads, -1, attention_mask.shape[-1]
)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)
key = attn.to_k(context_k)
value = attn.to_v(context_v)
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
def replace_attentions_for_hypernetwork():
import diffusers.models.attention_processor
diffusers.models.attention_processor.XFormersAttnProcessor.__call__ = (
xformers_forward
)
diffusers.models.attention_processor.SlicedAttnProcessor.__call__ = (
sliced_attn_forward
)
diffusers.models.attention_processor.AttnProcessor2_0.__call__ = v2_0_forward

View File

@@ -464,10 +464,10 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
tokenizer: CLIPTokenizer, tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
scheduler: SchedulerMixin, scheduler: SchedulerMixin,
clip_skip: int,
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
clip_skip: int = 1,
): ):
super().__init__( super().__init__(
vae=vae, vae=vae,

View File

@@ -63,6 +63,8 @@ import safetensors.torch
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
import library.model_util as model_util import library.model_util as model_util
import library.huggingface_util as huggingface_util import library.huggingface_util as huggingface_util
from library.attention_processors import FlashAttnProcessor
from library.hypernetwork import replace_attentions_for_hypernetwork
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う # Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
TOKENIZER_PATH = "openai/clip-vit-large-patch14" TOKENIZER_PATH = "openai/clip-vit-large-patch14"
@@ -1630,209 +1632,14 @@ def get_git_revision_hash() -> str:
return "(unknown)" return "(unknown)"
# flash attention forwards and backwards
# https://arxiv.org/abs/2205.14135
class FlashAttentionFunction(torch.autograd.function.Function):
@staticmethod
@torch.no_grad()
def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
"""Algorithm 2 in the paper"""
device = q.device
dtype = q.dtype
max_neg_value = -torch.finfo(q.dtype).max
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
o = torch.zeros_like(q)
all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device)
scale = q.shape[-1] ** -0.5
if not exists(mask):
mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
else:
mask = rearrange(mask, "b n -> b 1 1 n")
mask = mask.split(q_bucket_size, dim=-1)
row_splits = zip(
q.split(q_bucket_size, dim=-2),
o.split(q_bucket_size, dim=-2),
mask,
all_row_sums.split(q_bucket_size, dim=-2),
all_row_maxes.split(q_bucket_size, dim=-2),
)
for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
q_start_index = ind * q_bucket_size - qk_len_diff
col_splits = zip(
k.split(k_bucket_size, dim=-2),
v.split(k_bucket_size, dim=-2),
)
for k_ind, (kc, vc) in enumerate(col_splits):
k_start_index = k_ind * k_bucket_size
attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale
if exists(row_mask):
attn_weights.masked_fill_(~row_mask, max_neg_value)
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
q_start_index - k_start_index + 1
)
attn_weights.masked_fill_(causal_mask, max_neg_value)
block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
attn_weights -= block_row_maxes
exp_weights = torch.exp(attn_weights)
if exists(row_mask):
exp_weights.masked_fill_(~row_mask, 0.0)
block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON)
new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
exp_values = einsum("... i j, ... j d -> ... i d", exp_weights, vc)
exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values)
row_maxes.copy_(new_row_maxes)
row_sums.copy_(new_row_sums)
ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
return o
@staticmethod
@torch.no_grad()
def backward(ctx, do):
"""Algorithm 4 in the paper"""
causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
q, k, v, o, l, m = ctx.saved_tensors
device = q.device
max_neg_value = -torch.finfo(q.dtype).max
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
dq = torch.zeros_like(q)
dk = torch.zeros_like(k)
dv = torch.zeros_like(v)
row_splits = zip(
q.split(q_bucket_size, dim=-2),
o.split(q_bucket_size, dim=-2),
do.split(q_bucket_size, dim=-2),
mask,
l.split(q_bucket_size, dim=-2),
m.split(q_bucket_size, dim=-2),
dq.split(q_bucket_size, dim=-2),
)
for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
q_start_index = ind * q_bucket_size - qk_len_diff
col_splits = zip(
k.split(k_bucket_size, dim=-2),
v.split(k_bucket_size, dim=-2),
dk.split(k_bucket_size, dim=-2),
dv.split(k_bucket_size, dim=-2),
)
for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
k_start_index = k_ind * k_bucket_size
attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
q_start_index - k_start_index + 1
)
attn_weights.masked_fill_(causal_mask, max_neg_value)
exp_attn_weights = torch.exp(attn_weights - mc)
if exists(row_mask):
exp_attn_weights.masked_fill_(~row_mask, 0.0)
p = exp_attn_weights / lc
dv_chunk = einsum("... i j, ... i d -> ... j d", p, doc)
dp = einsum("... i d, ... j d -> ... i j", doc, vc)
D = (doc * oc).sum(dim=-1, keepdims=True)
ds = p * scale * (dp - D)
dq_chunk = einsum("... i j, ... j d -> ... i d", ds, kc)
dk_chunk = einsum("... i j, ... i d -> ... j d", ds, qc)
dqc.add_(dq_chunk)
dkc.add_(dk_chunk)
dvc.add_(dv_chunk)
return dq, dk, dv, None, None, None, None
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers): def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
replace_attentions_for_hypernetwork()
# unet is not used currently, but it is here for future use # unet is not used currently, but it is here for future use
if mem_eff_attn: if mem_eff_attn:
replace_unet_cross_attn_to_memory_efficient() unet.set_attn_processor(FlashAttnProcessor())
elif xformers: elif xformers:
replace_unet_cross_attn_to_xformers() unet.enable_xformers_memory_efficient_attention()
def replace_unet_cross_attn_to_memory_efficient():
print("CrossAttention.forward has been replaced to FlashAttention (not xformers)")
flash_func = FlashAttentionFunction
def forward_flash_attn(self, x, context=None, mask=None):
q_bucket_size = 512
k_bucket_size = 1024
h = self.heads
q = self.to_q(x)
context = context if context is not None else x
context = context.to(x.dtype)
if hasattr(self, "hypernetwork") and self.hypernetwork is not None:
context_k, context_v = self.hypernetwork.forward(x, context)
context_k = context_k.to(x.dtype)
context_v = context_v.to(x.dtype)
else:
context_k = context
context_v = context
k = self.to_k(context_k)
v = self.to_v(context_v)
del context, x
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size)
out = rearrange(out, "b h n d -> b n (h d)")
# diffusers 0.7.0~ わざわざ変えるなよ (;´Д`)
out = self.to_out[0](out)
out = self.to_out[1](out)
return out
diffusers.models.attention.CrossAttention.forward = forward_flash_attn
def replace_unet_cross_attn_to_xformers(): def replace_unet_cross_attn_to_xformers():
@@ -3458,10 +3265,10 @@ def sample_images(
unet=unet, unet=unet,
tokenizer=tokenizer, tokenizer=tokenizer,
scheduler=scheduler, scheduler=scheduler,
clip_skip=args.clip_skip,
safety_checker=None, safety_checker=None,
feature_extractor=None, feature_extractor=None,
requires_safety_checker=False, requires_safety_checker=False,
clip_skip=args.clip_skip,
) )
pipeline.to(device) pipeline.to(device)

View File

@@ -665,7 +665,7 @@ class LoRANetwork(torch.nn.Module):
NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数 NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数
# is it possible to apply conv_in and conv_out? -> yes, newer LoCon supports it (^^;) # is it possible to apply conv_in and conv_out? -> yes, newer LoCon supports it (^^;)
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"] UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
LORA_PREFIX_UNET = "lora_unet" LORA_PREFIX_UNET = "lora_unet"

View File

@@ -1,10 +1,10 @@
accelerate==0.15.0 accelerate==0.19.0
transformers==4.26.0 transformers==4.29.2
diffusers[torch]==0.16.1
ftfy==6.1.1 ftfy==6.1.1
albumentations==1.3.0 albumentations==1.3.0
opencv-python==4.7.0.68 opencv-python==4.7.0.68
einops==0.6.0 einops==0.6.0
diffusers[torch]==0.10.2
pytorch-lightning==1.9.0 pytorch-lightning==1.9.0
bitsandbytes==0.35.0 bitsandbytes==0.35.0
tensorboard==2.10.1 tensorboard==2.10.1

View File

@@ -6,7 +6,6 @@ import os
import random import random
import time import time
import json import json
import toml
from multiprocessing import Value from multiprocessing import Value
from tqdm import tqdm from tqdm import tqdm
@@ -165,7 +164,7 @@ def train(args):
import sys import sys
sys.path.append(os.path.dirname(__file__)) sys.path.append(os.path.dirname(__file__))
print("import network module:", args.network_module) accelerator.print("import network module:", args.network_module)
network_module = importlib.import_module(args.network_module) network_module = importlib.import_module(args.network_module)
if args.base_weights is not None: if args.base_weights is not None:
@@ -176,14 +175,15 @@ def train(args):
else: else:
multiplier = args.base_weights_multiplier[i] multiplier = args.base_weights_multiplier[i]
print(f"merging module: {weight_path} with multiplier {multiplier}") accelerator.print(f"merging module: {weight_path} with multiplier {multiplier}")
module, weights_sd = network_module.create_network_from_weights( module, weights_sd = network_module.create_network_from_weights(
multiplier, weight_path, vae, text_encoder, unet, for_inference=True multiplier, weight_path, vae, text_encoder, unet, for_inference=True
) )
module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu") module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu")
print(f"all weights merged: {', '.join(args.base_weights)}") accelerator.print(f"all weights merged: {', '.join(args.base_weights)}")
# 学習を準備する # 学習を準備する
if cache_latents: if cache_latents:
vae.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype)
@@ -225,7 +225,7 @@ def train(args):
if args.network_weights is not None: if args.network_weights is not None:
info = network.load_weights(args.network_weights) info = network.load_weights(args.network_weights)
print(f"loaded network weights from {args.network_weights}: {info}") accelerator.print(f"load network weights from {args.network_weights}: {info}")
if args.gradient_checkpointing: if args.gradient_checkpointing:
unet.enable_gradient_checkpointing() unet.enable_gradient_checkpointing()
@@ -233,13 +233,13 @@ def train(args):
network.enable_gradient_checkpointing() # may have no effect network.enable_gradient_checkpointing() # may have no effect
# 学習に必要なクラスを準備する # 学習に必要なクラスを準備する
print("preparing optimizer, data loader etc.") accelerator.print("prepare optimizer, data loader etc.")
# 後方互換性を確保するよ # 後方互換性を確保するよ
try: try:
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate) trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate)
except TypeError: except TypeError:
print( accelerator.print(
"Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)" "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)"
) )
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr) trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
@@ -264,8 +264,7 @@ def train(args):
args.max_train_steps = args.max_train_epochs * math.ceil( args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
) )
if is_main_process: accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# データセット側にも学習ステップを送信 # データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps) train_dataset_group.set_max_train_steps(args.max_train_steps)
@@ -278,7 +277,7 @@ def train(args):
assert ( assert (
args.mixed_precision == "fp16" args.mixed_precision == "fp16"
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
print("enabling full fp16 training.") accelerator.print("enable full fp16 training.")
network.to(weight_dtype) network.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい # acceleratorがなんかよろしくやってくれるらしい
@@ -338,16 +337,15 @@ def train(args):
# TODO: find a way to handle total batch size when there are multiple datasets # TODO: find a way to handle total batch size when there are multiple datasets
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
if is_main_process: accelerator.print("running training / 学習開始")
print("running training / 学習開始") accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
print(f" num epochs / epoch数: {num_train_epochs}") accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") # accelerator.print(f" total train batch size (with parallel & distributed & accumulation) / バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
# print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") accelerator.print(f" gradient accumulation steps / 勾配合計するステップ数 = {args.gradient_accumulation_steps}")
print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
# TODO refactor metadata creation and move to util # TODO refactor metadata creation and move to util
metadata = { metadata = {
@@ -572,7 +570,7 @@ def train(args):
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
ckpt_file = os.path.join(args.output_dir, ckpt_name) ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"\nsaving checkpoint: {ckpt_file}") accelerator.print(f"\nsaving checkpoint: {ckpt_file}")
metadata["ss_training_finished_at"] = str(time.time()) metadata["ss_training_finished_at"] = str(time.time())
metadata["ss_steps"] = str(steps) metadata["ss_steps"] = str(steps)
metadata["ss_epoch"] = str(epoch_no) metadata["ss_epoch"] = str(epoch_no)
@@ -584,13 +582,12 @@ def train(args):
def remove_model(old_ckpt_name): def remove_model(old_ckpt_name):
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
if os.path.exists(old_ckpt_file): if os.path.exists(old_ckpt_file):
print(f"removing old checkpoint: {old_ckpt_file}") accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file) os.remove(old_ckpt_file)
# training loop # training loop
for epoch in range(num_train_epochs): for epoch in range(num_train_epochs):
if is_main_process: accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1 current_epoch.value = epoch + 1
metadata["ss_epoch"] = str(epoch + 1) metadata["ss_epoch"] = str(epoch + 1)