mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
update diffusers to 1.16 | train_network
This commit is contained in:
227
library/attention_processors.py
Normal file
227
library/attention_processors.py
Normal file
@@ -0,0 +1,227 @@
|
|||||||
|
import math
|
||||||
|
from typing import Any
|
||||||
|
from einops import rearrange
|
||||||
|
import torch
|
||||||
|
from diffusers.models.attention_processor import Attention
|
||||||
|
|
||||||
|
|
||||||
|
# flash attention forwards and backwards
|
||||||
|
|
||||||
|
# https://arxiv.org/abs/2205.14135
|
||||||
|
|
||||||
|
EPSILON = 1e-6
|
||||||
|
|
||||||
|
|
||||||
|
class FlashAttentionFunction(torch.autograd.function.Function):
|
||||||
|
@staticmethod
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
|
||||||
|
"""Algorithm 2 in the paper"""
|
||||||
|
|
||||||
|
device = q.device
|
||||||
|
dtype = q.dtype
|
||||||
|
max_neg_value = -torch.finfo(q.dtype).max
|
||||||
|
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
|
||||||
|
|
||||||
|
o = torch.zeros_like(q)
|
||||||
|
all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
|
||||||
|
all_row_maxes = torch.full(
|
||||||
|
(*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
scale = q.shape[-1] ** -0.5
|
||||||
|
|
||||||
|
if mask is None:
|
||||||
|
mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
|
||||||
|
else:
|
||||||
|
mask = rearrange(mask, "b n -> b 1 1 n")
|
||||||
|
mask = mask.split(q_bucket_size, dim=-1)
|
||||||
|
|
||||||
|
row_splits = zip(
|
||||||
|
q.split(q_bucket_size, dim=-2),
|
||||||
|
o.split(q_bucket_size, dim=-2),
|
||||||
|
mask,
|
||||||
|
all_row_sums.split(q_bucket_size, dim=-2),
|
||||||
|
all_row_maxes.split(q_bucket_size, dim=-2),
|
||||||
|
)
|
||||||
|
|
||||||
|
for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
|
||||||
|
q_start_index = ind * q_bucket_size - qk_len_diff
|
||||||
|
|
||||||
|
col_splits = zip(
|
||||||
|
k.split(k_bucket_size, dim=-2),
|
||||||
|
v.split(k_bucket_size, dim=-2),
|
||||||
|
)
|
||||||
|
|
||||||
|
for k_ind, (kc, vc) in enumerate(col_splits):
|
||||||
|
k_start_index = k_ind * k_bucket_size
|
||||||
|
|
||||||
|
attn_weights = (
|
||||||
|
torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
|
||||||
|
)
|
||||||
|
|
||||||
|
if row_mask is not None:
|
||||||
|
attn_weights.masked_fill_(~row_mask, max_neg_value)
|
||||||
|
|
||||||
|
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
|
||||||
|
causal_mask = torch.ones(
|
||||||
|
(qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device
|
||||||
|
).triu(q_start_index - k_start_index + 1)
|
||||||
|
attn_weights.masked_fill_(causal_mask, max_neg_value)
|
||||||
|
|
||||||
|
block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
|
||||||
|
attn_weights -= block_row_maxes
|
||||||
|
exp_weights = torch.exp(attn_weights)
|
||||||
|
|
||||||
|
if row_mask is not None:
|
||||||
|
exp_weights.masked_fill_(~row_mask, 0.0)
|
||||||
|
|
||||||
|
block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(
|
||||||
|
min=EPSILON
|
||||||
|
)
|
||||||
|
|
||||||
|
new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
|
||||||
|
|
||||||
|
exp_values = torch.einsum(
|
||||||
|
"... i j, ... j d -> ... i d", exp_weights, vc
|
||||||
|
)
|
||||||
|
|
||||||
|
exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
|
||||||
|
exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
|
||||||
|
|
||||||
|
new_row_sums = (
|
||||||
|
exp_row_max_diff * row_sums
|
||||||
|
+ exp_block_row_max_diff * block_row_sums
|
||||||
|
)
|
||||||
|
|
||||||
|
oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_(
|
||||||
|
(exp_block_row_max_diff / new_row_sums) * exp_values
|
||||||
|
)
|
||||||
|
|
||||||
|
row_maxes.copy_(new_row_maxes)
|
||||||
|
row_sums.copy_(new_row_sums)
|
||||||
|
|
||||||
|
ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
|
||||||
|
ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
|
||||||
|
|
||||||
|
return o
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@torch.no_grad()
|
||||||
|
def backward(ctx, do):
|
||||||
|
"""Algorithm 4 in the paper"""
|
||||||
|
|
||||||
|
causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
|
||||||
|
q, k, v, o, l, m = ctx.saved_tensors
|
||||||
|
|
||||||
|
device = q.device
|
||||||
|
|
||||||
|
max_neg_value = -torch.finfo(q.dtype).max
|
||||||
|
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
|
||||||
|
|
||||||
|
dq = torch.zeros_like(q)
|
||||||
|
dk = torch.zeros_like(k)
|
||||||
|
dv = torch.zeros_like(v)
|
||||||
|
|
||||||
|
row_splits = zip(
|
||||||
|
q.split(q_bucket_size, dim=-2),
|
||||||
|
o.split(q_bucket_size, dim=-2),
|
||||||
|
do.split(q_bucket_size, dim=-2),
|
||||||
|
mask,
|
||||||
|
l.split(q_bucket_size, dim=-2),
|
||||||
|
m.split(q_bucket_size, dim=-2),
|
||||||
|
dq.split(q_bucket_size, dim=-2),
|
||||||
|
)
|
||||||
|
|
||||||
|
for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
|
||||||
|
q_start_index = ind * q_bucket_size - qk_len_diff
|
||||||
|
|
||||||
|
col_splits = zip(
|
||||||
|
k.split(k_bucket_size, dim=-2),
|
||||||
|
v.split(k_bucket_size, dim=-2),
|
||||||
|
dk.split(k_bucket_size, dim=-2),
|
||||||
|
dv.split(k_bucket_size, dim=-2),
|
||||||
|
)
|
||||||
|
|
||||||
|
for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
|
||||||
|
k_start_index = k_ind * k_bucket_size
|
||||||
|
|
||||||
|
attn_weights = (
|
||||||
|
torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
|
||||||
|
)
|
||||||
|
|
||||||
|
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
|
||||||
|
causal_mask = torch.ones(
|
||||||
|
(qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device
|
||||||
|
).triu(q_start_index - k_start_index + 1)
|
||||||
|
attn_weights.masked_fill_(causal_mask, max_neg_value)
|
||||||
|
|
||||||
|
exp_attn_weights = torch.exp(attn_weights - mc)
|
||||||
|
|
||||||
|
if row_mask is not None:
|
||||||
|
exp_attn_weights.masked_fill_(~row_mask, 0.0)
|
||||||
|
|
||||||
|
p = exp_attn_weights / lc
|
||||||
|
|
||||||
|
dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc)
|
||||||
|
dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc)
|
||||||
|
|
||||||
|
D = (doc * oc).sum(dim=-1, keepdims=True)
|
||||||
|
ds = p * scale * (dp - D)
|
||||||
|
|
||||||
|
dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc)
|
||||||
|
dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc)
|
||||||
|
|
||||||
|
dqc.add_(dq_chunk)
|
||||||
|
dkc.add_(dk_chunk)
|
||||||
|
dvc.add_(dv_chunk)
|
||||||
|
|
||||||
|
return dq, dk, dv, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
|
class FlashAttnProcessor:
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
attn: Attention,
|
||||||
|
hidden_states,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
attention_mask=None,
|
||||||
|
) -> Any:
|
||||||
|
q_bucket_size = 512
|
||||||
|
k_bucket_size = 1024
|
||||||
|
|
||||||
|
h = attn.heads
|
||||||
|
q = attn.to_q(hidden_states)
|
||||||
|
|
||||||
|
encoder_hidden_states = (
|
||||||
|
encoder_hidden_states
|
||||||
|
if encoder_hidden_states is not None
|
||||||
|
else hidden_states
|
||||||
|
)
|
||||||
|
encoder_hidden_states = encoder_hidden_states.to(hidden_states.dtype)
|
||||||
|
|
||||||
|
if hasattr(attn, "hypernetwork") and attn.hypernetwork is not None:
|
||||||
|
context_k, context_v = attn.hypernetwork.forward(
|
||||||
|
hidden_states, encoder_hidden_states
|
||||||
|
)
|
||||||
|
context_k = context_k.to(hidden_states.dtype)
|
||||||
|
context_v = context_v.to(hidden_states.dtype)
|
||||||
|
else:
|
||||||
|
context_k = encoder_hidden_states
|
||||||
|
context_v = encoder_hidden_states
|
||||||
|
|
||||||
|
k = attn.to_k(context_k)
|
||||||
|
v = attn.to_v(context_v)
|
||||||
|
del encoder_hidden_states, hidden_states
|
||||||
|
|
||||||
|
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
|
||||||
|
|
||||||
|
out = FlashAttentionFunction.apply(
|
||||||
|
q, k, v, attention_mask, False, q_bucket_size, k_bucket_size
|
||||||
|
)
|
||||||
|
|
||||||
|
out = rearrange(out, "b h n d -> b n (h d)")
|
||||||
|
|
||||||
|
out = attn.to_out[0](out)
|
||||||
|
out = attn.to_out[1](out)
|
||||||
|
return out
|
||||||
223
library/hypernetwork.py
Normal file
223
library/hypernetwork.py
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from diffusers.models.attention_processor import (
|
||||||
|
Attention,
|
||||||
|
AttnProcessor2_0,
|
||||||
|
SlicedAttnProcessor,
|
||||||
|
XFormersAttnProcessor
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import xformers.ops
|
||||||
|
except:
|
||||||
|
xformers = None
|
||||||
|
|
||||||
|
|
||||||
|
loaded_networks = []
|
||||||
|
|
||||||
|
|
||||||
|
def apply_single_hypernetwork(
|
||||||
|
hypernetwork, hidden_states, encoder_hidden_states
|
||||||
|
):
|
||||||
|
context_k, context_v = hypernetwork.forward(hidden_states, encoder_hidden_states)
|
||||||
|
return context_k, context_v
|
||||||
|
|
||||||
|
|
||||||
|
def apply_hypernetworks(context_k, context_v, layer=None):
|
||||||
|
if len(loaded_networks) == 0:
|
||||||
|
return context_v, context_v
|
||||||
|
for hypernetwork in loaded_networks:
|
||||||
|
context_k, context_v = hypernetwork.forward(context_k, context_v)
|
||||||
|
|
||||||
|
context_k = context_k.to(dtype=context_k.dtype)
|
||||||
|
context_v = context_v.to(dtype=context_k.dtype)
|
||||||
|
|
||||||
|
return context_k, context_v
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def xformers_forward(
|
||||||
|
self: XFormersAttnProcessor,
|
||||||
|
attn: Attention,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
encoder_hidden_states: torch.Tensor = None,
|
||||||
|
attention_mask: torch.Tensor = None,
|
||||||
|
):
|
||||||
|
batch_size, sequence_length, _ = (
|
||||||
|
hidden_states.shape
|
||||||
|
if encoder_hidden_states is None
|
||||||
|
else encoder_hidden_states.shape
|
||||||
|
)
|
||||||
|
|
||||||
|
attention_mask = attn.prepare_attention_mask(
|
||||||
|
attention_mask, sequence_length, batch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
query = attn.to_q(hidden_states)
|
||||||
|
|
||||||
|
if encoder_hidden_states is None:
|
||||||
|
encoder_hidden_states = hidden_states
|
||||||
|
elif attn.norm_cross:
|
||||||
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||||
|
|
||||||
|
context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)
|
||||||
|
|
||||||
|
key = attn.to_k(context_k)
|
||||||
|
value = attn.to_v(context_v)
|
||||||
|
|
||||||
|
query = attn.head_to_batch_dim(query).contiguous()
|
||||||
|
key = attn.head_to_batch_dim(key).contiguous()
|
||||||
|
value = attn.head_to_batch_dim(value).contiguous()
|
||||||
|
|
||||||
|
hidden_states = xformers.ops.memory_efficient_attention(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
attn_bias=attention_mask,
|
||||||
|
op=self.attention_op,
|
||||||
|
scale=attn.scale,
|
||||||
|
)
|
||||||
|
hidden_states = hidden_states.to(query.dtype)
|
||||||
|
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||||
|
|
||||||
|
# linear proj
|
||||||
|
hidden_states = attn.to_out[0](hidden_states)
|
||||||
|
# dropout
|
||||||
|
hidden_states = attn.to_out[1](hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
def sliced_attn_forward(
|
||||||
|
self: SlicedAttnProcessor,
|
||||||
|
attn: Attention,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
encoder_hidden_states: torch.Tensor = None,
|
||||||
|
attention_mask: torch.Tensor = None,
|
||||||
|
):
|
||||||
|
batch_size, sequence_length, _ = (
|
||||||
|
hidden_states.shape
|
||||||
|
if encoder_hidden_states is None
|
||||||
|
else encoder_hidden_states.shape
|
||||||
|
)
|
||||||
|
attention_mask = attn.prepare_attention_mask(
|
||||||
|
attention_mask, sequence_length, batch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
query = attn.to_q(hidden_states)
|
||||||
|
dim = query.shape[-1]
|
||||||
|
query = attn.head_to_batch_dim(query)
|
||||||
|
|
||||||
|
if encoder_hidden_states is None:
|
||||||
|
encoder_hidden_states = hidden_states
|
||||||
|
elif attn.norm_cross:
|
||||||
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||||
|
|
||||||
|
context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)
|
||||||
|
|
||||||
|
key = attn.to_k(context_k)
|
||||||
|
value = attn.to_v(context_v)
|
||||||
|
key = attn.head_to_batch_dim(key)
|
||||||
|
value = attn.head_to_batch_dim(value)
|
||||||
|
|
||||||
|
batch_size_attention, query_tokens, _ = query.shape
|
||||||
|
hidden_states = torch.zeros(
|
||||||
|
(batch_size_attention, query_tokens, dim // attn.heads),
|
||||||
|
device=query.device,
|
||||||
|
dtype=query.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(batch_size_attention // self.slice_size):
|
||||||
|
start_idx = i * self.slice_size
|
||||||
|
end_idx = (i + 1) * self.slice_size
|
||||||
|
|
||||||
|
query_slice = query[start_idx:end_idx]
|
||||||
|
key_slice = key[start_idx:end_idx]
|
||||||
|
attn_mask_slice = (
|
||||||
|
attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
||||||
|
|
||||||
|
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
||||||
|
|
||||||
|
hidden_states[start_idx:end_idx] = attn_slice
|
||||||
|
|
||||||
|
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||||
|
|
||||||
|
# linear proj
|
||||||
|
hidden_states = attn.to_out[0](hidden_states)
|
||||||
|
# dropout
|
||||||
|
hidden_states = attn.to_out[1](hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
def v2_0_forward(
|
||||||
|
self: AttnProcessor2_0,
|
||||||
|
attn: Attention,
|
||||||
|
hidden_states,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
attention_mask=None,
|
||||||
|
):
|
||||||
|
batch_size, sequence_length, _ = (
|
||||||
|
hidden_states.shape
|
||||||
|
if encoder_hidden_states is None
|
||||||
|
else encoder_hidden_states.shape
|
||||||
|
)
|
||||||
|
inner_dim = hidden_states.shape[-1]
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = attn.prepare_attention_mask(
|
||||||
|
attention_mask, sequence_length, batch_size
|
||||||
|
)
|
||||||
|
# scaled_dot_product_attention expects attention_mask shape to be
|
||||||
|
# (batch, heads, source_length, target_length)
|
||||||
|
attention_mask = attention_mask.view(
|
||||||
|
batch_size, attn.heads, -1, attention_mask.shape[-1]
|
||||||
|
)
|
||||||
|
|
||||||
|
query = attn.to_q(hidden_states)
|
||||||
|
|
||||||
|
if encoder_hidden_states is None:
|
||||||
|
encoder_hidden_states = hidden_states
|
||||||
|
elif attn.norm_cross:
|
||||||
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||||
|
|
||||||
|
context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)
|
||||||
|
|
||||||
|
key = attn.to_k(context_k)
|
||||||
|
value = attn.to_v(context_v)
|
||||||
|
|
||||||
|
head_dim = inner_dim // attn.heads
|
||||||
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||||
|
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||||
|
hidden_states = F.scaled_dot_product_attention(
|
||||||
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = hidden_states.transpose(1, 2).reshape(
|
||||||
|
batch_size, -1, attn.heads * head_dim
|
||||||
|
)
|
||||||
|
hidden_states = hidden_states.to(query.dtype)
|
||||||
|
|
||||||
|
# linear proj
|
||||||
|
hidden_states = attn.to_out[0](hidden_states)
|
||||||
|
# dropout
|
||||||
|
hidden_states = attn.to_out[1](hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
def replace_attentions_for_hypernetwork():
|
||||||
|
import diffusers.models.attention_processor
|
||||||
|
|
||||||
|
diffusers.models.attention_processor.XFormersAttnProcessor.__call__ = (
|
||||||
|
xformers_forward
|
||||||
|
)
|
||||||
|
diffusers.models.attention_processor.SlicedAttnProcessor.__call__ = (
|
||||||
|
sliced_attn_forward
|
||||||
|
)
|
||||||
|
diffusers.models.attention_processor.AttnProcessor2_0.__call__ = v2_0_forward
|
||||||
@@ -464,10 +464,10 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|||||||
tokenizer: CLIPTokenizer,
|
tokenizer: CLIPTokenizer,
|
||||||
unet: UNet2DConditionModel,
|
unet: UNet2DConditionModel,
|
||||||
scheduler: SchedulerMixin,
|
scheduler: SchedulerMixin,
|
||||||
clip_skip: int,
|
|
||||||
safety_checker: StableDiffusionSafetyChecker,
|
safety_checker: StableDiffusionSafetyChecker,
|
||||||
feature_extractor: CLIPFeatureExtractor,
|
feature_extractor: CLIPFeatureExtractor,
|
||||||
requires_safety_checker: bool = True,
|
requires_safety_checker: bool = True,
|
||||||
|
clip_skip: int = 1,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
vae=vae,
|
vae=vae,
|
||||||
|
|||||||
@@ -63,6 +63,8 @@ import safetensors.torch
|
|||||||
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
|
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
|
||||||
import library.model_util as model_util
|
import library.model_util as model_util
|
||||||
import library.huggingface_util as huggingface_util
|
import library.huggingface_util as huggingface_util
|
||||||
|
from library.attention_processors import FlashAttnProcessor
|
||||||
|
from library.hypernetwork import replace_attentions_for_hypernetwork
|
||||||
|
|
||||||
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
|
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
|
||||||
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
||||||
@@ -1630,209 +1632,14 @@ def get_git_revision_hash() -> str:
|
|||||||
return "(unknown)"
|
return "(unknown)"
|
||||||
|
|
||||||
|
|
||||||
# flash attention forwards and backwards
|
|
||||||
|
|
||||||
# https://arxiv.org/abs/2205.14135
|
|
||||||
|
|
||||||
|
|
||||||
class FlashAttentionFunction(torch.autograd.function.Function):
|
|
||||||
@staticmethod
|
|
||||||
@torch.no_grad()
|
|
||||||
def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
|
|
||||||
"""Algorithm 2 in the paper"""
|
|
||||||
|
|
||||||
device = q.device
|
|
||||||
dtype = q.dtype
|
|
||||||
max_neg_value = -torch.finfo(q.dtype).max
|
|
||||||
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
|
|
||||||
|
|
||||||
o = torch.zeros_like(q)
|
|
||||||
all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
|
|
||||||
all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
scale = q.shape[-1] ** -0.5
|
|
||||||
|
|
||||||
if not exists(mask):
|
|
||||||
mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
|
|
||||||
else:
|
|
||||||
mask = rearrange(mask, "b n -> b 1 1 n")
|
|
||||||
mask = mask.split(q_bucket_size, dim=-1)
|
|
||||||
|
|
||||||
row_splits = zip(
|
|
||||||
q.split(q_bucket_size, dim=-2),
|
|
||||||
o.split(q_bucket_size, dim=-2),
|
|
||||||
mask,
|
|
||||||
all_row_sums.split(q_bucket_size, dim=-2),
|
|
||||||
all_row_maxes.split(q_bucket_size, dim=-2),
|
|
||||||
)
|
|
||||||
|
|
||||||
for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
|
|
||||||
q_start_index = ind * q_bucket_size - qk_len_diff
|
|
||||||
|
|
||||||
col_splits = zip(
|
|
||||||
k.split(k_bucket_size, dim=-2),
|
|
||||||
v.split(k_bucket_size, dim=-2),
|
|
||||||
)
|
|
||||||
|
|
||||||
for k_ind, (kc, vc) in enumerate(col_splits):
|
|
||||||
k_start_index = k_ind * k_bucket_size
|
|
||||||
|
|
||||||
attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale
|
|
||||||
|
|
||||||
if exists(row_mask):
|
|
||||||
attn_weights.masked_fill_(~row_mask, max_neg_value)
|
|
||||||
|
|
||||||
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
|
|
||||||
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
|
|
||||||
q_start_index - k_start_index + 1
|
|
||||||
)
|
|
||||||
attn_weights.masked_fill_(causal_mask, max_neg_value)
|
|
||||||
|
|
||||||
block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
|
|
||||||
attn_weights -= block_row_maxes
|
|
||||||
exp_weights = torch.exp(attn_weights)
|
|
||||||
|
|
||||||
if exists(row_mask):
|
|
||||||
exp_weights.masked_fill_(~row_mask, 0.0)
|
|
||||||
|
|
||||||
block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON)
|
|
||||||
|
|
||||||
new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
|
|
||||||
|
|
||||||
exp_values = einsum("... i j, ... j d -> ... i d", exp_weights, vc)
|
|
||||||
|
|
||||||
exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
|
|
||||||
exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
|
|
||||||
|
|
||||||
new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
|
|
||||||
|
|
||||||
oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values)
|
|
||||||
|
|
||||||
row_maxes.copy_(new_row_maxes)
|
|
||||||
row_sums.copy_(new_row_sums)
|
|
||||||
|
|
||||||
ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
|
|
||||||
ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
|
|
||||||
|
|
||||||
return o
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@torch.no_grad()
|
|
||||||
def backward(ctx, do):
|
|
||||||
"""Algorithm 4 in the paper"""
|
|
||||||
|
|
||||||
causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
|
|
||||||
q, k, v, o, l, m = ctx.saved_tensors
|
|
||||||
|
|
||||||
device = q.device
|
|
||||||
|
|
||||||
max_neg_value = -torch.finfo(q.dtype).max
|
|
||||||
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
|
|
||||||
|
|
||||||
dq = torch.zeros_like(q)
|
|
||||||
dk = torch.zeros_like(k)
|
|
||||||
dv = torch.zeros_like(v)
|
|
||||||
|
|
||||||
row_splits = zip(
|
|
||||||
q.split(q_bucket_size, dim=-2),
|
|
||||||
o.split(q_bucket_size, dim=-2),
|
|
||||||
do.split(q_bucket_size, dim=-2),
|
|
||||||
mask,
|
|
||||||
l.split(q_bucket_size, dim=-2),
|
|
||||||
m.split(q_bucket_size, dim=-2),
|
|
||||||
dq.split(q_bucket_size, dim=-2),
|
|
||||||
)
|
|
||||||
|
|
||||||
for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
|
|
||||||
q_start_index = ind * q_bucket_size - qk_len_diff
|
|
||||||
|
|
||||||
col_splits = zip(
|
|
||||||
k.split(k_bucket_size, dim=-2),
|
|
||||||
v.split(k_bucket_size, dim=-2),
|
|
||||||
dk.split(k_bucket_size, dim=-2),
|
|
||||||
dv.split(k_bucket_size, dim=-2),
|
|
||||||
)
|
|
||||||
|
|
||||||
for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
|
|
||||||
k_start_index = k_ind * k_bucket_size
|
|
||||||
|
|
||||||
attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale
|
|
||||||
|
|
||||||
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
|
|
||||||
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
|
|
||||||
q_start_index - k_start_index + 1
|
|
||||||
)
|
|
||||||
attn_weights.masked_fill_(causal_mask, max_neg_value)
|
|
||||||
|
|
||||||
exp_attn_weights = torch.exp(attn_weights - mc)
|
|
||||||
|
|
||||||
if exists(row_mask):
|
|
||||||
exp_attn_weights.masked_fill_(~row_mask, 0.0)
|
|
||||||
|
|
||||||
p = exp_attn_weights / lc
|
|
||||||
|
|
||||||
dv_chunk = einsum("... i j, ... i d -> ... j d", p, doc)
|
|
||||||
dp = einsum("... i d, ... j d -> ... i j", doc, vc)
|
|
||||||
|
|
||||||
D = (doc * oc).sum(dim=-1, keepdims=True)
|
|
||||||
ds = p * scale * (dp - D)
|
|
||||||
|
|
||||||
dq_chunk = einsum("... i j, ... j d -> ... i d", ds, kc)
|
|
||||||
dk_chunk = einsum("... i j, ... i d -> ... j d", ds, qc)
|
|
||||||
|
|
||||||
dqc.add_(dq_chunk)
|
|
||||||
dkc.add_(dk_chunk)
|
|
||||||
dvc.add_(dv_chunk)
|
|
||||||
|
|
||||||
return dq, dk, dv, None, None, None, None
|
|
||||||
|
|
||||||
|
|
||||||
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
|
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
|
||||||
|
replace_attentions_for_hypernetwork()
|
||||||
# unet is not used currently, but it is here for future use
|
# unet is not used currently, but it is here for future use
|
||||||
if mem_eff_attn:
|
if mem_eff_attn:
|
||||||
replace_unet_cross_attn_to_memory_efficient()
|
unet.set_attn_processor(FlashAttnProcessor())
|
||||||
elif xformers:
|
elif xformers:
|
||||||
replace_unet_cross_attn_to_xformers()
|
unet.enable_xformers_memory_efficient_attention()
|
||||||
|
|
||||||
|
|
||||||
def replace_unet_cross_attn_to_memory_efficient():
|
|
||||||
print("CrossAttention.forward has been replaced to FlashAttention (not xformers)")
|
|
||||||
flash_func = FlashAttentionFunction
|
|
||||||
|
|
||||||
def forward_flash_attn(self, x, context=None, mask=None):
|
|
||||||
q_bucket_size = 512
|
|
||||||
k_bucket_size = 1024
|
|
||||||
|
|
||||||
h = self.heads
|
|
||||||
q = self.to_q(x)
|
|
||||||
|
|
||||||
context = context if context is not None else x
|
|
||||||
context = context.to(x.dtype)
|
|
||||||
|
|
||||||
if hasattr(self, "hypernetwork") and self.hypernetwork is not None:
|
|
||||||
context_k, context_v = self.hypernetwork.forward(x, context)
|
|
||||||
context_k = context_k.to(x.dtype)
|
|
||||||
context_v = context_v.to(x.dtype)
|
|
||||||
else:
|
|
||||||
context_k = context
|
|
||||||
context_v = context
|
|
||||||
|
|
||||||
k = self.to_k(context_k)
|
|
||||||
v = self.to_v(context_v)
|
|
||||||
del context, x
|
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
|
|
||||||
|
|
||||||
out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size)
|
|
||||||
|
|
||||||
out = rearrange(out, "b h n d -> b n (h d)")
|
|
||||||
|
|
||||||
# diffusers 0.7.0~ わざわざ変えるなよ (;´Д`)
|
|
||||||
out = self.to_out[0](out)
|
|
||||||
out = self.to_out[1](out)
|
|
||||||
return out
|
|
||||||
|
|
||||||
diffusers.models.attention.CrossAttention.forward = forward_flash_attn
|
|
||||||
|
|
||||||
|
|
||||||
def replace_unet_cross_attn_to_xformers():
|
def replace_unet_cross_attn_to_xformers():
|
||||||
@@ -3458,10 +3265,10 @@ def sample_images(
|
|||||||
unet=unet,
|
unet=unet,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
clip_skip=args.clip_skip,
|
|
||||||
safety_checker=None,
|
safety_checker=None,
|
||||||
feature_extractor=None,
|
feature_extractor=None,
|
||||||
requires_safety_checker=False,
|
requires_safety_checker=False,
|
||||||
|
clip_skip=args.clip_skip,
|
||||||
)
|
)
|
||||||
pipeline.to(device)
|
pipeline.to(device)
|
||||||
|
|
||||||
|
|||||||
@@ -665,7 +665,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数
|
NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数
|
||||||
|
|
||||||
# is it possible to apply conv_in and conv_out? -> yes, newer LoCon supports it (^^;)
|
# is it possible to apply conv_in and conv_out? -> yes, newer LoCon supports it (^^;)
|
||||||
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
|
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
||||||
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
||||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
||||||
LORA_PREFIX_UNET = "lora_unet"
|
LORA_PREFIX_UNET = "lora_unet"
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
accelerate==0.15.0
|
accelerate==0.19.0
|
||||||
transformers==4.26.0
|
transformers==4.29.2
|
||||||
|
diffusers[torch]==0.16.1
|
||||||
ftfy==6.1.1
|
ftfy==6.1.1
|
||||||
albumentations==1.3.0
|
albumentations==1.3.0
|
||||||
opencv-python==4.7.0.68
|
opencv-python==4.7.0.68
|
||||||
einops==0.6.0
|
einops==0.6.0
|
||||||
diffusers[torch]==0.10.2
|
|
||||||
pytorch-lightning==1.9.0
|
pytorch-lightning==1.9.0
|
||||||
bitsandbytes==0.35.0
|
bitsandbytes==0.35.0
|
||||||
tensorboard==2.10.1
|
tensorboard==2.10.1
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import os
|
|||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
import toml
|
|
||||||
from multiprocessing import Value
|
from multiprocessing import Value
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@@ -165,7 +164,7 @@ def train(args):
|
|||||||
import sys
|
import sys
|
||||||
|
|
||||||
sys.path.append(os.path.dirname(__file__))
|
sys.path.append(os.path.dirname(__file__))
|
||||||
print("import network module:", args.network_module)
|
accelerator.print("import network module:", args.network_module)
|
||||||
network_module = importlib.import_module(args.network_module)
|
network_module = importlib.import_module(args.network_module)
|
||||||
|
|
||||||
if args.base_weights is not None:
|
if args.base_weights is not None:
|
||||||
@@ -176,14 +175,15 @@ def train(args):
|
|||||||
else:
|
else:
|
||||||
multiplier = args.base_weights_multiplier[i]
|
multiplier = args.base_weights_multiplier[i]
|
||||||
|
|
||||||
print(f"merging module: {weight_path} with multiplier {multiplier}")
|
accelerator.print(f"merging module: {weight_path} with multiplier {multiplier}")
|
||||||
|
|
||||||
module, weights_sd = network_module.create_network_from_weights(
|
module, weights_sd = network_module.create_network_from_weights(
|
||||||
multiplier, weight_path, vae, text_encoder, unet, for_inference=True
|
multiplier, weight_path, vae, text_encoder, unet, for_inference=True
|
||||||
)
|
)
|
||||||
module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu")
|
module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu")
|
||||||
|
|
||||||
print(f"all weights merged: {', '.join(args.base_weights)}")
|
accelerator.print(f"all weights merged: {', '.join(args.base_weights)}")
|
||||||
|
|
||||||
# 学習を準備する
|
# 学習を準備する
|
||||||
if cache_latents:
|
if cache_latents:
|
||||||
vae.to(accelerator.device, dtype=weight_dtype)
|
vae.to(accelerator.device, dtype=weight_dtype)
|
||||||
@@ -225,7 +225,7 @@ def train(args):
|
|||||||
|
|
||||||
if args.network_weights is not None:
|
if args.network_weights is not None:
|
||||||
info = network.load_weights(args.network_weights)
|
info = network.load_weights(args.network_weights)
|
||||||
print(f"loaded network weights from {args.network_weights}: {info}")
|
accelerator.print(f"load network weights from {args.network_weights}: {info}")
|
||||||
|
|
||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
unet.enable_gradient_checkpointing()
|
unet.enable_gradient_checkpointing()
|
||||||
@@ -233,13 +233,13 @@ def train(args):
|
|||||||
network.enable_gradient_checkpointing() # may have no effect
|
network.enable_gradient_checkpointing() # may have no effect
|
||||||
|
|
||||||
# 学習に必要なクラスを準備する
|
# 学習に必要なクラスを準備する
|
||||||
print("preparing optimizer, data loader etc.")
|
accelerator.print("prepare optimizer, data loader etc.")
|
||||||
|
|
||||||
# 後方互換性を確保するよ
|
# 後方互換性を確保するよ
|
||||||
try:
|
try:
|
||||||
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate)
|
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
print(
|
accelerator.print(
|
||||||
"Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)"
|
"Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)"
|
||||||
)
|
)
|
||||||
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
|
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
|
||||||
@@ -264,8 +264,7 @@ def train(args):
|
|||||||
args.max_train_steps = args.max_train_epochs * math.ceil(
|
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||||
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
||||||
)
|
)
|
||||||
if is_main_process:
|
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
|
||||||
|
|
||||||
# データセット側にも学習ステップを送信
|
# データセット側にも学習ステップを送信
|
||||||
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||||
@@ -278,7 +277,7 @@ def train(args):
|
|||||||
assert (
|
assert (
|
||||||
args.mixed_precision == "fp16"
|
args.mixed_precision == "fp16"
|
||||||
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
||||||
print("enabling full fp16 training.")
|
accelerator.print("enable full fp16 training.")
|
||||||
network.to(weight_dtype)
|
network.to(weight_dtype)
|
||||||
|
|
||||||
# acceleratorがなんかよろしくやってくれるらしい
|
# acceleratorがなんかよろしくやってくれるらしい
|
||||||
@@ -338,16 +337,15 @@ def train(args):
|
|||||||
# TODO: find a way to handle total batch size when there are multiple datasets
|
# TODO: find a way to handle total batch size when there are multiple datasets
|
||||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||||
|
|
||||||
if is_main_process:
|
accelerator.print("running training / 学習開始")
|
||||||
print("running training / 学習開始")
|
accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
||||||
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
||||||
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||||
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
||||||
print(f" num epochs / epoch数: {num_train_epochs}")
|
accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
|
||||||
print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
|
# accelerator.print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
||||||
# print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
||||||
print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
||||||
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
|
||||||
|
|
||||||
# TODO refactor metadata creation and move to util
|
# TODO refactor metadata creation and move to util
|
||||||
metadata = {
|
metadata = {
|
||||||
@@ -572,7 +570,7 @@ def train(args):
|
|||||||
os.makedirs(args.output_dir, exist_ok=True)
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||||
|
|
||||||
print(f"\nsaving checkpoint: {ckpt_file}")
|
accelerator.print(f"\nsaving checkpoint: {ckpt_file}")
|
||||||
metadata["ss_training_finished_at"] = str(time.time())
|
metadata["ss_training_finished_at"] = str(time.time())
|
||||||
metadata["ss_steps"] = str(steps)
|
metadata["ss_steps"] = str(steps)
|
||||||
metadata["ss_epoch"] = str(epoch_no)
|
metadata["ss_epoch"] = str(epoch_no)
|
||||||
@@ -584,13 +582,12 @@ def train(args):
|
|||||||
def remove_model(old_ckpt_name):
|
def remove_model(old_ckpt_name):
|
||||||
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
|
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
|
||||||
if os.path.exists(old_ckpt_file):
|
if os.path.exists(old_ckpt_file):
|
||||||
print(f"removing old checkpoint: {old_ckpt_file}")
|
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
|
||||||
os.remove(old_ckpt_file)
|
os.remove(old_ckpt_file)
|
||||||
|
|
||||||
# training loop
|
# training loop
|
||||||
for epoch in range(num_train_epochs):
|
for epoch in range(num_train_epochs):
|
||||||
if is_main_process:
|
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||||
print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
|
||||||
current_epoch.value = epoch + 1
|
current_epoch.value = epoch + 1
|
||||||
|
|
||||||
metadata["ss_epoch"] = str(epoch + 1)
|
metadata["ss_epoch"] = str(epoch + 1)
|
||||||
|
|||||||
Reference in New Issue
Block a user