fix sdxl timestep embedding

This commit is contained in:
Kohya S
2024-03-15 21:05:00 +09:00
parent 2d7389185c
commit f811b115ba
2 changed files with 15 additions and 3 deletions

View File

@@ -31,8 +31,10 @@ from torch import nn
from torch.nn import functional as F
from einops import rearrange
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
IN_CHANNELS: int = 4
@@ -1074,7 +1076,7 @@ class SdxlUNet2DConditionModel(nn.Module):
timesteps = timesteps.expand(x.shape[0])
hs = []
t_emb = get_timestep_embedding(timesteps, self.model_channels) # , repeat_only=False)
t_emb = get_timestep_embedding(timesteps, self.model_channels, downscale_freq_shift=0) # , repeat_only=False)
t_emb = t_emb.to(x.dtype)
emb = self.time_embed(t_emb)
@@ -1132,7 +1134,7 @@ class InferSdxlUNet2DConditionModel:
# call original model's methods
def __getattr__(self, name):
return getattr(self.delegate, name)
def __call__(self, *args, **kwargs):
return self.delegate(*args, **kwargs)
@@ -1164,7 +1166,7 @@ class InferSdxlUNet2DConditionModel:
timesteps = timesteps.expand(x.shape[0])
hs = []
t_emb = get_timestep_embedding(timesteps, _self.model_channels) # , repeat_only=False)
t_emb = get_timestep_embedding(timesteps, _self.model_channels, downscale_freq_shift=0) # , repeat_only=False)
t_emb = t_emb.to(x.dtype)
emb = _self.time_embed(t_emb)