Merge branch 'dev' into gradual_latent_hires_fix

This commit is contained in:
Kohya S
2023-11-26 18:12:24 +09:00
15 changed files with 457 additions and 141 deletions

View File

@@ -355,7 +355,7 @@ def train(args):
loss = loss.mean([1, 2, 3]) loss = loss.mean([1, 2, 3])
if args.min_snr_gamma: if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred: if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.debiased_estimation_loss: if args.debiased_estimation_loss:

View File

@@ -105,7 +105,7 @@ import library.train_util as train_util
from networks.lora import LoRANetwork from networks.lora import LoRANetwork
import tools.original_control_net as original_control_net import tools.original_control_net as original_control_net
from tools.original_control_net import ControlNetInfo from tools.original_control_net import ControlNetInfo
from library.original_unet import UNet2DConditionModel from library.original_unet import UNet2DConditionModel, InferUNet2DConditionModel
from library.original_unet import FlashAttentionFunction from library.original_unet import FlashAttentionFunction
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
@@ -378,7 +378,7 @@ class PipelineLike:
vae: AutoencoderKL, vae: AutoencoderKL,
text_encoder: CLIPTextModel, text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer, tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel, unet: InferUNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
clip_skip: int, clip_skip: int,
clip_model: CLIPModel, clip_model: CLIPModel,
@@ -2365,6 +2365,7 @@ def main(args):
) )
original_unet.load_state_dict(unet.state_dict()) original_unet.load_state_dict(unet.state_dict())
unet = original_unet unet = original_unet
unet: InferUNet2DConditionModel = InferUNet2DConditionModel(unet)
# VAEを読み込む # VAEを読み込む
if args.vae is not None: if args.vae is not None:
@@ -2521,13 +2522,20 @@ def main(args):
vae = sli_vae vae = sli_vae
del sli_vae del sli_vae
vae.to(dtype).to(device) vae.to(dtype).to(device)
vae.eval()
text_encoder.to(dtype).to(device) text_encoder.to(dtype).to(device)
unet.to(dtype).to(device) unet.to(dtype).to(device)
text_encoder.eval()
unet.eval()
if clip_model is not None: if clip_model is not None:
clip_model.to(dtype).to(device) clip_model.to(dtype).to(device)
clip_model.eval()
if vgg16_model is not None: if vgg16_model is not None:
vgg16_model.to(dtype).to(device) vgg16_model.to(dtype).to(device)
vgg16_model.eval()
# networkを組み込む # networkを組み込む
if args.network_module: if args.network_module:

View File

@@ -57,10 +57,13 @@ def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
noise_scheduler.alphas_cumprod = alphas_cumprod noise_scheduler.alphas_cumprod = alphas_cumprod
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma): def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False):
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr) min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float().to(loss.device) # from paper if v_prediction:
snr_weight = torch.div(min_snr_gamma, snr+1).float().to(loss.device)
else:
snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device)
loss = loss * snr_weight loss = loss * snr_weight
return loss return loss

View File

@@ -1148,10 +1148,6 @@ class UpBlock2D(nn.Module):
res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# Deep Shrink
if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]:
hidden_states = resize_like(hidden_states, res_hidden_states)
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing: if self.training and self.gradient_checkpointing:
@@ -1244,10 +1240,6 @@ class CrossAttnUpBlock2D(nn.Module):
res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# Deep Shrink
if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]:
hidden_states = resize_like(hidden_states, res_hidden_states)
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing: if self.training and self.gradient_checkpointing:
@@ -1444,31 +1436,6 @@ class UNet2DConditionModel(nn.Module):
self.conv_act = nn.SiLU() self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1) self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1)
# Deep Shrink
self.ds_depth_1 = None
self.ds_depth_2 = None
self.ds_timesteps_1 = None
self.ds_timesteps_2 = None
self.ds_ratio = None
def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
if ds_depth_1 is None:
print("Deep Shrink is disabled.")
self.ds_depth_1 = None
self.ds_timesteps_1 = None
self.ds_depth_2 = None
self.ds_timesteps_2 = None
self.ds_ratio = None
else:
print(
f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
)
self.ds_depth_1 = ds_depth_1
self.ds_timesteps_1 = ds_timesteps_1
self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1
self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
self.ds_ratio = ds_ratio
# region diffusers compatibility # region diffusers compatibility
def prepare_config(self): def prepare_config(self):
self.config = SimpleNamespace() self.config = SimpleNamespace()
@@ -1572,20 +1539,7 @@ class UNet2DConditionModel(nn.Module):
sample = self.conv_in(sample) sample = self.conv_in(sample)
down_block_res_samples = (sample,) down_block_res_samples = (sample,)
for depth, downsample_block in enumerate(self.down_blocks): for downsample_block in self.down_blocks:
# Deep Shrink
if self.ds_depth_1 is not None:
if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or (
self.ds_depth_2 is not None
and depth == self.ds_depth_2
and timesteps[0] < self.ds_timesteps_1
and timesteps[0] >= self.ds_timesteps_2
):
org_dtype = sample.dtype
if org_dtype == torch.bfloat16:
sample = sample.to(torch.float32)
sample = F.interpolate(sample, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype)
# downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、 # downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、
# まあこちらのほうがわかりやすいかもしれない # まあこちらのほうがわかりやすいかもしれない
if downsample_block.has_cross_attention: if downsample_block.has_cross_attention:
@@ -1668,3 +1622,255 @@ class UNet2DConditionModel(nn.Module):
timesteps = timesteps.expand(sample.shape[0]) timesteps = timesteps.expand(sample.shape[0])
return timesteps return timesteps
class InferUNet2DConditionModel:
def __init__(self, original_unet: UNet2DConditionModel):
self.delegate = original_unet
# override original model's forward method: because forward is not called by `__call__`
# overriding `__call__` is not enough, because nn.Module.forward has a special handling
self.delegate.forward = self.forward
# override original model's up blocks' forward method
for up_block in self.delegate.up_blocks:
if up_block.__class__.__name__ == "UpBlock2D":
def resnet_wrapper(func, block):
def forward(*args, **kwargs):
return func(block, *args, **kwargs)
return forward
up_block.forward = resnet_wrapper(self.up_block_forward, up_block)
elif up_block.__class__.__name__ == "CrossAttnUpBlock2D":
def cross_attn_up_wrapper(func, block):
def forward(*args, **kwargs):
return func(block, *args, **kwargs)
return forward
up_block.forward = cross_attn_up_wrapper(self.cross_attn_up_block_forward, up_block)
# Deep Shrink
self.ds_depth_1 = None
self.ds_depth_2 = None
self.ds_timesteps_1 = None
self.ds_timesteps_2 = None
self.ds_ratio = None
# call original model's methods
def __getattr__(self, name):
return getattr(self.delegate, name)
def __call__(self, *args, **kwargs):
return self.delegate(*args, **kwargs)
def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
if ds_depth_1 is None:
print("Deep Shrink is disabled.")
self.ds_depth_1 = None
self.ds_timesteps_1 = None
self.ds_depth_2 = None
self.ds_timesteps_2 = None
self.ds_ratio = None
else:
print(
f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
)
self.ds_depth_1 = ds_depth_1
self.ds_timesteps_1 = ds_timesteps_1
self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1
self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
self.ds_ratio = ds_ratio
def up_block_forward(self, _self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
for resnet in _self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# Deep Shrink
if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]:
hidden_states = resize_like(hidden_states, res_hidden_states)
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
if _self.upsamplers is not None:
for upsampler in _self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
def cross_attn_up_block_forward(
self,
_self,
hidden_states,
res_hidden_states_tuple,
temb=None,
encoder_hidden_states=None,
upsample_size=None,
):
for resnet, attn in zip(_self.resnets, _self.attentions):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# Deep Shrink
if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]:
hidden_states = resize_like(hidden_states, res_hidden_states)
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
if _self.upsamplers is not None:
for upsampler in _self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
class_labels: Optional[torch.Tensor] = None,
return_dict: bool = True,
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
mid_block_additional_residual: Optional[torch.Tensor] = None,
) -> Union[Dict, Tuple]:
r"""
current implementation is a copy of `UNet2DConditionModel.forward()` with Deep Shrink.
"""
r"""
Args:
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a dict instead of a plain tuple.
Returns:
`SampleOutput` or `tuple`:
`SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
"""
_self = self.delegate
# By default samples have to be AT least a multiple of the overall upsampling factor.
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
# However, the upsampling interpolation output size can be forced to fit any upsampling size
# on the fly if necessary.
# デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある
# ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する
# 多分画質が悪くなるので、64で割り切れるようにしておくのが良い
default_overall_up_factor = 2**_self.num_upsamplers
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
# 64で割り切れないときはupsamplerにサイズを伝える
forward_upsample_size = False
upsample_size = None
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
# logger.info("Forward upsample size to force interpolation output size.")
forward_upsample_size = True
# 1. time
timesteps = timestep
timesteps = _self.handle_unusual_timesteps(sample, timesteps) # 変な時だけ処理
t_emb = _self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
# timestepsは重みを含まないので常にfloat32のテンソルを返す
# しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある
# time_projでキャストしておけばいいんじゃね
t_emb = t_emb.to(dtype=_self.dtype)
emb = _self.time_embedding(t_emb)
# 2. pre-process
sample = _self.conv_in(sample)
down_block_res_samples = (sample,)
for depth, downsample_block in enumerate(_self.down_blocks):
# Deep Shrink
if self.ds_depth_1 is not None:
if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or (
self.ds_depth_2 is not None
and depth == self.ds_depth_2
and timesteps[0] < self.ds_timesteps_1
and timesteps[0] >= self.ds_timesteps_2
):
org_dtype = sample.dtype
if org_dtype == torch.bfloat16:
sample = sample.to(torch.float32)
sample = F.interpolate(sample, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype)
# downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、
# まあこちらのほうがわかりやすいかもしれない
if downsample_block.has_cross_attention:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
down_block_res_samples += res_samples
# skip connectionにControlNetの出力を追加する
if down_block_additional_residuals is not None:
down_block_res_samples = list(down_block_res_samples)
for i in range(len(down_block_res_samples)):
down_block_res_samples[i] += down_block_additional_residuals[i]
down_block_res_samples = tuple(down_block_res_samples)
# 4. mid
sample = _self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
# ControlNetの出力を追加する
if mid_block_additional_residual is not None:
sample += mid_block_additional_residual
# 5. up
for i, upsample_block in enumerate(_self.up_blocks):
is_final_block = i == len(_self.up_blocks) - 1
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # skip connection
# if we have not reached the final block and need to forward the upsample size, we do it here
# 前述のように最後のブロック以外ではupsample_sizeを伝える
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]
if upsample_block.has_cross_attention:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
upsample_size=upsample_size,
)
else:
sample = upsample_block(
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
)
# 6. post-process
sample = _self.conv_norm_out(sample)
sample = _self.conv_act(sample)
sample = _self.conv_out(sample)
if not return_dict:
return (sample,)
return SampleOutput(sample=sample)

View File

@@ -24,7 +24,7 @@
import math import math
from types import SimpleNamespace from types import SimpleNamespace
from typing import Optional from typing import Any, Optional
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
@@ -1013,31 +1013,6 @@ class SdxlUNet2DConditionModel(nn.Module):
[GroupNorm32(32, self.model_channels), nn.SiLU(), nn.Conv2d(self.model_channels, self.out_channels, 3, padding=1)] [GroupNorm32(32, self.model_channels), nn.SiLU(), nn.Conv2d(self.model_channels, self.out_channels, 3, padding=1)]
) )
# Deep Shrink
self.ds_depth_1 = None
self.ds_depth_2 = None
self.ds_timesteps_1 = None
self.ds_timesteps_2 = None
self.ds_ratio = None
def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
if ds_depth_1 is None:
print("Deep Shrink is disabled.")
self.ds_depth_1 = None
self.ds_timesteps_1 = None
self.ds_depth_2 = None
self.ds_timesteps_2 = None
self.ds_ratio = None
else:
print(
f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
)
self.ds_depth_1 = ds_depth_1
self.ds_timesteps_1 = ds_timesteps_1
self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1
self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
self.ds_ratio = ds_ratio
# region diffusers compatibility # region diffusers compatibility
def prepare_config(self): def prepare_config(self):
self.config = SimpleNamespace() self.config = SimpleNamespace()
@@ -1120,7 +1095,97 @@ class SdxlUNet2DConditionModel(nn.Module):
# h = x.type(self.dtype) # h = x.type(self.dtype)
h = x h = x
for depth, module in enumerate(self.input_blocks): for module in self.input_blocks:
h = call_module(module, h, emb, context)
hs.append(h)
h = call_module(self.middle_block, h, emb, context)
for module in self.output_blocks:
h = torch.cat([h, hs.pop()], dim=1)
h = call_module(module, h, emb, context)
h = h.type(x.dtype)
h = call_module(self.out, h, emb, context)
return h
class InferSdxlUNet2DConditionModel:
def __init__(self, original_unet: SdxlUNet2DConditionModel, **kwargs):
self.delegate = original_unet
# override original model's forward method: because forward is not called by `__call__`
# overriding `__call__` is not enough, because nn.Module.forward has a special handling
self.delegate.forward = self.forward
# Deep Shrink
self.ds_depth_1 = None
self.ds_depth_2 = None
self.ds_timesteps_1 = None
self.ds_timesteps_2 = None
self.ds_ratio = None
# call original model's methods
def __getattr__(self, name):
return getattr(self.delegate, name)
def __call__(self, *args, **kwargs):
return self.delegate(*args, **kwargs)
def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
if ds_depth_1 is None:
print("Deep Shrink is disabled.")
self.ds_depth_1 = None
self.ds_timesteps_1 = None
self.ds_depth_2 = None
self.ds_timesteps_2 = None
self.ds_ratio = None
else:
print(
f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
)
self.ds_depth_1 = ds_depth_1
self.ds_timesteps_1 = ds_timesteps_1
self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1
self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
self.ds_ratio = ds_ratio
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
r"""
current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink.
"""
_self = self.delegate
# broadcast timesteps to batch dimension
timesteps = timesteps.expand(x.shape[0])
hs = []
t_emb = get_timestep_embedding(timesteps, _self.model_channels) # , repeat_only=False)
t_emb = t_emb.to(x.dtype)
emb = _self.time_embed(t_emb)
assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}"
assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}"
# assert x.dtype == _self.dtype
emb = emb + _self.label_emb(y)
def call_module(module, h, emb, context):
x = h
for layer in module:
# print(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None)
if isinstance(layer, ResnetBlock2D):
x = layer(x, emb)
elif isinstance(layer, Transformer2DModel):
x = layer(x, context)
else:
x = layer(x)
return x
# h = x.type(self.dtype)
h = x
for depth, module in enumerate(_self.input_blocks):
# Deep Shrink # Deep Shrink
if self.ds_depth_1 is not None: if self.ds_depth_1 is not None:
if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or ( if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or (
@@ -1138,9 +1203,9 @@ class SdxlUNet2DConditionModel(nn.Module):
h = call_module(module, h, emb, context) h = call_module(module, h, emb, context)
hs.append(h) hs.append(h)
h = call_module(self.middle_block, h, emb, context) h = call_module(_self.middle_block, h, emb, context)
for module in self.output_blocks: for module in _self.output_blocks:
# Deep Shrink # Deep Shrink
if self.ds_depth_1 is not None: if self.ds_depth_1 is not None:
if hs[-1].shape[-2:] != h.shape[-2:]: if hs[-1].shape[-2:] != h.shape[-2:]:
@@ -1156,7 +1221,7 @@ class SdxlUNet2DConditionModel(nn.Module):
h = resize_like(h, x) h = resize_like(h, x)
h = h.type(x.dtype) h = h.type(x.dtype)
h = call_module(self.out, h, emb, context) h = call_module(_self.out, h, emb, context)
return h return h

View File

@@ -62,7 +62,7 @@ def cat_h(sliced):
return x return x
def resblock_forward(_self, num_slices, input_tensor, temb): def resblock_forward(_self, num_slices, input_tensor, temb, **kwargs):
assert _self.upsample is None and _self.downsample is None assert _self.upsample is None and _self.downsample is None
assert _self.norm1.num_groups == _self.norm2.num_groups assert _self.norm1.num_groups == _self.norm2.num_groups
assert temb is None assert temb is None

View File

@@ -13,8 +13,8 @@ from library import sai_model_spec, model_util, sdxl_model_util
import lora import lora
CLAMP_QUANTILE = 0.99 # CLAMP_QUANTILE = 0.99
MIN_DIFF = 1e-1 # MIN_DIFF = 1e-1
def save_to_file(file_name, model, state_dict, dtype): def save_to_file(file_name, model, state_dict, dtype):
@@ -29,7 +29,21 @@ def save_to_file(file_name, model, state_dict, dtype):
torch.save(model, file_name) torch.save(model, file_name)
def svd(args): def svd(
model_org=None,
model_tuned=None,
save_to=None,
dim=4,
v2=None,
sdxl=None,
conv_dim=None,
v_parameterization=None,
device=None,
save_precision=None,
clamp_quantile=0.99,
min_diff=0.01,
no_metadata=False,
):
def str_to_dtype(p): def str_to_dtype(p):
if p == "float": if p == "float":
return torch.float return torch.float
@@ -39,44 +53,42 @@ def svd(args):
return torch.bfloat16 return torch.bfloat16
return None return None
assert args.v2 != args.sdxl or ( assert v2 != sdxl or (not v2 and not sdxl), "v2 and sdxl cannot be specified at the same time / v2とsdxlは同時に指定できません"
not args.v2 and not args.sdxl if v_parameterization is None:
), "v2 and sdxl cannot be specified at the same time / v2とsdxlは同時に指定できません" v_parameterization = v2
if args.v_parameterization is None:
args.v_parameterization = args.v2
save_dtype = str_to_dtype(args.save_precision) save_dtype = str_to_dtype(save_precision)
# load models # load models
if not args.sdxl: if not sdxl:
print(f"loading original SD model : {args.model_org}") print(f"loading original SD model : {model_org}")
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_org) text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org)
text_encoders_o = [text_encoder_o] text_encoders_o = [text_encoder_o]
print(f"loading tuned SD model : {args.model_tuned}") print(f"loading tuned SD model : {model_tuned}")
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned) text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned)
text_encoders_t = [text_encoder_t] text_encoders_t = [text_encoder_t]
model_version = model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization) model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization)
else: else:
print(f"loading original SDXL model : {args.model_org}") print(f"loading original SDXL model : {model_org}")
text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.model_org, "cpu" sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, "cpu"
) )
text_encoders_o = [text_encoder_o1, text_encoder_o2] text_encoders_o = [text_encoder_o1, text_encoder_o2]
print(f"loading original SDXL model : {args.model_tuned}") print(f"loading original SDXL model : {model_tuned}")
text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.model_tuned, "cpu" sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, "cpu"
) )
text_encoders_t = [text_encoder_t1, text_encoder_t2] text_encoders_t = [text_encoder_t1, text_encoder_t2]
model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0 model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0
# create LoRA network to extract weights: Use dim (rank) as alpha # create LoRA network to extract weights: Use dim (rank) as alpha
if args.conv_dim is None: if conv_dim is None:
kwargs = {} kwargs = {}
else: else:
kwargs = {"conv_dim": args.conv_dim, "conv_alpha": args.conv_dim} kwargs = {"conv_dim": conv_dim, "conv_alpha": conv_dim}
lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoders_o, unet_o, **kwargs) lora_network_o = lora.create_network(1.0, dim, dim, None, text_encoders_o, unet_o, **kwargs)
lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoders_t, unet_t, **kwargs) lora_network_t = lora.create_network(1.0, dim, dim, None, text_encoders_t, unet_t, **kwargs)
assert len(lora_network_o.text_encoder_loras) == len( assert len(lora_network_o.text_encoder_loras) == len(
lora_network_t.text_encoder_loras lora_network_t.text_encoder_loras
), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違いますSD1.xベースとSD2.xベース " ), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違いますSD1.xベースとSD2.xベース "
@@ -91,9 +103,9 @@ def svd(args):
diff = module_t.weight - module_o.weight diff = module_t.weight - module_o.weight
# Text Encoder might be same # Text Encoder might be same
if not text_encoder_different and torch.max(torch.abs(diff)) > MIN_DIFF: if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff:
text_encoder_different = True text_encoder_different = True
print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {MIN_DIFF}") print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {min_diff}")
diff = diff.float() diff = diff.float()
diffs[lora_name] = diff diffs[lora_name] = diff
@@ -120,16 +132,16 @@ def svd(args):
lora_weights = {} lora_weights = {}
with torch.no_grad(): with torch.no_grad():
for lora_name, mat in tqdm(list(diffs.items())): for lora_name, mat in tqdm(list(diffs.items())):
# if args.conv_dim is None, diffs do not include LoRAs for conv2d-3x3 # if conv_dim is None, diffs do not include LoRAs for conv2d-3x3
conv2d = len(mat.size()) == 4 conv2d = len(mat.size()) == 4
kernel_size = None if not conv2d else mat.size()[2:4] kernel_size = None if not conv2d else mat.size()[2:4]
conv2d_3x3 = conv2d and kernel_size != (1, 1) conv2d_3x3 = conv2d and kernel_size != (1, 1)
rank = args.dim if not conv2d_3x3 or args.conv_dim is None else args.conv_dim rank = dim if not conv2d_3x3 or conv_dim is None else conv_dim
out_dim, in_dim = mat.size()[0:2] out_dim, in_dim = mat.size()[0:2]
if args.device: if device:
mat = mat.to(args.device) mat = mat.to(device)
# print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim) # print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim)
rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
@@ -149,7 +161,7 @@ def svd(args):
Vh = Vh[:rank, :] Vh = Vh[:rank, :]
dist = torch.cat([U.flatten(), Vh.flatten()]) dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, CLAMP_QUANTILE) hi_val = torch.quantile(dist, clamp_quantile)
low_val = -hi_val low_val = -hi_val
U = U.clamp(low_val, hi_val) U = U.clamp(low_val, hi_val)
@@ -178,34 +190,32 @@ def svd(args):
info = lora_network_save.load_state_dict(lora_sd) info = lora_network_save.load_state_dict(lora_sd)
print(f"Loading extracted LoRA weights: {info}") print(f"Loading extracted LoRA weights: {info}")
dir_name = os.path.dirname(args.save_to) dir_name = os.path.dirname(save_to)
if dir_name and not os.path.exists(dir_name): if dir_name and not os.path.exists(dir_name):
os.makedirs(dir_name, exist_ok=True) os.makedirs(dir_name, exist_ok=True)
# minimum metadata # minimum metadata
net_kwargs = {} net_kwargs = {}
if args.conv_dim is not None: if conv_dim is not None:
net_kwargs["conv_dim"] = args.conv_dim net_kwargs["conv_dim"] = str(conv_dim)
net_kwargs["conv_alpha"] = args.conv_dim net_kwargs["conv_alpha"] = str(float(conv_dim))
metadata = { metadata = {
"ss_v2": str(args.v2), "ss_v2": str(v2),
"ss_base_model_version": model_version, "ss_base_model_version": model_version,
"ss_network_module": "networks.lora", "ss_network_module": "networks.lora",
"ss_network_dim": str(args.dim), "ss_network_dim": str(dim),
"ss_network_alpha": str(args.dim), "ss_network_alpha": str(float(dim)),
"ss_network_args": json.dumps(net_kwargs), "ss_network_args": json.dumps(net_kwargs),
} }
if not args.no_metadata: if not no_metadata:
title = os.path.splitext(os.path.basename(args.save_to))[0] title = os.path.splitext(os.path.basename(save_to))[0]
sai_metadata = sai_model_spec.build_metadata( sai_metadata = sai_model_spec.build_metadata(None, v2, v_parameterization, sdxl, True, False, time.time(), title=title)
None, args.v2, args.v_parameterization, args.sdxl, True, False, time.time(), title=title
)
metadata.update(sai_metadata) metadata.update(sai_metadata)
lora_network_save.save_weights(args.save_to, save_dtype, metadata) lora_network_save.save_weights(save_to, save_dtype, metadata)
print(f"LoRA weights are saved to: {args.save_to}") print(f"LoRA weights are saved to: {save_to}")
def setup_parser() -> argparse.ArgumentParser: def setup_parser() -> argparse.ArgumentParser:
@@ -213,7 +223,7 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む") parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む")
parser.add_argument( parser.add_argument(
"--v_parameterization", "--v_parameterization",
type=bool, action="store_true",
default=None, default=None,
help="make LoRA metadata for v-parameterization (default is same to v2) / 作成するLoRAのメタデータにv-parameterization用と設定する省略時はv2と同じ", help="make LoRA metadata for v-parameterization (default is same to v2) / 作成するLoRAのメタデータにv-parameterization用と設定する省略時はv2と同じ",
) )
@@ -231,16 +241,22 @@ def setup_parser() -> argparse.ArgumentParser:
"--model_org", "--model_org",
type=str, type=str,
default=None, default=None,
required=True,
help="Stable Diffusion original model: ckpt or safetensors file / 元モデル、ckptまたはsafetensors", help="Stable Diffusion original model: ckpt or safetensors file / 元モデル、ckptまたはsafetensors",
) )
parser.add_argument( parser.add_argument(
"--model_tuned", "--model_tuned",
type=str, type=str,
default=None, default=None,
required=True,
help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル生成されるLoRAは元→派生の差分になります、ckptまたはsafetensors", help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル生成されるLoRAは元→派生の差分になります、ckptまたはsafetensors",
) )
parser.add_argument( parser.add_argument(
"--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors" "--save_to",
type=str,
default=None,
required=True,
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors",
) )
parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数rankデフォルト4") parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数rankデフォルト4")
parser.add_argument( parser.add_argument(
@@ -250,6 +266,19 @@ def setup_parser() -> argparse.ArgumentParser:
help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数rankデフォルトNone、適用なし", help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数rankデフォルトNone、適用なし",
) )
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
parser.add_argument(
"--clamp_quantile",
type=float,
default=0.99,
help="Quantile clamping value, float, (0-1). Default = 0.99 / 値をクランプするための分位点、float、(0-1)。デフォルトは0.99",
)
parser.add_argument(
"--min_diff",
type=float,
default=0.01,
help="Minimum difference between finetuned model and base to consider them different enough to extract, float, (0-1). Default = 0.01 /"
+ "LoRAを抽出するために元モデルと派生モデルの差分の最小値、float、(0-1)。デフォルトは0.01",
)
parser.add_argument( parser.add_argument(
"--no_metadata", "--no_metadata",
action="store_true", action="store_true",
@@ -264,4 +293,4 @@ if __name__ == "__main__":
parser = setup_parser() parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()
svd(args) svd(**vars(args))

View File

@@ -57,7 +57,7 @@ import library.train_util as train_util
import library.sdxl_model_util as sdxl_model_util import library.sdxl_model_util as sdxl_model_util
import library.sdxl_train_util as sdxl_train_util import library.sdxl_train_util as sdxl_train_util
from networks.lora import LoRANetwork from networks.lora import LoRANetwork
from library.sdxl_original_unet import SdxlUNet2DConditionModel from library.sdxl_original_unet import InferSdxlUNet2DConditionModel
from library.original_unet import FlashAttentionFunction from library.original_unet import FlashAttentionFunction
from networks.control_net_lllite import ControlNetLLLite from networks.control_net_lllite import ControlNetLLLite
@@ -290,7 +290,7 @@ class PipelineLike:
vae: AutoencoderKL, vae: AutoencoderKL,
text_encoders: List[CLIPTextModel], text_encoders: List[CLIPTextModel],
tokenizers: List[CLIPTokenizer], tokenizers: List[CLIPTokenizer],
unet: SdxlUNet2DConditionModel, unet: InferSdxlUNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
clip_skip: int, clip_skip: int,
): ):
@@ -328,7 +328,7 @@ class PipelineLike:
self.vae = vae self.vae = vae
self.text_encoders = text_encoders self.text_encoders = text_encoders
self.tokenizers = tokenizers self.tokenizers = tokenizers
self.unet: SdxlUNet2DConditionModel = unet self.unet: InferSdxlUNet2DConditionModel = unet
self.scheduler = scheduler self.scheduler = scheduler
self.safety_checker = None self.safety_checker = None
@@ -1611,6 +1611,7 @@ def main(args):
(_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model( (_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model(
args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype
) )
unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel(unet)
# xformers、Hypernetwork対応 # xformers、Hypernetwork対応
if not args.diffusers_xformers: if not args.diffusers_xformers:
@@ -1766,10 +1767,14 @@ def main(args):
print("set vae_dtype to float32") print("set vae_dtype to float32")
vae_dtype = torch.float32 vae_dtype = torch.float32
vae.to(vae_dtype).to(device) vae.to(vae_dtype).to(device)
vae.eval()
text_encoder1.to(dtype).to(device) text_encoder1.to(dtype).to(device)
text_encoder2.to(dtype).to(device) text_encoder2.to(dtype).to(device)
unet.to(dtype).to(device) unet.to(dtype).to(device)
text_encoder1.eval()
text_encoder2.eval()
unet.eval()
# networkを組み込む # networkを組み込む
if args.network_module: if args.network_module:

View File

@@ -460,7 +460,7 @@ def train(args):
loss = loss * loss_weights loss = loss * loss_weights
if args.min_snr_gamma: if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred: if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.v_pred_like_loss: if args.v_pred_like_loss:

View File

@@ -430,7 +430,7 @@ def train(args):
loss = loss * loss_weights loss = loss * loss_weights
if args.min_snr_gamma: if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred: if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.v_pred_like_loss: if args.v_pred_like_loss:

View File

@@ -449,7 +449,7 @@ def train(args):
loss = loss * loss_weights loss = loss * loss_weights
if args.min_snr_gamma: if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし

View File

@@ -342,7 +342,7 @@ def train(args):
loss = loss * loss_weights loss = loss * loss_weights
if args.min_snr_gamma: if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred: if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.debiased_estimation_loss: if args.debiased_estimation_loss:

View File

@@ -812,7 +812,7 @@ class NetworkTrainer:
loss = loss * loss_weights loss = loss * loss_weights
if args.min_snr_gamma: if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred: if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.v_pred_like_loss: if args.v_pred_like_loss:

View File

@@ -578,7 +578,7 @@ class TextualInversionTrainer:
loss = loss * loss_weights loss = loss * loss_weights
if args.min_snr_gamma: if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred: if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.v_pred_like_loss: if args.v_pred_like_loss:

View File

@@ -469,7 +469,7 @@ def train(args):
loss = loss * loss_weights loss = loss * loss_weights
if args.min_snr_gamma: if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred: if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.debiased_estimation_loss: if args.debiased_estimation_loss: