make separate U-Net for inference

This commit is contained in:
Kohya S
2023-11-26 18:11:30 +09:00
parent fc8649d80f
commit c61e3bf4c9
4 changed files with 366 additions and 82 deletions

View File

@@ -105,7 +105,7 @@ import library.train_util as train_util
from networks.lora import LoRANetwork from networks.lora import LoRANetwork
import tools.original_control_net as original_control_net import tools.original_control_net as original_control_net
from tools.original_control_net import ControlNetInfo from tools.original_control_net import ControlNetInfo
from library.original_unet import UNet2DConditionModel from library.original_unet import UNet2DConditionModel, InferUNet2DConditionModel
from library.original_unet import FlashAttentionFunction from library.original_unet import FlashAttentionFunction
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
@@ -378,7 +378,7 @@ class PipelineLike:
vae: AutoencoderKL, vae: AutoencoderKL,
text_encoder: CLIPTextModel, text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer, tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel, unet: InferUNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
clip_skip: int, clip_skip: int,
clip_model: CLIPModel, clip_model: CLIPModel,
@@ -2196,6 +2196,7 @@ def main(args):
) )
original_unet.load_state_dict(unet.state_dict()) original_unet.load_state_dict(unet.state_dict())
unet = original_unet unet = original_unet
unet: InferUNet2DConditionModel = InferUNet2DConditionModel(unet)
# VAEを読み込む # VAEを読み込む
if args.vae is not None: if args.vae is not None:
@@ -2352,13 +2353,20 @@ def main(args):
vae = sli_vae vae = sli_vae
del sli_vae del sli_vae
vae.to(dtype).to(device) vae.to(dtype).to(device)
vae.eval()
text_encoder.to(dtype).to(device) text_encoder.to(dtype).to(device)
unet.to(dtype).to(device) unet.to(dtype).to(device)
text_encoder.eval()
unet.eval()
if clip_model is not None: if clip_model is not None:
clip_model.to(dtype).to(device) clip_model.to(dtype).to(device)
clip_model.eval()
if vgg16_model is not None: if vgg16_model is not None:
vgg16_model.to(dtype).to(device) vgg16_model.to(dtype).to(device)
vgg16_model.eval()
# networkを組み込む # networkを組み込む
if args.network_module: if args.network_module:

View File

@@ -1148,10 +1148,6 @@ class UpBlock2D(nn.Module):
res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# Deep Shrink
if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]:
hidden_states = resize_like(hidden_states, res_hidden_states)
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing: if self.training and self.gradient_checkpointing:
@@ -1244,10 +1240,6 @@ class CrossAttnUpBlock2D(nn.Module):
res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# Deep Shrink
if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]:
hidden_states = resize_like(hidden_states, res_hidden_states)
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing: if self.training and self.gradient_checkpointing:
@@ -1444,31 +1436,6 @@ class UNet2DConditionModel(nn.Module):
self.conv_act = nn.SiLU() self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1) self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1)
# Deep Shrink
self.ds_depth_1 = None
self.ds_depth_2 = None
self.ds_timesteps_1 = None
self.ds_timesteps_2 = None
self.ds_ratio = None
def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
if ds_depth_1 is None:
print("Deep Shrink is disabled.")
self.ds_depth_1 = None
self.ds_timesteps_1 = None
self.ds_depth_2 = None
self.ds_timesteps_2 = None
self.ds_ratio = None
else:
print(
f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
)
self.ds_depth_1 = ds_depth_1
self.ds_timesteps_1 = ds_timesteps_1
self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1
self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
self.ds_ratio = ds_ratio
# region diffusers compatibility # region diffusers compatibility
def prepare_config(self): def prepare_config(self):
self.config = SimpleNamespace() self.config = SimpleNamespace()
@@ -1572,20 +1539,7 @@ class UNet2DConditionModel(nn.Module):
sample = self.conv_in(sample) sample = self.conv_in(sample)
down_block_res_samples = (sample,) down_block_res_samples = (sample,)
for depth, downsample_block in enumerate(self.down_blocks): for downsample_block in self.down_blocks:
# Deep Shrink
if self.ds_depth_1 is not None:
if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or (
self.ds_depth_2 is not None
and depth == self.ds_depth_2
and timesteps[0] < self.ds_timesteps_1
and timesteps[0] >= self.ds_timesteps_2
):
org_dtype = sample.dtype
if org_dtype == torch.bfloat16:
sample = sample.to(torch.float32)
sample = F.interpolate(sample, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype)
# downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、 # downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、
# まあこちらのほうがわかりやすいかもしれない # まあこちらのほうがわかりやすいかもしれない
if downsample_block.has_cross_attention: if downsample_block.has_cross_attention:
@@ -1668,3 +1622,255 @@ class UNet2DConditionModel(nn.Module):
timesteps = timesteps.expand(sample.shape[0]) timesteps = timesteps.expand(sample.shape[0])
return timesteps return timesteps
class InferUNet2DConditionModel:
def __init__(self, original_unet: UNet2DConditionModel):
self.delegate = original_unet
# override original model's forward method: because forward is not called by `__call__`
# overriding `__call__` is not enough, because nn.Module.forward has a special handling
self.delegate.forward = self.forward
# override original model's up blocks' forward method
for up_block in self.delegate.up_blocks:
if up_block.__class__.__name__ == "UpBlock2D":
def resnet_wrapper(func, block):
def forward(*args, **kwargs):
return func(block, *args, **kwargs)
return forward
up_block.forward = resnet_wrapper(self.up_block_forward, up_block)
elif up_block.__class__.__name__ == "CrossAttnUpBlock2D":
def cross_attn_up_wrapper(func, block):
def forward(*args, **kwargs):
return func(block, *args, **kwargs)
return forward
up_block.forward = cross_attn_up_wrapper(self.cross_attn_up_block_forward, up_block)
# Deep Shrink
self.ds_depth_1 = None
self.ds_depth_2 = None
self.ds_timesteps_1 = None
self.ds_timesteps_2 = None
self.ds_ratio = None
# call original model's methods
def __getattr__(self, name):
return getattr(self.delegate, name)
def __call__(self, *args, **kwargs):
return self.delegate(*args, **kwargs)
def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
if ds_depth_1 is None:
print("Deep Shrink is disabled.")
self.ds_depth_1 = None
self.ds_timesteps_1 = None
self.ds_depth_2 = None
self.ds_timesteps_2 = None
self.ds_ratio = None
else:
print(
f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
)
self.ds_depth_1 = ds_depth_1
self.ds_timesteps_1 = ds_timesteps_1
self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1
self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
self.ds_ratio = ds_ratio
def up_block_forward(self, _self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
for resnet in _self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# Deep Shrink
if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]:
hidden_states = resize_like(hidden_states, res_hidden_states)
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
if _self.upsamplers is not None:
for upsampler in _self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
def cross_attn_up_block_forward(
self,
_self,
hidden_states,
res_hidden_states_tuple,
temb=None,
encoder_hidden_states=None,
upsample_size=None,
):
for resnet, attn in zip(_self.resnets, _self.attentions):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# Deep Shrink
if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]:
hidden_states = resize_like(hidden_states, res_hidden_states)
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
if _self.upsamplers is not None:
for upsampler in _self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
class_labels: Optional[torch.Tensor] = None,
return_dict: bool = True,
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
mid_block_additional_residual: Optional[torch.Tensor] = None,
) -> Union[Dict, Tuple]:
r"""
current implementation is a copy of `UNet2DConditionModel.forward()` with Deep Shrink.
"""
r"""
Args:
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a dict instead of a plain tuple.
Returns:
`SampleOutput` or `tuple`:
`SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
"""
_self = self.delegate
# By default samples have to be AT least a multiple of the overall upsampling factor.
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
# However, the upsampling interpolation output size can be forced to fit any upsampling size
# on the fly if necessary.
# デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある
# ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する
# 多分画質が悪くなるので、64で割り切れるようにしておくのが良い
default_overall_up_factor = 2**_self.num_upsamplers
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
# 64で割り切れないときはupsamplerにサイズを伝える
forward_upsample_size = False
upsample_size = None
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
# logger.info("Forward upsample size to force interpolation output size.")
forward_upsample_size = True
# 1. time
timesteps = timestep
timesteps = _self.handle_unusual_timesteps(sample, timesteps) # 変な時だけ処理
t_emb = _self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
# timestepsは重みを含まないので常にfloat32のテンソルを返す
# しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある
# time_projでキャストしておけばいいんじゃね
t_emb = t_emb.to(dtype=_self.dtype)
emb = _self.time_embedding(t_emb)
# 2. pre-process
sample = _self.conv_in(sample)
down_block_res_samples = (sample,)
for depth, downsample_block in enumerate(_self.down_blocks):
# Deep Shrink
if self.ds_depth_1 is not None:
if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or (
self.ds_depth_2 is not None
and depth == self.ds_depth_2
and timesteps[0] < self.ds_timesteps_1
and timesteps[0] >= self.ds_timesteps_2
):
org_dtype = sample.dtype
if org_dtype == torch.bfloat16:
sample = sample.to(torch.float32)
sample = F.interpolate(sample, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype)
# downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、
# まあこちらのほうがわかりやすいかもしれない
if downsample_block.has_cross_attention:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
down_block_res_samples += res_samples
# skip connectionにControlNetの出力を追加する
if down_block_additional_residuals is not None:
down_block_res_samples = list(down_block_res_samples)
for i in range(len(down_block_res_samples)):
down_block_res_samples[i] += down_block_additional_residuals[i]
down_block_res_samples = tuple(down_block_res_samples)
# 4. mid
sample = _self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
# ControlNetの出力を追加する
if mid_block_additional_residual is not None:
sample += mid_block_additional_residual
# 5. up
for i, upsample_block in enumerate(_self.up_blocks):
is_final_block = i == len(_self.up_blocks) - 1
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # skip connection
# if we have not reached the final block and need to forward the upsample size, we do it here
# 前述のように最後のブロック以外ではupsample_sizeを伝える
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]
if upsample_block.has_cross_attention:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
upsample_size=upsample_size,
)
else:
sample = upsample_block(
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
)
# 6. post-process
sample = _self.conv_norm_out(sample)
sample = _self.conv_act(sample)
sample = _self.conv_out(sample)
if not return_dict:
return (sample,)
return SampleOutput(sample=sample)

View File

@@ -24,7 +24,7 @@
import math import math
from types import SimpleNamespace from types import SimpleNamespace
from typing import Optional from typing import Any, Optional
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
@@ -1013,31 +1013,6 @@ class SdxlUNet2DConditionModel(nn.Module):
[GroupNorm32(32, self.model_channels), nn.SiLU(), nn.Conv2d(self.model_channels, self.out_channels, 3, padding=1)] [GroupNorm32(32, self.model_channels), nn.SiLU(), nn.Conv2d(self.model_channels, self.out_channels, 3, padding=1)]
) )
# Deep Shrink
self.ds_depth_1 = None
self.ds_depth_2 = None
self.ds_timesteps_1 = None
self.ds_timesteps_2 = None
self.ds_ratio = None
def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
if ds_depth_1 is None:
print("Deep Shrink is disabled.")
self.ds_depth_1 = None
self.ds_timesteps_1 = None
self.ds_depth_2 = None
self.ds_timesteps_2 = None
self.ds_ratio = None
else:
print(
f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
)
self.ds_depth_1 = ds_depth_1
self.ds_timesteps_1 = ds_timesteps_1
self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1
self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
self.ds_ratio = ds_ratio
# region diffusers compatibility # region diffusers compatibility
def prepare_config(self): def prepare_config(self):
self.config = SimpleNamespace() self.config = SimpleNamespace()
@@ -1120,7 +1095,97 @@ class SdxlUNet2DConditionModel(nn.Module):
# h = x.type(self.dtype) # h = x.type(self.dtype)
h = x h = x
for depth, module in enumerate(self.input_blocks): for module in self.input_blocks:
h = call_module(module, h, emb, context)
hs.append(h)
h = call_module(self.middle_block, h, emb, context)
for module in self.output_blocks:
h = torch.cat([h, hs.pop()], dim=1)
h = call_module(module, h, emb, context)
h = h.type(x.dtype)
h = call_module(self.out, h, emb, context)
return h
class InferSdxlUNet2DConditionModel:
def __init__(self, original_unet: SdxlUNet2DConditionModel, **kwargs):
self.delegate = original_unet
# override original model's forward method: because forward is not called by `__call__`
# overriding `__call__` is not enough, because nn.Module.forward has a special handling
self.delegate.forward = self.forward
# Deep Shrink
self.ds_depth_1 = None
self.ds_depth_2 = None
self.ds_timesteps_1 = None
self.ds_timesteps_2 = None
self.ds_ratio = None
# call original model's methods
def __getattr__(self, name):
return getattr(self.delegate, name)
def __call__(self, *args, **kwargs):
return self.delegate(*args, **kwargs)
def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
if ds_depth_1 is None:
print("Deep Shrink is disabled.")
self.ds_depth_1 = None
self.ds_timesteps_1 = None
self.ds_depth_2 = None
self.ds_timesteps_2 = None
self.ds_ratio = None
else:
print(
f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
)
self.ds_depth_1 = ds_depth_1
self.ds_timesteps_1 = ds_timesteps_1
self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1
self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
self.ds_ratio = ds_ratio
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
r"""
current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink.
"""
_self = self.delegate
# broadcast timesteps to batch dimension
timesteps = timesteps.expand(x.shape[0])
hs = []
t_emb = get_timestep_embedding(timesteps, _self.model_channels) # , repeat_only=False)
t_emb = t_emb.to(x.dtype)
emb = _self.time_embed(t_emb)
assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}"
assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}"
# assert x.dtype == _self.dtype
emb = emb + _self.label_emb(y)
def call_module(module, h, emb, context):
x = h
for layer in module:
# print(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None)
if isinstance(layer, ResnetBlock2D):
x = layer(x, emb)
elif isinstance(layer, Transformer2DModel):
x = layer(x, context)
else:
x = layer(x)
return x
# h = x.type(self.dtype)
h = x
for depth, module in enumerate(_self.input_blocks):
# Deep Shrink # Deep Shrink
if self.ds_depth_1 is not None: if self.ds_depth_1 is not None:
if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or ( if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or (
@@ -1138,9 +1203,9 @@ class SdxlUNet2DConditionModel(nn.Module):
h = call_module(module, h, emb, context) h = call_module(module, h, emb, context)
hs.append(h) hs.append(h)
h = call_module(self.middle_block, h, emb, context) h = call_module(_self.middle_block, h, emb, context)
for module in self.output_blocks: for module in _self.output_blocks:
# Deep Shrink # Deep Shrink
if self.ds_depth_1 is not None: if self.ds_depth_1 is not None:
if hs[-1].shape[-2:] != h.shape[-2:]: if hs[-1].shape[-2:] != h.shape[-2:]:
@@ -1156,7 +1221,7 @@ class SdxlUNet2DConditionModel(nn.Module):
h = resize_like(h, x) h = resize_like(h, x)
h = h.type(x.dtype) h = h.type(x.dtype)
h = call_module(self.out, h, emb, context) h = call_module(_self.out, h, emb, context)
return h return h

View File

@@ -57,7 +57,7 @@ import library.train_util as train_util
import library.sdxl_model_util as sdxl_model_util import library.sdxl_model_util as sdxl_model_util
import library.sdxl_train_util as sdxl_train_util import library.sdxl_train_util as sdxl_train_util
from networks.lora import LoRANetwork from networks.lora import LoRANetwork
from library.sdxl_original_unet import SdxlUNet2DConditionModel from library.sdxl_original_unet import InferSdxlUNet2DConditionModel
from library.original_unet import FlashAttentionFunction from library.original_unet import FlashAttentionFunction
from networks.control_net_lllite import ControlNetLLLite from networks.control_net_lllite import ControlNetLLLite
@@ -290,7 +290,7 @@ class PipelineLike:
vae: AutoencoderKL, vae: AutoencoderKL,
text_encoders: List[CLIPTextModel], text_encoders: List[CLIPTextModel],
tokenizers: List[CLIPTokenizer], tokenizers: List[CLIPTokenizer],
unet: SdxlUNet2DConditionModel, unet: InferSdxlUNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
clip_skip: int, clip_skip: int,
): ):
@@ -328,7 +328,7 @@ class PipelineLike:
self.vae = vae self.vae = vae
self.text_encoders = text_encoders self.text_encoders = text_encoders
self.tokenizers = tokenizers self.tokenizers = tokenizers
self.unet: SdxlUNet2DConditionModel = unet self.unet: InferSdxlUNet2DConditionModel = unet
self.scheduler = scheduler self.scheduler = scheduler
self.safety_checker = None self.safety_checker = None
@@ -1371,6 +1371,7 @@ def main(args):
(_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model( (_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model(
args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype
) )
unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel(unet)
# xformers、Hypernetwork対応 # xformers、Hypernetwork対応
if not args.diffusers_xformers: if not args.diffusers_xformers:
@@ -1526,10 +1527,14 @@ def main(args):
print("set vae_dtype to float32") print("set vae_dtype to float32")
vae_dtype = torch.float32 vae_dtype = torch.float32
vae.to(vae_dtype).to(device) vae.to(vae_dtype).to(device)
vae.eval()
text_encoder1.to(dtype).to(device) text_encoder1.to(dtype).to(device)
text_encoder2.to(dtype).to(device) text_encoder2.to(dtype).to(device)
unet.to(dtype).to(device) unet.to(dtype).to(device)
text_encoder1.eval()
text_encoder2.eval()
unet.eval()
# networkを組み込む # networkを組み込む
if args.network_module: if args.network_module: