diff --git a/library/attention_processors.py b/library/attention_processors.py new file mode 100644 index 00000000..310c2cb1 --- /dev/null +++ b/library/attention_processors.py @@ -0,0 +1,227 @@ +import math +from typing import Any +from einops import rearrange +import torch +from diffusers.models.attention_processor import Attention + + +# flash attention forwards and backwards + +# https://arxiv.org/abs/2205.14135 + +EPSILON = 1e-6 + + +class FlashAttentionFunction(torch.autograd.function.Function): + @staticmethod + @torch.no_grad() + def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): + """Algorithm 2 in the paper""" + + device = q.device + dtype = q.dtype + max_neg_value = -torch.finfo(q.dtype).max + qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) + + o = torch.zeros_like(q) + all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device) + all_row_maxes = torch.full( + (*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device + ) + + scale = q.shape[-1] ** -0.5 + + if mask is None: + mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) + else: + mask = rearrange(mask, "b n -> b 1 1 n") + mask = mask.split(q_bucket_size, dim=-1) + + row_splits = zip( + q.split(q_bucket_size, dim=-2), + o.split(q_bucket_size, dim=-2), + mask, + all_row_sums.split(q_bucket_size, dim=-2), + all_row_maxes.split(q_bucket_size, dim=-2), + ) + + for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): + q_start_index = ind * q_bucket_size - qk_len_diff + + col_splits = zip( + k.split(k_bucket_size, dim=-2), + v.split(k_bucket_size, dim=-2), + ) + + for k_ind, (kc, vc) in enumerate(col_splits): + k_start_index = k_ind * k_bucket_size + + attn_weights = ( + torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale + ) + + if row_mask is not None: + attn_weights.masked_fill_(~row_mask, max_neg_value) + + if causal and q_start_index < (k_start_index + k_bucket_size - 1): + causal_mask = torch.ones( + (qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device + ).triu(q_start_index - k_start_index + 1) + attn_weights.masked_fill_(causal_mask, max_neg_value) + + block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) + attn_weights -= block_row_maxes + exp_weights = torch.exp(attn_weights) + + if row_mask is not None: + exp_weights.masked_fill_(~row_mask, 0.0) + + block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp( + min=EPSILON + ) + + new_row_maxes = torch.maximum(block_row_maxes, row_maxes) + + exp_values = torch.einsum( + "... i j, ... j d -> ... i d", exp_weights, vc + ) + + exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) + exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) + + new_row_sums = ( + exp_row_max_diff * row_sums + + exp_block_row_max_diff * block_row_sums + ) + + oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_( + (exp_block_row_max_diff / new_row_sums) * exp_values + ) + + row_maxes.copy_(new_row_maxes) + row_sums.copy_(new_row_sums) + + ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) + ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) + + return o + + @staticmethod + @torch.no_grad() + def backward(ctx, do): + """Algorithm 4 in the paper""" + + causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args + q, k, v, o, l, m = ctx.saved_tensors + + device = q.device + + max_neg_value = -torch.finfo(q.dtype).max + qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) + + dq = torch.zeros_like(q) + dk = torch.zeros_like(k) + dv = torch.zeros_like(v) + + row_splits = zip( + q.split(q_bucket_size, dim=-2), + o.split(q_bucket_size, dim=-2), + do.split(q_bucket_size, dim=-2), + mask, + l.split(q_bucket_size, dim=-2), + m.split(q_bucket_size, dim=-2), + dq.split(q_bucket_size, dim=-2), + ) + + for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): + q_start_index = ind * q_bucket_size - qk_len_diff + + col_splits = zip( + k.split(k_bucket_size, dim=-2), + v.split(k_bucket_size, dim=-2), + dk.split(k_bucket_size, dim=-2), + dv.split(k_bucket_size, dim=-2), + ) + + for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): + k_start_index = k_ind * k_bucket_size + + attn_weights = ( + torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale + ) + + if causal and q_start_index < (k_start_index + k_bucket_size - 1): + causal_mask = torch.ones( + (qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device + ).triu(q_start_index - k_start_index + 1) + attn_weights.masked_fill_(causal_mask, max_neg_value) + + exp_attn_weights = torch.exp(attn_weights - mc) + + if row_mask is not None: + exp_attn_weights.masked_fill_(~row_mask, 0.0) + + p = exp_attn_weights / lc + + dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc) + dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc) + + D = (doc * oc).sum(dim=-1, keepdims=True) + ds = p * scale * (dp - D) + + dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc) + dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc) + + dqc.add_(dq_chunk) + dkc.add_(dk_chunk) + dvc.add_(dv_chunk) + + return dq, dk, dv, None, None, None, None + + +class FlashAttnProcessor: + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + ) -> Any: + q_bucket_size = 512 + k_bucket_size = 1024 + + h = attn.heads + q = attn.to_q(hidden_states) + + encoder_hidden_states = ( + encoder_hidden_states + if encoder_hidden_states is not None + else hidden_states + ) + encoder_hidden_states = encoder_hidden_states.to(hidden_states.dtype) + + if hasattr(attn, "hypernetwork") and attn.hypernetwork is not None: + context_k, context_v = attn.hypernetwork.forward( + hidden_states, encoder_hidden_states + ) + context_k = context_k.to(hidden_states.dtype) + context_v = context_v.to(hidden_states.dtype) + else: + context_k = encoder_hidden_states + context_v = encoder_hidden_states + + k = attn.to_k(context_k) + v = attn.to_v(context_v) + del encoder_hidden_states, hidden_states + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) + + out = FlashAttentionFunction.apply( + q, k, v, attention_mask, False, q_bucket_size, k_bucket_size + ) + + out = rearrange(out, "b h n d -> b n (h d)") + + out = attn.to_out[0](out) + out = attn.to_out[1](out) + return out diff --git a/library/hypernetwork.py b/library/hypernetwork.py new file mode 100644 index 00000000..fbd3fb24 --- /dev/null +++ b/library/hypernetwork.py @@ -0,0 +1,223 @@ +import torch +import torch.nn.functional as F +from diffusers.models.attention_processor import ( + Attention, + AttnProcessor2_0, + SlicedAttnProcessor, + XFormersAttnProcessor +) + +try: + import xformers.ops +except: + xformers = None + + +loaded_networks = [] + + +def apply_single_hypernetwork( + hypernetwork, hidden_states, encoder_hidden_states +): + context_k, context_v = hypernetwork.forward(hidden_states, encoder_hidden_states) + return context_k, context_v + + +def apply_hypernetworks(context_k, context_v, layer=None): + if len(loaded_networks) == 0: + return context_v, context_v + for hypernetwork in loaded_networks: + context_k, context_v = hypernetwork.forward(context_k, context_v) + + context_k = context_k.to(dtype=context_k.dtype) + context_v = context_v.to(dtype=context_k.dtype) + + return context_k, context_v + + + +def xformers_forward( + self: XFormersAttnProcessor, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + attention_mask: torch.Tensor = None, +): + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states) + + key = attn.to_k(context_k) + value = attn.to_v(context_v) + + query = attn.head_to_batch_dim(query).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + hidden_states = xformers.ops.memory_efficient_attention( + query, + key, + value, + attn_bias=attention_mask, + op=self.attention_op, + scale=attn.scale, + ) + hidden_states = hidden_states.to(query.dtype) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +def sliced_attn_forward( + self: SlicedAttnProcessor, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + attention_mask: torch.Tensor = None, +): + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + + query = attn.to_q(hidden_states) + dim = query.shape[-1] + query = attn.head_to_batch_dim(query) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states) + + key = attn.to_k(context_k) + value = attn.to_v(context_v) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + batch_size_attention, query_tokens, _ = query.shape + hidden_states = torch.zeros( + (batch_size_attention, query_tokens, dim // attn.heads), + device=query.device, + dtype=query.dtype, + ) + + for i in range(batch_size_attention // self.slice_size): + start_idx = i * self.slice_size + end_idx = (i + 1) * self.slice_size + + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + attn_mask_slice = ( + attention_mask[start_idx:end_idx] if attention_mask is not None else None + ) + + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) + + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +def v2_0_forward( + self: AttnProcessor2_0, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, +): + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + inner_dim = hidden_states.shape[-1] + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view( + batch_size, attn.heads, -1, attention_mask.shape[-1] + ) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states) + + key = attn.to_k(context_k) + value = attn.to_v(context_v) + + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +def replace_attentions_for_hypernetwork(): + import diffusers.models.attention_processor + + diffusers.models.attention_processor.XFormersAttnProcessor.__call__ = ( + xformers_forward + ) + diffusers.models.attention_processor.SlicedAttnProcessor.__call__ = ( + sliced_attn_forward + ) + diffusers.models.attention_processor.AttnProcessor2_0.__call__ = v2_0_forward diff --git a/library/lpw_stable_diffusion.py b/library/lpw_stable_diffusion.py index 3e04b887..84e1ab15 100644 --- a/library/lpw_stable_diffusion.py +++ b/library/lpw_stable_diffusion.py @@ -464,10 +464,10 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline): tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: SchedulerMixin, - clip_skip: int, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, requires_safety_checker: bool = True, + clip_skip: int = 1, ): super().__init__( vae=vae, diff --git a/library/train_util.py b/library/train_util.py index 46c5c3b2..008ccd64 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -63,6 +63,8 @@ import safetensors.torch from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline import library.model_util as model_util import library.huggingface_util as huggingface_util +from library.attention_processors import FlashAttnProcessor +from library.hypernetwork import replace_attentions_for_hypernetwork # Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う TOKENIZER_PATH = "openai/clip-vit-large-patch14" @@ -1630,209 +1632,14 @@ def get_git_revision_hash() -> str: return "(unknown)" -# flash attention forwards and backwards - -# https://arxiv.org/abs/2205.14135 - - -class FlashAttentionFunction(torch.autograd.function.Function): - @staticmethod - @torch.no_grad() - def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): - """Algorithm 2 in the paper""" - - device = q.device - dtype = q.dtype - max_neg_value = -torch.finfo(q.dtype).max - qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) - - o = torch.zeros_like(q) - all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device) - all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device) - - scale = q.shape[-1] ** -0.5 - - if not exists(mask): - mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) - else: - mask = rearrange(mask, "b n -> b 1 1 n") - mask = mask.split(q_bucket_size, dim=-1) - - row_splits = zip( - q.split(q_bucket_size, dim=-2), - o.split(q_bucket_size, dim=-2), - mask, - all_row_sums.split(q_bucket_size, dim=-2), - all_row_maxes.split(q_bucket_size, dim=-2), - ) - - for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): - q_start_index = ind * q_bucket_size - qk_len_diff - - col_splits = zip( - k.split(k_bucket_size, dim=-2), - v.split(k_bucket_size, dim=-2), - ) - - for k_ind, (kc, vc) in enumerate(col_splits): - k_start_index = k_ind * k_bucket_size - - attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale - - if exists(row_mask): - attn_weights.masked_fill_(~row_mask, max_neg_value) - - if causal and q_start_index < (k_start_index + k_bucket_size - 1): - causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu( - q_start_index - k_start_index + 1 - ) - attn_weights.masked_fill_(causal_mask, max_neg_value) - - block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) - attn_weights -= block_row_maxes - exp_weights = torch.exp(attn_weights) - - if exists(row_mask): - exp_weights.masked_fill_(~row_mask, 0.0) - - block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON) - - new_row_maxes = torch.maximum(block_row_maxes, row_maxes) - - exp_values = einsum("... i j, ... j d -> ... i d", exp_weights, vc) - - exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) - exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) - - new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums - - oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values) - - row_maxes.copy_(new_row_maxes) - row_sums.copy_(new_row_sums) - - ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) - ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) - - return o - - @staticmethod - @torch.no_grad() - def backward(ctx, do): - """Algorithm 4 in the paper""" - - causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args - q, k, v, o, l, m = ctx.saved_tensors - - device = q.device - - max_neg_value = -torch.finfo(q.dtype).max - qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) - - dq = torch.zeros_like(q) - dk = torch.zeros_like(k) - dv = torch.zeros_like(v) - - row_splits = zip( - q.split(q_bucket_size, dim=-2), - o.split(q_bucket_size, dim=-2), - do.split(q_bucket_size, dim=-2), - mask, - l.split(q_bucket_size, dim=-2), - m.split(q_bucket_size, dim=-2), - dq.split(q_bucket_size, dim=-2), - ) - - for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): - q_start_index = ind * q_bucket_size - qk_len_diff - - col_splits = zip( - k.split(k_bucket_size, dim=-2), - v.split(k_bucket_size, dim=-2), - dk.split(k_bucket_size, dim=-2), - dv.split(k_bucket_size, dim=-2), - ) - - for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): - k_start_index = k_ind * k_bucket_size - - attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale - - if causal and q_start_index < (k_start_index + k_bucket_size - 1): - causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu( - q_start_index - k_start_index + 1 - ) - attn_weights.masked_fill_(causal_mask, max_neg_value) - - exp_attn_weights = torch.exp(attn_weights - mc) - - if exists(row_mask): - exp_attn_weights.masked_fill_(~row_mask, 0.0) - - p = exp_attn_weights / lc - - dv_chunk = einsum("... i j, ... i d -> ... j d", p, doc) - dp = einsum("... i d, ... j d -> ... i j", doc, vc) - - D = (doc * oc).sum(dim=-1, keepdims=True) - ds = p * scale * (dp - D) - - dq_chunk = einsum("... i j, ... j d -> ... i d", ds, kc) - dk_chunk = einsum("... i j, ... i d -> ... j d", ds, qc) - - dqc.add_(dq_chunk) - dkc.add_(dk_chunk) - dvc.add_(dv_chunk) - - return dq, dk, dv, None, None, None, None - def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers): + replace_attentions_for_hypernetwork() # unet is not used currently, but it is here for future use if mem_eff_attn: - replace_unet_cross_attn_to_memory_efficient() + unet.set_attn_processor(FlashAttnProcessor()) elif xformers: - replace_unet_cross_attn_to_xformers() - - -def replace_unet_cross_attn_to_memory_efficient(): - print("CrossAttention.forward has been replaced to FlashAttention (not xformers)") - flash_func = FlashAttentionFunction - - def forward_flash_attn(self, x, context=None, mask=None): - q_bucket_size = 512 - k_bucket_size = 1024 - - h = self.heads - q = self.to_q(x) - - context = context if context is not None else x - context = context.to(x.dtype) - - if hasattr(self, "hypernetwork") and self.hypernetwork is not None: - context_k, context_v = self.hypernetwork.forward(x, context) - context_k = context_k.to(x.dtype) - context_v = context_v.to(x.dtype) - else: - context_k = context - context_v = context - - k = self.to_k(context_k) - v = self.to_v(context_v) - del context, x - - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) - - out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size) - - out = rearrange(out, "b h n d -> b n (h d)") - - # diffusers 0.7.0~ わざわざ変えるなよ (;´Д`) - out = self.to_out[0](out) - out = self.to_out[1](out) - return out - - diffusers.models.attention.CrossAttention.forward = forward_flash_attn + unet.enable_xformers_memory_efficient_attention() def replace_unet_cross_attn_to_xformers(): @@ -3458,10 +3265,10 @@ def sample_images( unet=unet, tokenizer=tokenizer, scheduler=scheduler, - clip_skip=args.clip_skip, safety_checker=None, feature_extractor=None, requires_safety_checker=False, + clip_skip=args.clip_skip, ) pipeline.to(device) diff --git a/networks/lora.py b/networks/lora.py index 19fbbbdb..3a475e25 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -665,7 +665,7 @@ class LoRANetwork(torch.nn.Module): NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数 # is it possible to apply conv_in and conv_out? -> yes, newer LoCon supports it (^^;) - UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"] + UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"] UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] LORA_PREFIX_UNET = "lora_unet" diff --git a/requirements.txt b/requirements.txt index 801cf321..96da36a0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,10 @@ -accelerate==0.15.0 -transformers==4.26.0 +accelerate==0.19.0 +transformers==4.29.2 +diffusers[torch]==0.16.1 ftfy==6.1.1 albumentations==1.3.0 opencv-python==4.7.0.68 einops==0.6.0 -diffusers[torch]==0.10.2 pytorch-lightning==1.9.0 bitsandbytes==0.35.0 tensorboard==2.10.1 diff --git a/train_network.py b/train_network.py index cd90b0a2..109d1ff2 100644 --- a/train_network.py +++ b/train_network.py @@ -6,7 +6,6 @@ import os import random import time import json -import toml from multiprocessing import Value from tqdm import tqdm @@ -165,7 +164,7 @@ def train(args): import sys sys.path.append(os.path.dirname(__file__)) - print("import network module:", args.network_module) + accelerator.print("import network module:", args.network_module) network_module = importlib.import_module(args.network_module) if args.base_weights is not None: @@ -176,14 +175,15 @@ def train(args): else: multiplier = args.base_weights_multiplier[i] - print(f"merging module: {weight_path} with multiplier {multiplier}") + accelerator.print(f"merging module: {weight_path} with multiplier {multiplier}") module, weights_sd = network_module.create_network_from_weights( multiplier, weight_path, vae, text_encoder, unet, for_inference=True ) module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu") - print(f"all weights merged: {', '.join(args.base_weights)}") + accelerator.print(f"all weights merged: {', '.join(args.base_weights)}") + # 学習を準備する if cache_latents: vae.to(accelerator.device, dtype=weight_dtype) @@ -225,7 +225,7 @@ def train(args): if args.network_weights is not None: info = network.load_weights(args.network_weights) - print(f"loaded network weights from {args.network_weights}: {info}") + accelerator.print(f"load network weights from {args.network_weights}: {info}") if args.gradient_checkpointing: unet.enable_gradient_checkpointing() @@ -233,13 +233,13 @@ def train(args): network.enable_gradient_checkpointing() # may have no effect # 学習に必要なクラスを準備する - print("preparing optimizer, data loader etc.") + accelerator.print("prepare optimizer, data loader etc.") # 後方互換性を確保するよ try: trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate) except TypeError: - print( + accelerator.print( "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)" ) trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr) @@ -264,8 +264,7 @@ def train(args): args.max_train_steps = args.max_train_epochs * math.ceil( len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps ) - if is_main_process: - print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) @@ -278,7 +277,7 @@ def train(args): assert ( args.mixed_precision == "fp16" ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" - print("enabling full fp16 training.") + accelerator.print("enable full fp16 training.") network.to(weight_dtype) # acceleratorがなんかよろしくやってくれるらしい @@ -338,16 +337,15 @@ def train(args): # TODO: find a way to handle total batch size when there are multiple datasets total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - if is_main_process: - print("running training / 学習開始") - print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") - print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") - print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") - print(f" num epochs / epoch数: {num_train_epochs}") - print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") - # print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") - print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") - print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + accelerator.print("running training / 学習開始") + accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") + # accelerator.print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") # TODO refactor metadata creation and move to util metadata = { @@ -572,7 +570,7 @@ def train(args): os.makedirs(args.output_dir, exist_ok=True) ckpt_file = os.path.join(args.output_dir, ckpt_name) - print(f"\nsaving checkpoint: {ckpt_file}") + accelerator.print(f"\nsaving checkpoint: {ckpt_file}") metadata["ss_training_finished_at"] = str(time.time()) metadata["ss_steps"] = str(steps) metadata["ss_epoch"] = str(epoch_no) @@ -584,13 +582,12 @@ def train(args): def remove_model(old_ckpt_name): old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) if os.path.exists(old_ckpt_file): - print(f"removing old checkpoint: {old_ckpt_file}") + accelerator.print(f"removing old checkpoint: {old_ckpt_file}") os.remove(old_ckpt_file) # training loop for epoch in range(num_train_epochs): - if is_main_process: - print(f"\nepoch {epoch+1}/{num_train_epochs}") + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 metadata["ss_epoch"] = str(epoch + 1)