mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
Modify nn.MHA to attn with q/k/v
This commit is contained in:
@@ -20,7 +20,7 @@ import diffusers
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from library.ipex_interop import init_ipex
|
from library.device_utils import init_ipex, clean_memory, get_preferred_device
|
||||||
|
|
||||||
init_ipex()
|
init_ipex()
|
||||||
|
|
||||||
@@ -338,7 +338,7 @@ class PipelineLike:
|
|||||||
self.clip_vision_model: CLIPVisionModelWithProjection = None
|
self.clip_vision_model: CLIPVisionModelWithProjection = None
|
||||||
self.clip_vision_processor: CLIPImageProcessor = None
|
self.clip_vision_processor: CLIPImageProcessor = None
|
||||||
self.clip_vision_strength = 0.0
|
self.clip_vision_strength = 0.0
|
||||||
|
|
||||||
# Textual Inversion
|
# Textual Inversion
|
||||||
self.token_replacements_list = []
|
self.token_replacements_list = []
|
||||||
for _ in range(len(self.text_encoders)):
|
for _ in range(len(self.text_encoders)):
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
import math
|
import math
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
from einops import rearrange
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@@ -148,7 +149,7 @@ class EfficientNetEncoder(nn.Module):
|
|||||||
The method to make it usable like VAE. It should be separated properly, but it is a temporary response.
|
The method to make it usable like VAE. It should be separated properly, but it is a temporary response.
|
||||||
"""
|
"""
|
||||||
# latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
|
# latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
|
||||||
|
|
||||||
# x is -1 to 1, so we need to convert it to 0 to 1, and then preprocess it with EfficientNet's preprocessing.
|
# x is -1 to 1, so we need to convert it to 0 to 1, and then preprocess it with EfficientNet's preprocessing.
|
||||||
x = (x + 1) / 2
|
x = (x + 1) / 2
|
||||||
x = EFFNET_PREPROCESS(x)
|
x = EFFNET_PREPROCESS(x)
|
||||||
@@ -172,6 +173,7 @@ from torch.nn import Conv2d
|
|||||||
from torch.nn import Linear
|
from torch.nn import Linear
|
||||||
|
|
||||||
|
|
||||||
|
r"""
|
||||||
class Attention2D(nn.Module):
|
class Attention2D(nn.Module):
|
||||||
def __init__(self, c, nhead, dropout=0.0):
|
def __init__(self, c, nhead, dropout=0.0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -185,6 +187,119 @@ class Attention2D(nn.Module):
|
|||||||
x = self.attn(x, kv, kv, need_weights=False)[0]
|
x = self.attn(x, kv, kv, need_weights=False)[0]
|
||||||
x = x.permute(0, 2, 1).view(*orig_shape)
|
x = x.permute(0, 2, 1).view(*orig_shape)
|
||||||
return x
|
return x
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(self, c, nhead, dropout=0.0):
|
||||||
|
# dropout is for attn_output_weights, so we may not need it. however, if we use sdpa, we enable it.
|
||||||
|
# xformers and normal attn are not affected by dropout
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.to_q = Linear(c, c, bias=True)
|
||||||
|
self.to_k = Linear(c, c, bias=True)
|
||||||
|
self.to_v = Linear(c, c, bias=True)
|
||||||
|
self.to_out = Linear(c, c, bias=True)
|
||||||
|
self.nhead = nhead
|
||||||
|
self.dropout = dropout
|
||||||
|
self.scale = (c // nhead) ** -0.5
|
||||||
|
|
||||||
|
# default is to use sdpa
|
||||||
|
self.use_memory_efficient_attention_xformers = False
|
||||||
|
self.use_sdpa = True
|
||||||
|
|
||||||
|
def set_use_xformers_or_sdpa(self, xformers, sdpa):
|
||||||
|
# print(f"Attention: set_use_xformers_or_sdpa: xformers={xformers}, sdpa={sdpa}")
|
||||||
|
self.use_memory_efficient_attention_xformers = xformers
|
||||||
|
self.use_sdpa = sdpa
|
||||||
|
|
||||||
|
def forward(self, q_in, k_in, v_in):
|
||||||
|
q_in = self.to_q(q_in)
|
||||||
|
k_in = self.to_k(k_in)
|
||||||
|
v_in = self.to_v(v_in)
|
||||||
|
|
||||||
|
if self.use_memory_efficient_attention_xformers:
|
||||||
|
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=self.nhead), (q_in, k_in, v_in))
|
||||||
|
del q_in, k_in, v_in
|
||||||
|
out = self.forward_memory_efficient_xformers(q, k, v)
|
||||||
|
del q, k, v
|
||||||
|
out = rearrange(out, "b n h d -> b n (h d)", h=self.nhead)
|
||||||
|
elif self.use_sdpa:
|
||||||
|
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.nhead), (q_in, k_in, v_in))
|
||||||
|
del q_in, k_in, v_in
|
||||||
|
out = self.forward_sdpa(q, k, v)
|
||||||
|
del q, k, v
|
||||||
|
out = rearrange(out, "b h n d -> b n (h d)", h=self.nhead)
|
||||||
|
else:
|
||||||
|
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=self.nhead), (q_in, k_in, v_in))
|
||||||
|
del q_in, k_in, v_in
|
||||||
|
out = self._attention(q, k, v)
|
||||||
|
del q, k, v
|
||||||
|
out = rearrange(out, "(b h) n d -> b n (h d)", h=self.nhead)
|
||||||
|
|
||||||
|
return self.to_out(out)
|
||||||
|
|
||||||
|
def _attention(self, query, key, value):
|
||||||
|
# if self.upcast_attention:
|
||||||
|
# query = query.float()
|
||||||
|
# key = key.float()
|
||||||
|
|
||||||
|
attention_scores = torch.baddbmm(
|
||||||
|
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
|
||||||
|
query,
|
||||||
|
key.transpose(-1, -2),
|
||||||
|
beta=0,
|
||||||
|
alpha=self.scale,
|
||||||
|
)
|
||||||
|
attention_probs = attention_scores.softmax(dim=-1)
|
||||||
|
|
||||||
|
# cast back to the original dtype
|
||||||
|
attention_probs = attention_probs.to(value.dtype)
|
||||||
|
|
||||||
|
# compute attention output
|
||||||
|
hidden_states = torch.bmm(attention_probs, value)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def forward_memory_efficient_xformers(self, q, k, v):
|
||||||
|
import xformers.ops
|
||||||
|
|
||||||
|
q = q.contiguous()
|
||||||
|
k = k.contiguous()
|
||||||
|
v = v.contiguous()
|
||||||
|
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
|
||||||
|
del q, k, v
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
def forward_sdpa(self, q, k, v):
|
||||||
|
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout, is_causal=False)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Attention2D(nn.Module):
|
||||||
|
r"""
|
||||||
|
to_q/k/v を個別に重みをもつように変更
|
||||||
|
modified to have separate weights for to_q/k/v
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, c, nhead, dropout=0.0):
|
||||||
|
super().__init__()
|
||||||
|
# self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True)
|
||||||
|
self.attn = Attention(c, nhead, dropout=dropout) # , bias=True, batch_first=True)
|
||||||
|
|
||||||
|
def forward(self, x, kv, self_attn=False):
|
||||||
|
orig_shape = x.shape
|
||||||
|
x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
|
||||||
|
if self_attn:
|
||||||
|
kv = torch.cat([x, kv], dim=1)
|
||||||
|
# x = self.attn(x, kv, kv, need_weights=False)[0]
|
||||||
|
x = self.attn(x, kv, kv)
|
||||||
|
x = x.permute(0, 2, 1).view(*orig_shape)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def set_use_xformers_or_sdpa(self, xformers, sdpa):
|
||||||
|
self.attn.set_use_xformers_or_sdpa(xformers, sdpa)
|
||||||
|
|
||||||
|
|
||||||
class LayerNorm2d(nn.LayerNorm):
|
class LayerNorm2d(nn.LayerNorm):
|
||||||
@@ -262,6 +377,9 @@ class AttnBlock(nn.Module):
|
|||||||
def set_gradient_checkpointing(self, value):
|
def set_gradient_checkpointing(self, value):
|
||||||
self.gradient_checkpointing = value
|
self.gradient_checkpointing = value
|
||||||
|
|
||||||
|
def set_use_xformers_or_sdpa(self, xformers, sdpa):
|
||||||
|
self.attention.set_use_xformers_or_sdpa(xformers, sdpa)
|
||||||
|
|
||||||
def forward_body(self, x, kv):
|
def forward_body(self, x, kv):
|
||||||
kv = self.kv_mapper(kv)
|
kv = self.kv_mapper(kv)
|
||||||
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
|
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
|
||||||
@@ -657,6 +775,12 @@ class StageB(nn.Module):
|
|||||||
if m.bias is not None:
|
if m.bias is not None:
|
||||||
nn.init.constant_(m.bias, 0)
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
def set_use_xformers_or_sdpa(self, xformers, sdpa):
|
||||||
|
for block in self.down_blocks + self.up_blocks:
|
||||||
|
for layer in block:
|
||||||
|
if hasattr(layer, "set_use_xformers_or_sdpa"):
|
||||||
|
layer.set_use_xformers_or_sdpa(xformers, sdpa)
|
||||||
|
|
||||||
def gen_r_embedding(self, r, max_positions=10000):
|
def gen_r_embedding(self, r, max_positions=10000):
|
||||||
r = r * max_positions
|
r = r * max_positions
|
||||||
half_dim = self.c_r // 2
|
half_dim = self.c_r // 2
|
||||||
@@ -920,6 +1044,12 @@ class StageC(nn.Module):
|
|||||||
if hasattr(layer, "set_gradient_checkpointing"):
|
if hasattr(layer, "set_gradient_checkpointing"):
|
||||||
layer.set_gradient_checkpointing(value)
|
layer.set_gradient_checkpointing(value)
|
||||||
|
|
||||||
|
def set_use_xformers_or_sdpa(self, xformers, sdpa):
|
||||||
|
for block in self.down_blocks + self.up_blocks:
|
||||||
|
for layer in block:
|
||||||
|
if hasattr(layer, "set_use_xformers_or_sdpa"):
|
||||||
|
layer.set_use_xformers_or_sdpa(xformers, sdpa)
|
||||||
|
|
||||||
def gen_r_embedding(self, r, max_positions=10000):
|
def gen_r_embedding(self, r, max_positions=10000):
|
||||||
r = r * max_positions
|
r = r * max_positions
|
||||||
half_dim = self.c_r // 2
|
half_dim = self.c_r // 2
|
||||||
|
|||||||
@@ -103,6 +103,9 @@ def load_stage_c_model(stage_c_checkpoint_path, dtype=None, device="cpu") -> sc.
|
|||||||
generator_c = sc.StageC()
|
generator_c = sc.StageC()
|
||||||
logger.info(f"Loading Stage C generator from {stage_c_checkpoint_path}")
|
logger.info(f"Loading Stage C generator from {stage_c_checkpoint_path}")
|
||||||
stage_c_checkpoint = load_file(stage_c_checkpoint_path)
|
stage_c_checkpoint = load_file(stage_c_checkpoint_path)
|
||||||
|
|
||||||
|
stage_c_checkpoint = convert_state_dict_mha_to_normal_attn(stage_c_checkpoint)
|
||||||
|
|
||||||
logger.info(f"Loading state dict")
|
logger.info(f"Loading state dict")
|
||||||
info = _load_state_dict_on_device(generator_c, stage_c_checkpoint, device, dtype=dtype)
|
info = _load_state_dict_on_device(generator_c, stage_c_checkpoint, device, dtype=dtype)
|
||||||
logger.info(info)
|
logger.info(info)
|
||||||
@@ -115,6 +118,9 @@ def load_stage_b_model(stage_b_checkpoint_path, dtype=None, device="cpu") -> sc.
|
|||||||
generator_b = sc.StageB()
|
generator_b = sc.StageB()
|
||||||
logger.info(f"Loading Stage B generator from {stage_b_checkpoint_path}")
|
logger.info(f"Loading Stage B generator from {stage_b_checkpoint_path}")
|
||||||
stage_b_checkpoint = load_file(stage_b_checkpoint_path)
|
stage_b_checkpoint = load_file(stage_b_checkpoint_path)
|
||||||
|
|
||||||
|
stage_b_checkpoint = convert_state_dict_mha_to_normal_attn(stage_b_checkpoint)
|
||||||
|
|
||||||
logger.info(f"Loading state dict")
|
logger.info(f"Loading state dict")
|
||||||
info = _load_state_dict_on_device(generator_b, stage_b_checkpoint, device, dtype=dtype)
|
info = _load_state_dict_on_device(generator_b, stage_b_checkpoint, device, dtype=dtype)
|
||||||
logger.info(info)
|
logger.info(info)
|
||||||
@@ -189,6 +195,55 @@ def load_previewer_model(previewer_checkpoint_path, dtype=None, device="cpu") ->
|
|||||||
return previewer
|
return previewer
|
||||||
|
|
||||||
|
|
||||||
|
def convert_state_dict_mha_to_normal_attn(state_dict):
|
||||||
|
# convert nn.MultiheadAttention to to_q/k/v and to_out
|
||||||
|
print("convert_state_dict_mha_to_normal_attn")
|
||||||
|
for key in list(state_dict.keys()):
|
||||||
|
if "attention.attn." in key:
|
||||||
|
if "in_proj_bias" in key:
|
||||||
|
value = state_dict.pop(key)
|
||||||
|
qkv = torch.chunk(value, 3, dim=0)
|
||||||
|
state_dict[key.replace("in_proj_bias", "to_q.bias")] = qkv[0]
|
||||||
|
state_dict[key.replace("in_proj_bias", "to_k.bias")] = qkv[1]
|
||||||
|
state_dict[key.replace("in_proj_bias", "to_v.bias")] = qkv[2]
|
||||||
|
elif "in_proj_weight" in key:
|
||||||
|
value = state_dict.pop(key)
|
||||||
|
qkv = torch.chunk(value, 3, dim=0)
|
||||||
|
state_dict[key.replace("in_proj_weight", "to_q.weight")] = qkv[0]
|
||||||
|
state_dict[key.replace("in_proj_weight", "to_k.weight")] = qkv[1]
|
||||||
|
state_dict[key.replace("in_proj_weight", "to_v.weight")] = qkv[2]
|
||||||
|
elif "out_proj.bias" in key:
|
||||||
|
value = state_dict.pop(key)
|
||||||
|
state_dict[key.replace("out_proj.bias", "to_out.bias")] = value
|
||||||
|
elif "out_proj.weight" in key:
|
||||||
|
value = state_dict.pop(key)
|
||||||
|
state_dict[key.replace("out_proj.weight", "to_out.weight")] = value
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def convert_state_dict_normal_attn_to_mha(state_dict):
|
||||||
|
# convert to_q/k/v and to_out to nn.MultiheadAttention
|
||||||
|
for key in list(state_dict.keys()):
|
||||||
|
if "attention.attn." in key:
|
||||||
|
if "to_q.bias" in key:
|
||||||
|
q = state_dict.pop(key)
|
||||||
|
k = state_dict.pop(key.replace("to_q.bias", "to_k.bias"))
|
||||||
|
v = state_dict.pop(key.replace("to_q.bias", "to_v.bias"))
|
||||||
|
state_dict[key.replace("to_q.bias", "in_proj_bias")] = torch.cat([q, k, v])
|
||||||
|
elif "to_q.weight" in key:
|
||||||
|
q = state_dict.pop(key)
|
||||||
|
k = state_dict.pop(key.replace("to_q.weight", "to_k.weight"))
|
||||||
|
v = state_dict.pop(key.replace("to_q.weight", "to_v.weight"))
|
||||||
|
state_dict[key.replace("to_q.weight", "in_proj_weight")] = torch.cat([q, k, v])
|
||||||
|
elif "to_out.bias" in key:
|
||||||
|
v = state_dict.pop(key)
|
||||||
|
state_dict[key.replace("to_out.bias", "out_proj.bias")] = v
|
||||||
|
elif "to_out.weight" in key:
|
||||||
|
v = state_dict.pop(key)
|
||||||
|
state_dict[key.replace("to_out.weight", "out_proj.weight")] = v
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
def get_sai_model_spec(args, lora=False):
|
def get_sai_model_spec(args, lora=False):
|
||||||
timestamp = time.time()
|
timestamp = time.time()
|
||||||
|
|
||||||
@@ -230,6 +285,8 @@ def stage_c_saver_common(ckpt_file, stage_c, text_model, save_dtype, sai_metadat
|
|||||||
if save_dtype is not None:
|
if save_dtype is not None:
|
||||||
state_dict = {k: v.to(save_dtype) for k, v in state_dict.items()}
|
state_dict = {k: v.to(save_dtype) for k, v in state_dict.items()}
|
||||||
|
|
||||||
|
state_dict = convert_state_dict_normal_attn_to_mha(state_dict)
|
||||||
|
|
||||||
save_file(state_dict, ckpt_file, metadata=sai_metadata)
|
save_file(state_dict, ckpt_file, metadata=sai_metadata)
|
||||||
|
|
||||||
# save text model
|
# save text model
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ def main(file):
|
|||||||
|
|
||||||
for key, value in values:
|
for key, value in values:
|
||||||
value = value.to(torch.float32)
|
value = value.to(torch.float32)
|
||||||
logger.info(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}")
|
print(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}")
|
||||||
|
|
||||||
|
|
||||||
def setup_parser() -> argparse.ArgumentParser:
|
def setup_parser() -> argparse.ArgumentParser:
|
||||||
|
|||||||
@@ -40,9 +40,15 @@ def main(args):
|
|||||||
|
|
||||||
generator_c = sc_utils.load_stage_c_model(args.stage_c_checkpoint_path, dtype=dtype, device=loading_device)
|
generator_c = sc_utils.load_stage_c_model(args.stage_c_checkpoint_path, dtype=dtype, device=loading_device)
|
||||||
generator_c.eval().requires_grad_(False).to(loading_device)
|
generator_c.eval().requires_grad_(False).to(loading_device)
|
||||||
|
# if args.xformers or args.sdpa:
|
||||||
|
print(f"Stage C: use_xformers_or_sdpa: {args.xformers} {args.sdpa}")
|
||||||
|
generator_c.set_use_xformers_or_sdpa(args.xformers, args.sdpa)
|
||||||
|
|
||||||
generator_b = sc_utils.load_stage_b_model(args.stage_b_checkpoint_path, dtype=dtype, device=loading_device)
|
generator_b = sc_utils.load_stage_b_model(args.stage_b_checkpoint_path, dtype=dtype, device=loading_device)
|
||||||
generator_b.eval().requires_grad_(False).to(loading_device)
|
generator_b.eval().requires_grad_(False).to(loading_device)
|
||||||
|
# if args.xformers or args.sdpa:
|
||||||
|
print(f"Stage B: use_xformers_or_sdpa: {args.xformers} {args.sdpa}")
|
||||||
|
generator_b.set_use_xformers_or_sdpa(args.xformers, args.sdpa)
|
||||||
|
|
||||||
# CLIP encoders
|
# CLIP encoders
|
||||||
tokenizer = sc_utils.load_tokenizer(args)
|
tokenizer = sc_utils.load_tokenizer(args)
|
||||||
@@ -332,6 +338,8 @@ if __name__ == "__main__":
|
|||||||
sc_utils.add_text_model_arguments(parser)
|
sc_utils.add_text_model_arguments(parser)
|
||||||
parser.add_argument("--bf16", action="store_true")
|
parser.add_argument("--bf16", action="store_true")
|
||||||
parser.add_argument("--fp16", action="store_true")
|
parser.add_argument("--fp16", action="store_true")
|
||||||
|
parser.add_argument("--xformers", action="store_true")
|
||||||
|
parser.add_argument("--sdpa", action="store_true")
|
||||||
parser.add_argument("--outdir", type=str, default="../outputs", help="dir to write results to / 生成画像の出力先")
|
parser.add_argument("--outdir", type=str, default="../outputs", help="dir to write results to / 生成画像の出力先")
|
||||||
parser.add_argument("--lowvram", action="store_true", help="if specified, use low VRAM mode")
|
parser.add_argument("--lowvram", action="store_true", help="if specified, use low VRAM mode")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
@@ -295,10 +295,9 @@ class NetworkTrainer:
|
|||||||
# text_encoder is List[CLIPTextModel] or CLIPTextModel
|
# text_encoder is List[CLIPTextModel] or CLIPTextModel
|
||||||
text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder]
|
text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder]
|
||||||
|
|
||||||
# # モデルに xformers とか memory efficient attention を組み込む
|
# モデルに xformers とか memory efficient attention を組み込む
|
||||||
# train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
# train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
||||||
# if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
|
stage_c.set_use_xformers_or_sdpa(args.xformers, args.sdpa)
|
||||||
# vae.set_use_memory_efficient_attention_xformers(args.xformers)
|
|
||||||
|
|
||||||
# 差分追加学習のためにモデルを読み込む
|
# 差分追加学習のためにモデルを読み込む
|
||||||
sys.path.append(os.path.dirname(__file__))
|
sys.path.append(os.path.dirname(__file__))
|
||||||
|
|||||||
@@ -147,6 +147,9 @@ def train(args):
|
|||||||
else:
|
else:
|
||||||
previewer = None
|
previewer = None
|
||||||
|
|
||||||
|
# モデルに xformers とか memory efficient attention を組み込む
|
||||||
|
stage_c.set_use_xformers_or_sdpa(args.xformers, args.sdpa)
|
||||||
|
|
||||||
# 学習を準備する
|
# 学習を準備する
|
||||||
if cache_latents:
|
if cache_latents:
|
||||||
effnet.to(accelerator.device, dtype=effnet_dtype)
|
effnet.to(accelerator.device, dtype=effnet_dtype)
|
||||||
|
|||||||
Reference in New Issue
Block a user