implement FreeU

This commit is contained in:
Kohya S
2023-09-22 07:56:09 +09:00
parent db7a28ac25
commit 40525d4f4b
2 changed files with 84 additions and 1 deletions

View File

@@ -996,6 +996,76 @@ class SdxlUNet2DConditionModel(nn.Module):
[GroupNorm32(32, self.model_channels), nn.SiLU(), nn.Conv2d(self.model_channels, self.out_channels, 3, padding=1)]
)
# FreeU
self.freeU = False
self.freeUSl = 0.5
self.freeURThres = 0.5
self.freeUBl = 0.5
# implementation of FreeU
# FreeU: Free Lunch in Diffusion U-Net https://arxiv.org/abs/2309.11497
def set_free_u_enabled(self, enabled: bool, bl=0.5, sl=0.5, rthresh=0.5):
print(f"FreeU: {enabled}, bl={bl}, sl={sl}, rthresh={rthresh}")
self.freeU = enabled
self.freeUSl = sl
self.freeURThres = rthresh
self.freeUBl = bl
def spectral_modulation(self, skip_feature, sl=0.5, rthresh=0.5):
"""
スキップ特徴を周波数領域で修正する関数
:param skip_feature: スキップ特徴のテンソル [b, c, H, W]
:param sl: スケーリング係数
:param rthresh: 周波数の閾値
:return: 修正されたスキップ特徴
"""
import torch.fft
org_dtype = skip_feature.dtype
if org_dtype == torch.bfloat16:
skip_feature = skip_feature.to(torch.float32)
# FFTを計算
F = torch.fft.fftn(skip_feature, dim=(2, 3))
# 周波数領域での座標を計算
freq_x = torch.fft.fftfreq(skip_feature.size(2), d=1 / skip_feature.size(2)).to(skip_feature.device)
freq_y = torch.fft.fftfreq(skip_feature.size(3), d=1 / skip_feature.size(3)).to(skip_feature.device)
# 2Dグリッドを作成
freq_x = freq_x[:, None] # [H, 1]
freq_y = freq_y[None, :] # [1, W]
# ラジアス(距離)を計算
r = torch.sqrt(freq_x**2 + freq_y**2)
# 32,32: tensor(0., device='cuda:0') tensor(22.6274, device='cuda:0') tensor(12.2521, device='cuda:0')
# 64,64: tensor(0., device='cuda:0') tensor(45.2548, device='cuda:0') tensor(24.4908, device='cuda:0')
# 128,128: tensor(0., device='cuda:0') tensor(90.5097, device='cuda:0') tensor(48.9748, device='cuda:0')
# マスクを作成
mask = torch.ones_like(r)
mask[r < rthresh] = sl
# b,c,H,Wの形状にブロードキャスト
# TODO shapeごとに同じなのでキャッシュすると良さそう
mask = mask[None, None, :, :]
# 周波数領域での要素ごとの乗算
F_prime = F * mask
# 逆FFTを計算
modified_skip_feature = torch.fft.ifftn(F_prime, dim=(2, 3))
modified_skip_feature = modified_skip_feature.real # 実部のみを取得
if org_dtype == torch.bfloat16:
modified_skip_feature = modified_skip_feature.to(org_dtype)
return modified_skip_feature
# region diffusers compatibility
def prepare_config(self):
self.config = SimpleNamespace()
@@ -1079,11 +1149,20 @@ class SdxlUNet2DConditionModel(nn.Module):
h = x
for module in self.input_blocks:
h = call_module(module, h, emb, context)
hs.append(h)
if self.freeU:
h_mod = self.spectral_modulation(h, self.freeUSl, self.freeURThres)
hs.append(h_mod)
else:
hs.append(h)
h = call_module(self.middle_block, h, emb, context)
for module in self.output_blocks:
if self.freeU:
ch = h.shape[1]
h[:, : ch // 2] = h[:, : ch // 2] * self.freeUBl
h = torch.cat([h, hs.pop()], dim=1)
h = call_module(module, h, emb, context)

View File

@@ -1521,6 +1521,10 @@ def main(args):
text_encoder2.to(dtype).to(device)
unet.to(dtype).to(device)
# freeU
# unet.set_free_u_enabled(False, 1.0, 1.0, 0)
unet.set_free_u_enabled(True, 1.4, 1.0, 10)
# networkを組み込む
if args.network_module:
networks = []