mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
rename and update
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
import os
|
||||
from typing import Optional, List, Type
|
||||
import torch
|
||||
from networks.lora import LoRAModule, LoRANetwork
|
||||
from library import sdxl_original_unet
|
||||
|
||||
|
||||
@@ -21,6 +20,9 @@ TRANSFORMER_ONLY = True # if True, SKIP_CONV2D is ignored because conv2d is not
|
||||
# Trueならattn1とattn2にのみ適用し、ffなどには適用しない / if True, apply only to attn1 and attn2, not to ff etc.
|
||||
ATTN1_2_ONLY = True
|
||||
|
||||
# Trueならattn1のQKV、attn2のQにのみ適用する、ATTN1_2_ONLY指定時のみ有効 / if True, apply only to attn1 QKV and attn2 Q, only valid when ATTN1_2_ONLY is specified
|
||||
ATTN_QKV_ONLY = True
|
||||
|
||||
# Trueならattn1やffなどにのみ適用し、attn2などには適用しない / if True, apply only to attn1 and ff, not to attn2
|
||||
# ATTN1_2_ONLYと同時にTrueにできない / cannot be True at the same time as ATTN1_2_ONLY
|
||||
ATTN1_ETC_ONLY = False # True
|
||||
@@ -30,126 +32,159 @@ ATTN1_ETC_ONLY = False # True
|
||||
TRANSFORMER_MAX_BLOCK_INDEX = None
|
||||
|
||||
|
||||
class LoRAModuleControlNet(LoRAModule):
|
||||
def __init__(self, depth, cond_emb_dim, name, org_module, multiplier, lora_dim, alpha, dropout=None):
|
||||
super().__init__(name, org_module, multiplier, lora_dim, alpha, dropout=dropout)
|
||||
class LLLiteModule(torch.nn.Module):
|
||||
def __init__(self, depth, cond_emb_dim, name, org_module, mlp_dim, dropout=None):
|
||||
super().__init__()
|
||||
|
||||
self.is_conv2d = org_module.__class__.__name__ == "Conv2d"
|
||||
self.lllite_name = name
|
||||
self.cond_emb_dim = cond_emb_dim
|
||||
|
||||
# conditioning1は、conditioning image embeddingを、各LoRA的モジュールでさらに学習する。ここはtimestepごとに呼ばれない
|
||||
# それぞれのモジュールで異なる表現を学習することを期待している
|
||||
# conditioning1 learns conditioning image embedding in each LoRA-like module. this is not called for each timestep
|
||||
# we expect to learn different representations in each module
|
||||
|
||||
# conditioning2は、conditioning1の出力とLoRAの出力を結合し、LoRAの出力に加算する。timestepごとに呼ばれる
|
||||
# conditioning image embeddingとU-Netの出力を合わせて学ぶことで、conditioningに応じたU-Netの調整を行う
|
||||
# conditioning2 combines the output of conditioning1 and the output of LoRA, and adds it to the output of LoRA. this is called for each timestep
|
||||
# by learning the output of conditioning image embedding and U-Net together, we adjust U-Net according to conditioning
|
||||
self.org_module = [org_module]
|
||||
self.dropout = dropout
|
||||
|
||||
if self.is_conv2d:
|
||||
self.conditioning1 = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(cond_emb_dim, cond_emb_dim, kernel_size=3, stride=1, padding=0),
|
||||
torch.nn.ReLU(inplace=True),
|
||||
torch.nn.Conv2d(cond_emb_dim, cond_emb_dim, kernel_size=3, stride=1, padding=0),
|
||||
in_dim = org_module.in_channels
|
||||
else:
|
||||
in_dim = org_module.in_features
|
||||
|
||||
# conditioning1はconditioning imageを embedding する。timestepごとに呼ばれない
|
||||
# conditioning1 embeds conditioning image. it is not called for each timestep
|
||||
modules = []
|
||||
modules.append(torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) # to latent (from VAE) size
|
||||
if depth == 1:
|
||||
modules.append(torch.nn.ReLU(inplace=True))
|
||||
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
|
||||
elif depth == 2:
|
||||
modules.append(torch.nn.ReLU(inplace=True))
|
||||
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=4, stride=4, padding=0))
|
||||
elif depth == 3:
|
||||
# kernel size 8は大きすぎるので、4にする / kernel size 8 is too large, so set it to 4
|
||||
modules.append(torch.nn.ReLU(inplace=True))
|
||||
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0))
|
||||
modules.append(torch.nn.ReLU(inplace=True))
|
||||
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
|
||||
|
||||
self.conditioning1 = torch.nn.Sequential(*modules)
|
||||
|
||||
# downで入力の次元数を削減する。LoRAにヒントを得ていることにする
|
||||
# midでconditioning image embeddingと入力を結合する
|
||||
# upで元の次元数に戻す
|
||||
# これらはtimestepごとに呼ばれる
|
||||
# reduce the number of input dimensions with down. inspired by LoRA
|
||||
# combine conditioning image embedding and input with mid
|
||||
# restore to the original dimension with up
|
||||
# these are called for each timestep
|
||||
|
||||
if self.is_conv2d:
|
||||
self.down = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(in_dim, mlp_dim, kernel_size=1, stride=1, padding=0),
|
||||
torch.nn.ReLU(inplace=True),
|
||||
)
|
||||
self.conditioning2 = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(lora_dim + cond_emb_dim, cond_emb_dim, kernel_size=1, stride=1, padding=0),
|
||||
torch.nn.ReLU(inplace=True),
|
||||
torch.nn.Conv2d(cond_emb_dim, lora_dim, kernel_size=1, stride=1, padding=0),
|
||||
self.mid = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(mlp_dim + cond_emb_dim, mlp_dim, kernel_size=1, stride=1, padding=0),
|
||||
torch.nn.ReLU(inplace=True),
|
||||
)
|
||||
self.up = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(mlp_dim, in_dim, kernel_size=1, stride=1, padding=0),
|
||||
)
|
||||
else:
|
||||
self.conditioning1 = torch.nn.Sequential(
|
||||
torch.nn.Linear(cond_emb_dim, cond_emb_dim),
|
||||
torch.nn.ReLU(inplace=True),
|
||||
torch.nn.Linear(cond_emb_dim, cond_emb_dim),
|
||||
# midの前にconditioningをreshapeすること / reshape conditioning before mid
|
||||
self.down = torch.nn.Sequential(
|
||||
torch.nn.Linear(in_dim, mlp_dim),
|
||||
torch.nn.ReLU(inplace=True),
|
||||
)
|
||||
self.conditioning2 = torch.nn.Sequential(
|
||||
torch.nn.Linear(lora_dim + cond_emb_dim, cond_emb_dim),
|
||||
torch.nn.ReLU(inplace=True),
|
||||
torch.nn.Linear(cond_emb_dim, lora_dim),
|
||||
self.mid = torch.nn.Sequential(
|
||||
torch.nn.Linear(mlp_dim + cond_emb_dim, mlp_dim),
|
||||
torch.nn.ReLU(inplace=True),
|
||||
)
|
||||
self.up = torch.nn.Sequential(
|
||||
torch.nn.Linear(mlp_dim, in_dim),
|
||||
)
|
||||
|
||||
# Zero-Convにするならコメントを外す / uncomment if you want to use Zero-Conv
|
||||
# torch.nn.init.zeros_(self.conditioning2[-2].weight) # zero conv
|
||||
# Zero-Convにする / set to Zero-Conv
|
||||
torch.nn.init.zeros_(self.up[0].weight) # zero conv
|
||||
|
||||
self.depth = depth # 1~3
|
||||
self.cond_emb = None
|
||||
self.batch_cond_only = False # Trueなら推論時のcondにのみ適用する / if True, apply only to cond at inference
|
||||
self.use_zeros_for_batch_uncond = False # Trueならuncondのconditioningを0にする / if True, set uncond conditioning to 0
|
||||
|
||||
def set_cond_embs(self, cond_embs_4d, cond_embs_3d):
|
||||
# batch_cond_onlyとuse_zeros_for_batch_uncondはどちらも適用すると生成画像の色味がおかしくなるので実際には使えそうにない
|
||||
# Controlの種類によっては使えるかも
|
||||
# both batch_cond_only and use_zeros_for_batch_uncond make the color of the generated image strange, so it doesn't seem to be usable in practice
|
||||
# it may be available depending on the type of Control
|
||||
|
||||
def set_cond_image(self, cond_image):
|
||||
r"""
|
||||
中でモデルを呼び出すので必要ならwith torch.no_grad()で囲む
|
||||
/ call the model inside, so if necessary, surround it with torch.no_grad()
|
||||
"""
|
||||
# conv2dとlinearでshapeが違うので必要な方を選択 / select the required one because the shape is different for conv2d and linear
|
||||
cond_embs = cond_embs_4d if self.is_conv2d else cond_embs_3d
|
||||
|
||||
cond_emb = cond_embs[self.depth - 1]
|
||||
|
||||
# timestepごとに呼ばれないので、あらかじめ計算しておく / it is not called for each timestep, so calculate it in advance
|
||||
self.cond_emb = self.conditioning1(cond_emb)
|
||||
cx = self.conditioning1(cond_image)
|
||||
if not self.is_conv2d:
|
||||
# reshape / b,c,h,w -> b,h*w,c
|
||||
n, c, h, w = cx.shape
|
||||
cx = cx.view(n, c, h * w).permute(0, 2, 1)
|
||||
self.cond_emb = cx
|
||||
|
||||
def set_batch_cond_only(self, cond_only, zeros):
|
||||
self.batch_cond_only = cond_only
|
||||
self.use_zeros_for_batch_uncond = zeros
|
||||
|
||||
def apply_to(self):
|
||||
self.org_forward = self.org_module[0].forward
|
||||
self.org_module[0].forward = self.forward
|
||||
|
||||
def forward(self, x):
|
||||
if self.cond_emb is None:
|
||||
return self.org_forward(x)
|
||||
|
||||
# LoRA-Down
|
||||
lx = x
|
||||
if self.batch_cond_only:
|
||||
lx = lx[1::2] # cond only in inference
|
||||
|
||||
lx = self.lora_down(lx)
|
||||
|
||||
if self.dropout is not None and self.training:
|
||||
lx = torch.nn.functional.dropout(lx, p=self.dropout)
|
||||
|
||||
# conditioning image embeddingを結合 / combine conditioning image embedding
|
||||
r"""
|
||||
学習用の便利forward。元のモジュールのforwardを呼び出す
|
||||
/ convenient forward for training. call the forward of the original module
|
||||
"""
|
||||
cx = self.cond_emb
|
||||
|
||||
if not self.batch_cond_only and lx.shape[0] // 2 == cx.shape[0]: # inference only
|
||||
if not self.batch_cond_only and x.shape[0] // 2 == cx.shape[0]: # inference only
|
||||
cx = cx.repeat(2, 1, 1, 1) if self.is_conv2d else cx.repeat(2, 1, 1)
|
||||
if self.use_zeros_for_batch_uncond:
|
||||
cx[0::2] = 0.0 # uncond is zero
|
||||
# print(f"C {self.lora_name}, lx.shape={lx.shape}, cx.shape={cx.shape}")
|
||||
# print(f"C {self.lllite_name}, lx.shape={lx.shape}, cx.shape={cx.shape}")
|
||||
|
||||
# downで入力の次元数を削減し、conditioning image embeddingと結合する
|
||||
# 加算ではなくchannel方向に結合することで、うまいこと混ぜてくれることを期待している
|
||||
# down reduces the number of input dimensions and combines it with conditioning image embedding
|
||||
# we expect that it will mix well by combining in the channel direction instead of adding
|
||||
cx = torch.cat([cx, lx], dim=1 if self.is_conv2d else 2)
|
||||
cx = self.conditioning2(cx)
|
||||
|
||||
lx = lx + cx # lxはresidual的に加算される / lx is added residually
|
||||
cx = torch.cat([cx, self.down(x)], dim=1 if self.is_conv2d else 2)
|
||||
cx = self.mid(cx)
|
||||
|
||||
# LoRA-Up
|
||||
lx = self.lora_up(lx)
|
||||
if self.dropout is not None and self.training:
|
||||
cx = torch.nn.functional.dropout(cx, p=self.dropout)
|
||||
|
||||
# call original module
|
||||
x = self.org_forward(x)
|
||||
cx = self.up(cx)
|
||||
|
||||
# add LoRA
|
||||
# residualを加算する / add residual
|
||||
if self.batch_cond_only:
|
||||
x[1::2] += lx * self.multiplier * self.scale
|
||||
x[1::2] += cx
|
||||
else:
|
||||
x += lx * self.multiplier * self.scale
|
||||
# to_outを対象とすると、cloneがないと次のエラーが出る / if to_out is the target, the following error will occur without clone
|
||||
# RuntimeError: Output 0 of ReshapeAliasBackward0 is a view and is being modified inplace.
|
||||
# This view was created inside a custom Function ...
|
||||
# x = x.clone()
|
||||
|
||||
x += cx
|
||||
|
||||
x = self.org_forward(x) # ここで元のモジュールを呼び出す / call the original module here
|
||||
return x
|
||||
|
||||
|
||||
class LoRAControlNet(torch.nn.Module):
|
||||
class ControlNetLLLite(torch.nn.Module):
|
||||
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
||||
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
unet: sdxl_original_unet.SdxlUNet2DConditionModel,
|
||||
cond_emb_dim: int = 16,
|
||||
lora_dim: int = 16,
|
||||
alpha: float = 1,
|
||||
mlp_dim: int = 16,
|
||||
dropout: Optional[float] = None,
|
||||
varbose: Optional[bool] = False,
|
||||
) -> None:
|
||||
@@ -161,9 +196,9 @@ class LoRAControlNet(torch.nn.Module):
|
||||
target_replace_modules: List[torch.nn.Module],
|
||||
module_class: Type[object],
|
||||
) -> List[torch.nn.Module]:
|
||||
prefix = LoRANetwork.LORA_PREFIX_UNET
|
||||
prefix = "lllite_unet"
|
||||
|
||||
loras = []
|
||||
modules = []
|
||||
for name, module in root_module.named_modules():
|
||||
if module.__class__.__name__ in target_replace_modules:
|
||||
for child_name, child_module in module.named_modules():
|
||||
@@ -190,13 +225,13 @@ class LoRAControlNet(torch.nn.Module):
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
lora_name = prefix + "." + name + "." + child_name
|
||||
lora_name = lora_name.replace(".", "_")
|
||||
lllite_name = prefix + "." + name + "." + child_name
|
||||
lllite_name = lllite_name.replace(".", "_")
|
||||
|
||||
if TRANSFORMER_MAX_BLOCK_INDEX is not None:
|
||||
p = lora_name.find("transformer_blocks")
|
||||
p = lllite_name.find("transformer_blocks")
|
||||
if p >= 0:
|
||||
tf_index = int(lora_name[p:].split("_")[2])
|
||||
tf_index = int(lllite_name[p:].split("_")[2])
|
||||
if tf_index > TRANSFORMER_MAX_BLOCK_INDEX:
|
||||
continue
|
||||
|
||||
@@ -204,104 +239,63 @@ class LoRAControlNet(torch.nn.Module):
|
||||
# attn2のconditioning (CLIPからの入力) はshapeが違うので適用できない
|
||||
# time emb is not applied
|
||||
# attn2 conditioning (input from CLIP) cannot be applied because the shape is different
|
||||
if "emb_layers" in lora_name or ("attn2" in lora_name and ("to_k" in lora_name or "to_v" in lora_name)):
|
||||
if "emb_layers" in lllite_name or (
|
||||
"attn2" in lllite_name and ("to_k" in lllite_name or "to_v" in lllite_name)
|
||||
):
|
||||
continue
|
||||
|
||||
if ATTN1_2_ONLY:
|
||||
if not ("attn1" in lora_name or "attn2" in lora_name):
|
||||
if not ("attn1" in lllite_name or "attn2" in lllite_name):
|
||||
continue
|
||||
if ATTN_QKV_ONLY:
|
||||
if "to_out" in lllite_name:
|
||||
continue
|
||||
|
||||
if ATTN1_ETC_ONLY:
|
||||
if "proj_out" in lora_name:
|
||||
if "proj_out" in lllite_name:
|
||||
pass
|
||||
elif "attn1" in lora_name and ("to_k" in lora_name or "to_v" in lora_name or "to_out" in lora_name):
|
||||
elif "attn1" in lllite_name and (
|
||||
"to_k" in lllite_name or "to_v" in lllite_name or "to_out" in lllite_name
|
||||
):
|
||||
pass
|
||||
elif "ff_net_2" in lora_name:
|
||||
elif "ff_net_2" in lllite_name:
|
||||
pass
|
||||
else:
|
||||
continue
|
||||
|
||||
lora = module_class(
|
||||
module = module_class(
|
||||
depth,
|
||||
cond_emb_dim,
|
||||
lora_name,
|
||||
lllite_name,
|
||||
child_module,
|
||||
1.0,
|
||||
lora_dim,
|
||||
alpha,
|
||||
mlp_dim,
|
||||
dropout=dropout,
|
||||
)
|
||||
loras.append(lora)
|
||||
return loras
|
||||
modules.append(module)
|
||||
return modules
|
||||
|
||||
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
|
||||
target_modules = ControlNetLLLite.UNET_TARGET_REPLACE_MODULE
|
||||
if not TRANSFORMER_ONLY:
|
||||
target_modules = target_modules + LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||
target_modules = target_modules + ControlNetLLLite.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||
|
||||
# create module instances
|
||||
self.unet_loras: List[LoRAModuleControlNet] = create_modules(unet, target_modules, LoRAModuleControlNet)
|
||||
print(f"create ControlNet LoRA for U-Net: {len(self.unet_loras)} modules.")
|
||||
|
||||
# conditioning image embedding
|
||||
|
||||
# control画像そのままではLoRA的モジュールの入力にはサイズもチャネルも扱いにくいので、
|
||||
# 適切な潜在空間に変換する。ここでは、conditioning image embeddingと呼ぶ
|
||||
# ただcontrol画像自体にはあまり情報量はないので、conditioning image embeddingはわりと小さくてよいはず
|
||||
# また、conditioning image embeddingは、各LoRA的モジュールでさらに個別に学習する
|
||||
# depthに応じて3つのサイズを用意する
|
||||
|
||||
# conditioning image embedding is converted to an appropriate latent space
|
||||
# because the size and channels of the input to the LoRA-like module are difficult to handle
|
||||
# we call it conditioning image embedding
|
||||
# however, the control image itself does not have much information, so the conditioning image embedding should be small
|
||||
# conditioning image embedding is also learned individually in each LoRA-like module
|
||||
# prepare three sizes according to depth
|
||||
|
||||
self.cond_block0 = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0), # to latent (from VAE) size
|
||||
torch.nn.ReLU(inplace=True),
|
||||
torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=3, stride=2, padding=1),
|
||||
torch.nn.ReLU(inplace=True),
|
||||
)
|
||||
self.cond_block1 = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(cond_emb_dim, cond_emb_dim, kernel_size=3, stride=1, padding=1),
|
||||
torch.nn.ReLU(inplace=True),
|
||||
torch.nn.Conv2d(cond_emb_dim, cond_emb_dim, kernel_size=3, stride=2, padding=1),
|
||||
torch.nn.ReLU(inplace=True),
|
||||
)
|
||||
self.cond_block2 = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(cond_emb_dim, cond_emb_dim, kernel_size=3, stride=2, padding=1),
|
||||
torch.nn.ReLU(inplace=True),
|
||||
)
|
||||
self.unet_modules: List[LLLiteModule] = create_modules(unet, target_modules, LLLiteModule)
|
||||
print(f"create ControlNet LLLite for U-Net: {len(self.unet_modules)} modules.")
|
||||
|
||||
def forward(self, x):
|
||||
x = self.cond_block0(x)
|
||||
x0 = x
|
||||
x = self.cond_block1(x)
|
||||
x1 = x
|
||||
x = self.cond_block2(x)
|
||||
x2 = x
|
||||
return x # dummy
|
||||
|
||||
x_3d = [] # for Linear
|
||||
for x0 in [x0, x1, x2]:
|
||||
# b,c,h,w -> b,h*w,c
|
||||
n, c, h, w = x0.shape
|
||||
x0 = x0.view(n, c, h * w).permute(0, 2, 1)
|
||||
x_3d.append(x0)
|
||||
|
||||
return [x0, x1, x2], x_3d
|
||||
|
||||
def set_cond_embs(self, cond_embs_4d, cond_embs_3d):
|
||||
def set_cond_image(self, cond_image):
|
||||
r"""
|
||||
中でモデルを呼び出すので必要ならwith torch.no_grad()で囲む
|
||||
/ call the model inside, so if necessary, surround it with torch.no_grad()
|
||||
"""
|
||||
for lora in self.unet_loras:
|
||||
lora.set_cond_embs(cond_embs_4d, cond_embs_3d)
|
||||
for module in self.unet_modules:
|
||||
module.set_cond_image(cond_image)
|
||||
|
||||
def set_batch_cond_only(self, cond_only, zeros):
|
||||
for lora in self.unet_loras:
|
||||
lora.set_batch_cond_only(cond_only, zeros)
|
||||
for module in self.unet_modules:
|
||||
module.set_batch_cond_only(cond_only, zeros)
|
||||
|
||||
def load_weights(self, file):
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
@@ -315,10 +309,10 @@ class LoRAControlNet(torch.nn.Module):
|
||||
return info
|
||||
|
||||
def apply_to(self):
|
||||
print("applying LoRA for U-Net...")
|
||||
for lora in self.unet_loras:
|
||||
lora.apply_to()
|
||||
self.add_module(lora.lora_name, lora)
|
||||
print("applying LLLite for U-Net...")
|
||||
for module in self.unet_modules:
|
||||
module.apply_to()
|
||||
self.add_module(module.lllite_name, module)
|
||||
|
||||
# マージできるかどうかを返す
|
||||
def is_mergeable(self):
|
||||
@@ -367,16 +361,15 @@ class LoRAControlNet(torch.nn.Module):
|
||||
if __name__ == "__main__":
|
||||
# デバッグ用 / for debug
|
||||
|
||||
# これを指定しないとエラーが出てcond_blockが学習できない / if not specified, an error occurs and cond_block cannot be learned
|
||||
sdxl_original_unet.USE_REENTRANT = False
|
||||
# sdxl_original_unet.USE_REENTRANT = False
|
||||
|
||||
# test shape etc
|
||||
print("create unet")
|
||||
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
|
||||
unet.to("cuda").to(torch.float16)
|
||||
|
||||
print("create LoRA controlnet")
|
||||
control_net = LoRAControlNet(unet, 64, 32, 1)
|
||||
print("create ControlNet-LLLite")
|
||||
control_net = ControlNetLLLite(unet, 32, 64)
|
||||
control_net.apply_to()
|
||||
control_net.to("cuda")
|
||||
|
||||
@@ -414,6 +407,7 @@ if __name__ == "__main__":
|
||||
print("start training")
|
||||
steps = 10
|
||||
|
||||
sample_param = [p for p in control_net.named_parameters() if "up" in p[0]][0]
|
||||
for step in range(steps):
|
||||
print(f"step {step}")
|
||||
|
||||
@@ -425,8 +419,7 @@ if __name__ == "__main__":
|
||||
y = torch.randn(batch_size, sdxl_original_unet.ADM_IN_CHANNELS).cuda()
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=True):
|
||||
cond_embs_4d, cond_embs_3d = control_net(conditioning_image)
|
||||
control_net.set_cond_embs(cond_embs_4d, cond_embs_3d)
|
||||
control_net.set_cond_image(conditioning_image)
|
||||
|
||||
output = unet(x, t, ctx, y)
|
||||
target = torch.randn_like(output)
|
||||
@@ -436,3 +429,8 @@ if __name__ == "__main__":
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
print(sample_param)
|
||||
|
||||
# from safetensors.torch import save_file
|
||||
|
||||
# save_file(control_net.state_dict(), "logs/control_net.safetensors")
|
||||
Reference in New Issue
Block a user