mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Merge branch 'dev' into gradual_latent_hires_fix
This commit is contained in:
@@ -355,7 +355,7 @@ def train(args):
|
|||||||
loss = loss.mean([1, 2, 3])
|
loss = loss.mean([1, 2, 3])
|
||||||
|
|
||||||
if args.min_snr_gamma:
|
if args.min_snr_gamma:
|
||||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||||
if args.scale_v_pred_loss_like_noise_pred:
|
if args.scale_v_pred_loss_like_noise_pred:
|
||||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||||
if args.debiased_estimation_loss:
|
if args.debiased_estimation_loss:
|
||||||
|
|||||||
@@ -105,7 +105,7 @@ import library.train_util as train_util
|
|||||||
from networks.lora import LoRANetwork
|
from networks.lora import LoRANetwork
|
||||||
import tools.original_control_net as original_control_net
|
import tools.original_control_net as original_control_net
|
||||||
from tools.original_control_net import ControlNetInfo
|
from tools.original_control_net import ControlNetInfo
|
||||||
from library.original_unet import UNet2DConditionModel
|
from library.original_unet import UNet2DConditionModel, InferUNet2DConditionModel
|
||||||
from library.original_unet import FlashAttentionFunction
|
from library.original_unet import FlashAttentionFunction
|
||||||
|
|
||||||
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
|
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
|
||||||
@@ -378,7 +378,7 @@ class PipelineLike:
|
|||||||
vae: AutoencoderKL,
|
vae: AutoencoderKL,
|
||||||
text_encoder: CLIPTextModel,
|
text_encoder: CLIPTextModel,
|
||||||
tokenizer: CLIPTokenizer,
|
tokenizer: CLIPTokenizer,
|
||||||
unet: UNet2DConditionModel,
|
unet: InferUNet2DConditionModel,
|
||||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
||||||
clip_skip: int,
|
clip_skip: int,
|
||||||
clip_model: CLIPModel,
|
clip_model: CLIPModel,
|
||||||
@@ -2365,6 +2365,7 @@ def main(args):
|
|||||||
)
|
)
|
||||||
original_unet.load_state_dict(unet.state_dict())
|
original_unet.load_state_dict(unet.state_dict())
|
||||||
unet = original_unet
|
unet = original_unet
|
||||||
|
unet: InferUNet2DConditionModel = InferUNet2DConditionModel(unet)
|
||||||
|
|
||||||
# VAEを読み込む
|
# VAEを読み込む
|
||||||
if args.vae is not None:
|
if args.vae is not None:
|
||||||
@@ -2521,13 +2522,20 @@ def main(args):
|
|||||||
vae = sli_vae
|
vae = sli_vae
|
||||||
del sli_vae
|
del sli_vae
|
||||||
vae.to(dtype).to(device)
|
vae.to(dtype).to(device)
|
||||||
|
vae.eval()
|
||||||
|
|
||||||
text_encoder.to(dtype).to(device)
|
text_encoder.to(dtype).to(device)
|
||||||
unet.to(dtype).to(device)
|
unet.to(dtype).to(device)
|
||||||
|
|
||||||
|
text_encoder.eval()
|
||||||
|
unet.eval()
|
||||||
|
|
||||||
if clip_model is not None:
|
if clip_model is not None:
|
||||||
clip_model.to(dtype).to(device)
|
clip_model.to(dtype).to(device)
|
||||||
|
clip_model.eval()
|
||||||
if vgg16_model is not None:
|
if vgg16_model is not None:
|
||||||
vgg16_model.to(dtype).to(device)
|
vgg16_model.to(dtype).to(device)
|
||||||
|
vgg16_model.eval()
|
||||||
|
|
||||||
# networkを組み込む
|
# networkを組み込む
|
||||||
if args.network_module:
|
if args.network_module:
|
||||||
|
|||||||
@@ -57,10 +57,13 @@ def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
|
|||||||
noise_scheduler.alphas_cumprod = alphas_cumprod
|
noise_scheduler.alphas_cumprod = alphas_cumprod
|
||||||
|
|
||||||
|
|
||||||
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma):
|
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False):
|
||||||
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
|
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
|
||||||
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
|
min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
|
||||||
snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float().to(loss.device) # from paper
|
if v_prediction:
|
||||||
|
snr_weight = torch.div(min_snr_gamma, snr+1).float().to(loss.device)
|
||||||
|
else:
|
||||||
|
snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device)
|
||||||
loss = loss * snr_weight
|
loss = loss * snr_weight
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|||||||
@@ -1148,10 +1148,6 @@ class UpBlock2D(nn.Module):
|
|||||||
res_hidden_states = res_hidden_states_tuple[-1]
|
res_hidden_states = res_hidden_states_tuple[-1]
|
||||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||||
|
|
||||||
# Deep Shrink
|
|
||||||
if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]:
|
|
||||||
hidden_states = resize_like(hidden_states, res_hidden_states)
|
|
||||||
|
|
||||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||||
|
|
||||||
if self.training and self.gradient_checkpointing:
|
if self.training and self.gradient_checkpointing:
|
||||||
@@ -1244,10 +1240,6 @@ class CrossAttnUpBlock2D(nn.Module):
|
|||||||
res_hidden_states = res_hidden_states_tuple[-1]
|
res_hidden_states = res_hidden_states_tuple[-1]
|
||||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||||
|
|
||||||
# Deep Shrink
|
|
||||||
if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]:
|
|
||||||
hidden_states = resize_like(hidden_states, res_hidden_states)
|
|
||||||
|
|
||||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||||
|
|
||||||
if self.training and self.gradient_checkpointing:
|
if self.training and self.gradient_checkpointing:
|
||||||
@@ -1444,31 +1436,6 @@ class UNet2DConditionModel(nn.Module):
|
|||||||
self.conv_act = nn.SiLU()
|
self.conv_act = nn.SiLU()
|
||||||
self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1)
|
self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1)
|
||||||
|
|
||||||
# Deep Shrink
|
|
||||||
self.ds_depth_1 = None
|
|
||||||
self.ds_depth_2 = None
|
|
||||||
self.ds_timesteps_1 = None
|
|
||||||
self.ds_timesteps_2 = None
|
|
||||||
self.ds_ratio = None
|
|
||||||
|
|
||||||
def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
|
|
||||||
if ds_depth_1 is None:
|
|
||||||
print("Deep Shrink is disabled.")
|
|
||||||
self.ds_depth_1 = None
|
|
||||||
self.ds_timesteps_1 = None
|
|
||||||
self.ds_depth_2 = None
|
|
||||||
self.ds_timesteps_2 = None
|
|
||||||
self.ds_ratio = None
|
|
||||||
else:
|
|
||||||
print(
|
|
||||||
f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
|
|
||||||
)
|
|
||||||
self.ds_depth_1 = ds_depth_1
|
|
||||||
self.ds_timesteps_1 = ds_timesteps_1
|
|
||||||
self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1
|
|
||||||
self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
|
|
||||||
self.ds_ratio = ds_ratio
|
|
||||||
|
|
||||||
# region diffusers compatibility
|
# region diffusers compatibility
|
||||||
def prepare_config(self):
|
def prepare_config(self):
|
||||||
self.config = SimpleNamespace()
|
self.config = SimpleNamespace()
|
||||||
@@ -1572,20 +1539,7 @@ class UNet2DConditionModel(nn.Module):
|
|||||||
sample = self.conv_in(sample)
|
sample = self.conv_in(sample)
|
||||||
|
|
||||||
down_block_res_samples = (sample,)
|
down_block_res_samples = (sample,)
|
||||||
for depth, downsample_block in enumerate(self.down_blocks):
|
for downsample_block in self.down_blocks:
|
||||||
# Deep Shrink
|
|
||||||
if self.ds_depth_1 is not None:
|
|
||||||
if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or (
|
|
||||||
self.ds_depth_2 is not None
|
|
||||||
and depth == self.ds_depth_2
|
|
||||||
and timesteps[0] < self.ds_timesteps_1
|
|
||||||
and timesteps[0] >= self.ds_timesteps_2
|
|
||||||
):
|
|
||||||
org_dtype = sample.dtype
|
|
||||||
if org_dtype == torch.bfloat16:
|
|
||||||
sample = sample.to(torch.float32)
|
|
||||||
sample = F.interpolate(sample, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype)
|
|
||||||
|
|
||||||
# downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、
|
# downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、
|
||||||
# まあこちらのほうがわかりやすいかもしれない
|
# まあこちらのほうがわかりやすいかもしれない
|
||||||
if downsample_block.has_cross_attention:
|
if downsample_block.has_cross_attention:
|
||||||
@@ -1668,3 +1622,255 @@ class UNet2DConditionModel(nn.Module):
|
|||||||
timesteps = timesteps.expand(sample.shape[0])
|
timesteps = timesteps.expand(sample.shape[0])
|
||||||
|
|
||||||
return timesteps
|
return timesteps
|
||||||
|
|
||||||
|
|
||||||
|
class InferUNet2DConditionModel:
|
||||||
|
def __init__(self, original_unet: UNet2DConditionModel):
|
||||||
|
self.delegate = original_unet
|
||||||
|
|
||||||
|
# override original model's forward method: because forward is not called by `__call__`
|
||||||
|
# overriding `__call__` is not enough, because nn.Module.forward has a special handling
|
||||||
|
self.delegate.forward = self.forward
|
||||||
|
|
||||||
|
# override original model's up blocks' forward method
|
||||||
|
for up_block in self.delegate.up_blocks:
|
||||||
|
if up_block.__class__.__name__ == "UpBlock2D":
|
||||||
|
|
||||||
|
def resnet_wrapper(func, block):
|
||||||
|
def forward(*args, **kwargs):
|
||||||
|
return func(block, *args, **kwargs)
|
||||||
|
|
||||||
|
return forward
|
||||||
|
|
||||||
|
up_block.forward = resnet_wrapper(self.up_block_forward, up_block)
|
||||||
|
|
||||||
|
elif up_block.__class__.__name__ == "CrossAttnUpBlock2D":
|
||||||
|
|
||||||
|
def cross_attn_up_wrapper(func, block):
|
||||||
|
def forward(*args, **kwargs):
|
||||||
|
return func(block, *args, **kwargs)
|
||||||
|
|
||||||
|
return forward
|
||||||
|
|
||||||
|
up_block.forward = cross_attn_up_wrapper(self.cross_attn_up_block_forward, up_block)
|
||||||
|
|
||||||
|
# Deep Shrink
|
||||||
|
self.ds_depth_1 = None
|
||||||
|
self.ds_depth_2 = None
|
||||||
|
self.ds_timesteps_1 = None
|
||||||
|
self.ds_timesteps_2 = None
|
||||||
|
self.ds_ratio = None
|
||||||
|
|
||||||
|
# call original model's methods
|
||||||
|
def __getattr__(self, name):
|
||||||
|
return getattr(self.delegate, name)
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
return self.delegate(*args, **kwargs)
|
||||||
|
|
||||||
|
def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
|
||||||
|
if ds_depth_1 is None:
|
||||||
|
print("Deep Shrink is disabled.")
|
||||||
|
self.ds_depth_1 = None
|
||||||
|
self.ds_timesteps_1 = None
|
||||||
|
self.ds_depth_2 = None
|
||||||
|
self.ds_timesteps_2 = None
|
||||||
|
self.ds_ratio = None
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
|
||||||
|
)
|
||||||
|
self.ds_depth_1 = ds_depth_1
|
||||||
|
self.ds_timesteps_1 = ds_timesteps_1
|
||||||
|
self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1
|
||||||
|
self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
|
||||||
|
self.ds_ratio = ds_ratio
|
||||||
|
|
||||||
|
def up_block_forward(self, _self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
|
||||||
|
for resnet in _self.resnets:
|
||||||
|
# pop res hidden states
|
||||||
|
res_hidden_states = res_hidden_states_tuple[-1]
|
||||||
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||||
|
|
||||||
|
# Deep Shrink
|
||||||
|
if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]:
|
||||||
|
hidden_states = resize_like(hidden_states, res_hidden_states)
|
||||||
|
|
||||||
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||||
|
hidden_states = resnet(hidden_states, temb)
|
||||||
|
|
||||||
|
if _self.upsamplers is not None:
|
||||||
|
for upsampler in _self.upsamplers:
|
||||||
|
hidden_states = upsampler(hidden_states, upsample_size)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def cross_attn_up_block_forward(
|
||||||
|
self,
|
||||||
|
_self,
|
||||||
|
hidden_states,
|
||||||
|
res_hidden_states_tuple,
|
||||||
|
temb=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
upsample_size=None,
|
||||||
|
):
|
||||||
|
for resnet, attn in zip(_self.resnets, _self.attentions):
|
||||||
|
# pop res hidden states
|
||||||
|
res_hidden_states = res_hidden_states_tuple[-1]
|
||||||
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||||
|
|
||||||
|
# Deep Shrink
|
||||||
|
if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]:
|
||||||
|
hidden_states = resize_like(hidden_states, res_hidden_states)
|
||||||
|
|
||||||
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||||
|
hidden_states = resnet(hidden_states, temb)
|
||||||
|
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
||||||
|
|
||||||
|
if _self.upsamplers is not None:
|
||||||
|
for upsampler in _self.upsamplers:
|
||||||
|
hidden_states = upsampler(hidden_states, upsample_size)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
sample: torch.FloatTensor,
|
||||||
|
timestep: Union[torch.Tensor, float, int],
|
||||||
|
encoder_hidden_states: torch.Tensor,
|
||||||
|
class_labels: Optional[torch.Tensor] = None,
|
||||||
|
return_dict: bool = True,
|
||||||
|
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
||||||
|
) -> Union[Dict, Tuple]:
|
||||||
|
r"""
|
||||||
|
current implementation is a copy of `UNet2DConditionModel.forward()` with Deep Shrink.
|
||||||
|
"""
|
||||||
|
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
||||||
|
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
||||||
|
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
|
||||||
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to return a dict instead of a plain tuple.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`SampleOutput` or `tuple`:
|
||||||
|
`SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_self = self.delegate
|
||||||
|
|
||||||
|
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
||||||
|
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
|
||||||
|
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
||||||
|
# on the fly if necessary.
|
||||||
|
# デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある
|
||||||
|
# ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する
|
||||||
|
# 多分画質が悪くなるので、64で割り切れるようにしておくのが良い
|
||||||
|
default_overall_up_factor = 2**_self.num_upsamplers
|
||||||
|
|
||||||
|
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
||||||
|
# 64で割り切れないときはupsamplerにサイズを伝える
|
||||||
|
forward_upsample_size = False
|
||||||
|
upsample_size = None
|
||||||
|
|
||||||
|
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
||||||
|
# logger.info("Forward upsample size to force interpolation output size.")
|
||||||
|
forward_upsample_size = True
|
||||||
|
|
||||||
|
# 1. time
|
||||||
|
timesteps = timestep
|
||||||
|
timesteps = _self.handle_unusual_timesteps(sample, timesteps) # 変な時だけ処理
|
||||||
|
|
||||||
|
t_emb = _self.time_proj(timesteps)
|
||||||
|
|
||||||
|
# timesteps does not contain any weights and will always return f32 tensors
|
||||||
|
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||||
|
# there might be better ways to encapsulate this.
|
||||||
|
# timestepsは重みを含まないので常にfloat32のテンソルを返す
|
||||||
|
# しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある
|
||||||
|
# time_projでキャストしておけばいいんじゃね?
|
||||||
|
t_emb = t_emb.to(dtype=_self.dtype)
|
||||||
|
emb = _self.time_embedding(t_emb)
|
||||||
|
|
||||||
|
# 2. pre-process
|
||||||
|
sample = _self.conv_in(sample)
|
||||||
|
|
||||||
|
down_block_res_samples = (sample,)
|
||||||
|
for depth, downsample_block in enumerate(_self.down_blocks):
|
||||||
|
# Deep Shrink
|
||||||
|
if self.ds_depth_1 is not None:
|
||||||
|
if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or (
|
||||||
|
self.ds_depth_2 is not None
|
||||||
|
and depth == self.ds_depth_2
|
||||||
|
and timesteps[0] < self.ds_timesteps_1
|
||||||
|
and timesteps[0] >= self.ds_timesteps_2
|
||||||
|
):
|
||||||
|
org_dtype = sample.dtype
|
||||||
|
if org_dtype == torch.bfloat16:
|
||||||
|
sample = sample.to(torch.float32)
|
||||||
|
sample = F.interpolate(sample, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype)
|
||||||
|
|
||||||
|
# downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、
|
||||||
|
# まあこちらのほうがわかりやすいかもしれない
|
||||||
|
if downsample_block.has_cross_attention:
|
||||||
|
sample, res_samples = downsample_block(
|
||||||
|
hidden_states=sample,
|
||||||
|
temb=emb,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||||
|
|
||||||
|
down_block_res_samples += res_samples
|
||||||
|
|
||||||
|
# skip connectionにControlNetの出力を追加する
|
||||||
|
if down_block_additional_residuals is not None:
|
||||||
|
down_block_res_samples = list(down_block_res_samples)
|
||||||
|
for i in range(len(down_block_res_samples)):
|
||||||
|
down_block_res_samples[i] += down_block_additional_residuals[i]
|
||||||
|
down_block_res_samples = tuple(down_block_res_samples)
|
||||||
|
|
||||||
|
# 4. mid
|
||||||
|
sample = _self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
|
||||||
|
|
||||||
|
# ControlNetの出力を追加する
|
||||||
|
if mid_block_additional_residual is not None:
|
||||||
|
sample += mid_block_additional_residual
|
||||||
|
|
||||||
|
# 5. up
|
||||||
|
for i, upsample_block in enumerate(_self.up_blocks):
|
||||||
|
is_final_block = i == len(_self.up_blocks) - 1
|
||||||
|
|
||||||
|
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
||||||
|
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # skip connection
|
||||||
|
|
||||||
|
# if we have not reached the final block and need to forward the upsample size, we do it here
|
||||||
|
# 前述のように最後のブロック以外ではupsample_sizeを伝える
|
||||||
|
if not is_final_block and forward_upsample_size:
|
||||||
|
upsample_size = down_block_res_samples[-1].shape[2:]
|
||||||
|
|
||||||
|
if upsample_block.has_cross_attention:
|
||||||
|
sample = upsample_block(
|
||||||
|
hidden_states=sample,
|
||||||
|
temb=emb,
|
||||||
|
res_hidden_states_tuple=res_samples,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
upsample_size=upsample_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sample = upsample_block(
|
||||||
|
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# 6. post-process
|
||||||
|
sample = _self.conv_norm_out(sample)
|
||||||
|
sample = _self.conv_act(sample)
|
||||||
|
sample = _self.conv_out(sample)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (sample,)
|
||||||
|
|
||||||
|
return SampleOutput(sample=sample)
|
||||||
|
|||||||
@@ -24,7 +24,7 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from typing import Optional
|
from typing import Any, Optional
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@@ -1013,31 +1013,6 @@ class SdxlUNet2DConditionModel(nn.Module):
|
|||||||
[GroupNorm32(32, self.model_channels), nn.SiLU(), nn.Conv2d(self.model_channels, self.out_channels, 3, padding=1)]
|
[GroupNorm32(32, self.model_channels), nn.SiLU(), nn.Conv2d(self.model_channels, self.out_channels, 3, padding=1)]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Deep Shrink
|
|
||||||
self.ds_depth_1 = None
|
|
||||||
self.ds_depth_2 = None
|
|
||||||
self.ds_timesteps_1 = None
|
|
||||||
self.ds_timesteps_2 = None
|
|
||||||
self.ds_ratio = None
|
|
||||||
|
|
||||||
def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
|
|
||||||
if ds_depth_1 is None:
|
|
||||||
print("Deep Shrink is disabled.")
|
|
||||||
self.ds_depth_1 = None
|
|
||||||
self.ds_timesteps_1 = None
|
|
||||||
self.ds_depth_2 = None
|
|
||||||
self.ds_timesteps_2 = None
|
|
||||||
self.ds_ratio = None
|
|
||||||
else:
|
|
||||||
print(
|
|
||||||
f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
|
|
||||||
)
|
|
||||||
self.ds_depth_1 = ds_depth_1
|
|
||||||
self.ds_timesteps_1 = ds_timesteps_1
|
|
||||||
self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1
|
|
||||||
self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
|
|
||||||
self.ds_ratio = ds_ratio
|
|
||||||
|
|
||||||
# region diffusers compatibility
|
# region diffusers compatibility
|
||||||
def prepare_config(self):
|
def prepare_config(self):
|
||||||
self.config = SimpleNamespace()
|
self.config = SimpleNamespace()
|
||||||
@@ -1120,7 +1095,97 @@ class SdxlUNet2DConditionModel(nn.Module):
|
|||||||
# h = x.type(self.dtype)
|
# h = x.type(self.dtype)
|
||||||
h = x
|
h = x
|
||||||
|
|
||||||
for depth, module in enumerate(self.input_blocks):
|
for module in self.input_blocks:
|
||||||
|
h = call_module(module, h, emb, context)
|
||||||
|
hs.append(h)
|
||||||
|
|
||||||
|
h = call_module(self.middle_block, h, emb, context)
|
||||||
|
|
||||||
|
for module in self.output_blocks:
|
||||||
|
h = torch.cat([h, hs.pop()], dim=1)
|
||||||
|
h = call_module(module, h, emb, context)
|
||||||
|
|
||||||
|
h = h.type(x.dtype)
|
||||||
|
h = call_module(self.out, h, emb, context)
|
||||||
|
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
class InferSdxlUNet2DConditionModel:
|
||||||
|
def __init__(self, original_unet: SdxlUNet2DConditionModel, **kwargs):
|
||||||
|
self.delegate = original_unet
|
||||||
|
|
||||||
|
# override original model's forward method: because forward is not called by `__call__`
|
||||||
|
# overriding `__call__` is not enough, because nn.Module.forward has a special handling
|
||||||
|
self.delegate.forward = self.forward
|
||||||
|
|
||||||
|
# Deep Shrink
|
||||||
|
self.ds_depth_1 = None
|
||||||
|
self.ds_depth_2 = None
|
||||||
|
self.ds_timesteps_1 = None
|
||||||
|
self.ds_timesteps_2 = None
|
||||||
|
self.ds_ratio = None
|
||||||
|
|
||||||
|
# call original model's methods
|
||||||
|
def __getattr__(self, name):
|
||||||
|
return getattr(self.delegate, name)
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
return self.delegate(*args, **kwargs)
|
||||||
|
|
||||||
|
def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
|
||||||
|
if ds_depth_1 is None:
|
||||||
|
print("Deep Shrink is disabled.")
|
||||||
|
self.ds_depth_1 = None
|
||||||
|
self.ds_timesteps_1 = None
|
||||||
|
self.ds_depth_2 = None
|
||||||
|
self.ds_timesteps_2 = None
|
||||||
|
self.ds_ratio = None
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
|
||||||
|
)
|
||||||
|
self.ds_depth_1 = ds_depth_1
|
||||||
|
self.ds_timesteps_1 = ds_timesteps_1
|
||||||
|
self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1
|
||||||
|
self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
|
||||||
|
self.ds_ratio = ds_ratio
|
||||||
|
|
||||||
|
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
|
||||||
|
r"""
|
||||||
|
current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink.
|
||||||
|
"""
|
||||||
|
_self = self.delegate
|
||||||
|
|
||||||
|
# broadcast timesteps to batch dimension
|
||||||
|
timesteps = timesteps.expand(x.shape[0])
|
||||||
|
|
||||||
|
hs = []
|
||||||
|
t_emb = get_timestep_embedding(timesteps, _self.model_channels) # , repeat_only=False)
|
||||||
|
t_emb = t_emb.to(x.dtype)
|
||||||
|
emb = _self.time_embed(t_emb)
|
||||||
|
|
||||||
|
assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}"
|
||||||
|
assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}"
|
||||||
|
# assert x.dtype == _self.dtype
|
||||||
|
emb = emb + _self.label_emb(y)
|
||||||
|
|
||||||
|
def call_module(module, h, emb, context):
|
||||||
|
x = h
|
||||||
|
for layer in module:
|
||||||
|
# print(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None)
|
||||||
|
if isinstance(layer, ResnetBlock2D):
|
||||||
|
x = layer(x, emb)
|
||||||
|
elif isinstance(layer, Transformer2DModel):
|
||||||
|
x = layer(x, context)
|
||||||
|
else:
|
||||||
|
x = layer(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
# h = x.type(self.dtype)
|
||||||
|
h = x
|
||||||
|
|
||||||
|
for depth, module in enumerate(_self.input_blocks):
|
||||||
# Deep Shrink
|
# Deep Shrink
|
||||||
if self.ds_depth_1 is not None:
|
if self.ds_depth_1 is not None:
|
||||||
if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or (
|
if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or (
|
||||||
@@ -1138,9 +1203,9 @@ class SdxlUNet2DConditionModel(nn.Module):
|
|||||||
h = call_module(module, h, emb, context)
|
h = call_module(module, h, emb, context)
|
||||||
hs.append(h)
|
hs.append(h)
|
||||||
|
|
||||||
h = call_module(self.middle_block, h, emb, context)
|
h = call_module(_self.middle_block, h, emb, context)
|
||||||
|
|
||||||
for module in self.output_blocks:
|
for module in _self.output_blocks:
|
||||||
# Deep Shrink
|
# Deep Shrink
|
||||||
if self.ds_depth_1 is not None:
|
if self.ds_depth_1 is not None:
|
||||||
if hs[-1].shape[-2:] != h.shape[-2:]:
|
if hs[-1].shape[-2:] != h.shape[-2:]:
|
||||||
@@ -1156,7 +1221,7 @@ class SdxlUNet2DConditionModel(nn.Module):
|
|||||||
h = resize_like(h, x)
|
h = resize_like(h, x)
|
||||||
|
|
||||||
h = h.type(x.dtype)
|
h = h.type(x.dtype)
|
||||||
h = call_module(self.out, h, emb, context)
|
h = call_module(_self.out, h, emb, context)
|
||||||
|
|
||||||
return h
|
return h
|
||||||
|
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ def cat_h(sliced):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def resblock_forward(_self, num_slices, input_tensor, temb):
|
def resblock_forward(_self, num_slices, input_tensor, temb, **kwargs):
|
||||||
assert _self.upsample is None and _self.downsample is None
|
assert _self.upsample is None and _self.downsample is None
|
||||||
assert _self.norm1.num_groups == _self.norm2.num_groups
|
assert _self.norm1.num_groups == _self.norm2.num_groups
|
||||||
assert temb is None
|
assert temb is None
|
||||||
|
|||||||
@@ -13,8 +13,8 @@ from library import sai_model_spec, model_util, sdxl_model_util
|
|||||||
import lora
|
import lora
|
||||||
|
|
||||||
|
|
||||||
CLAMP_QUANTILE = 0.99
|
# CLAMP_QUANTILE = 0.99
|
||||||
MIN_DIFF = 1e-1
|
# MIN_DIFF = 1e-1
|
||||||
|
|
||||||
|
|
||||||
def save_to_file(file_name, model, state_dict, dtype):
|
def save_to_file(file_name, model, state_dict, dtype):
|
||||||
@@ -29,7 +29,21 @@ def save_to_file(file_name, model, state_dict, dtype):
|
|||||||
torch.save(model, file_name)
|
torch.save(model, file_name)
|
||||||
|
|
||||||
|
|
||||||
def svd(args):
|
def svd(
|
||||||
|
model_org=None,
|
||||||
|
model_tuned=None,
|
||||||
|
save_to=None,
|
||||||
|
dim=4,
|
||||||
|
v2=None,
|
||||||
|
sdxl=None,
|
||||||
|
conv_dim=None,
|
||||||
|
v_parameterization=None,
|
||||||
|
device=None,
|
||||||
|
save_precision=None,
|
||||||
|
clamp_quantile=0.99,
|
||||||
|
min_diff=0.01,
|
||||||
|
no_metadata=False,
|
||||||
|
):
|
||||||
def str_to_dtype(p):
|
def str_to_dtype(p):
|
||||||
if p == "float":
|
if p == "float":
|
||||||
return torch.float
|
return torch.float
|
||||||
@@ -39,44 +53,42 @@ def svd(args):
|
|||||||
return torch.bfloat16
|
return torch.bfloat16
|
||||||
return None
|
return None
|
||||||
|
|
||||||
assert args.v2 != args.sdxl or (
|
assert v2 != sdxl or (not v2 and not sdxl), "v2 and sdxl cannot be specified at the same time / v2とsdxlは同時に指定できません"
|
||||||
not args.v2 and not args.sdxl
|
if v_parameterization is None:
|
||||||
), "v2 and sdxl cannot be specified at the same time / v2とsdxlは同時に指定できません"
|
v_parameterization = v2
|
||||||
if args.v_parameterization is None:
|
|
||||||
args.v_parameterization = args.v2
|
|
||||||
|
|
||||||
save_dtype = str_to_dtype(args.save_precision)
|
save_dtype = str_to_dtype(save_precision)
|
||||||
|
|
||||||
# load models
|
# load models
|
||||||
if not args.sdxl:
|
if not sdxl:
|
||||||
print(f"loading original SD model : {args.model_org}")
|
print(f"loading original SD model : {model_org}")
|
||||||
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_org)
|
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org)
|
||||||
text_encoders_o = [text_encoder_o]
|
text_encoders_o = [text_encoder_o]
|
||||||
print(f"loading tuned SD model : {args.model_tuned}")
|
print(f"loading tuned SD model : {model_tuned}")
|
||||||
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned)
|
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned)
|
||||||
text_encoders_t = [text_encoder_t]
|
text_encoders_t = [text_encoder_t]
|
||||||
model_version = model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization)
|
model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization)
|
||||||
else:
|
else:
|
||||||
print(f"loading original SDXL model : {args.model_org}")
|
print(f"loading original SDXL model : {model_org}")
|
||||||
text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
||||||
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.model_org, "cpu"
|
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, "cpu"
|
||||||
)
|
)
|
||||||
text_encoders_o = [text_encoder_o1, text_encoder_o2]
|
text_encoders_o = [text_encoder_o1, text_encoder_o2]
|
||||||
print(f"loading original SDXL model : {args.model_tuned}")
|
print(f"loading original SDXL model : {model_tuned}")
|
||||||
text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
||||||
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.model_tuned, "cpu"
|
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, "cpu"
|
||||||
)
|
)
|
||||||
text_encoders_t = [text_encoder_t1, text_encoder_t2]
|
text_encoders_t = [text_encoder_t1, text_encoder_t2]
|
||||||
model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0
|
model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0
|
||||||
|
|
||||||
# create LoRA network to extract weights: Use dim (rank) as alpha
|
# create LoRA network to extract weights: Use dim (rank) as alpha
|
||||||
if args.conv_dim is None:
|
if conv_dim is None:
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
else:
|
else:
|
||||||
kwargs = {"conv_dim": args.conv_dim, "conv_alpha": args.conv_dim}
|
kwargs = {"conv_dim": conv_dim, "conv_alpha": conv_dim}
|
||||||
|
|
||||||
lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoders_o, unet_o, **kwargs)
|
lora_network_o = lora.create_network(1.0, dim, dim, None, text_encoders_o, unet_o, **kwargs)
|
||||||
lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoders_t, unet_t, **kwargs)
|
lora_network_t = lora.create_network(1.0, dim, dim, None, text_encoders_t, unet_t, **kwargs)
|
||||||
assert len(lora_network_o.text_encoder_loras) == len(
|
assert len(lora_network_o.text_encoder_loras) == len(
|
||||||
lora_network_t.text_encoder_loras
|
lora_network_t.text_encoder_loras
|
||||||
), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) "
|
), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) "
|
||||||
@@ -91,9 +103,9 @@ def svd(args):
|
|||||||
diff = module_t.weight - module_o.weight
|
diff = module_t.weight - module_o.weight
|
||||||
|
|
||||||
# Text Encoder might be same
|
# Text Encoder might be same
|
||||||
if not text_encoder_different and torch.max(torch.abs(diff)) > MIN_DIFF:
|
if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff:
|
||||||
text_encoder_different = True
|
text_encoder_different = True
|
||||||
print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {MIN_DIFF}")
|
print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {min_diff}")
|
||||||
|
|
||||||
diff = diff.float()
|
diff = diff.float()
|
||||||
diffs[lora_name] = diff
|
diffs[lora_name] = diff
|
||||||
@@ -120,16 +132,16 @@ def svd(args):
|
|||||||
lora_weights = {}
|
lora_weights = {}
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for lora_name, mat in tqdm(list(diffs.items())):
|
for lora_name, mat in tqdm(list(diffs.items())):
|
||||||
# if args.conv_dim is None, diffs do not include LoRAs for conv2d-3x3
|
# if conv_dim is None, diffs do not include LoRAs for conv2d-3x3
|
||||||
conv2d = len(mat.size()) == 4
|
conv2d = len(mat.size()) == 4
|
||||||
kernel_size = None if not conv2d else mat.size()[2:4]
|
kernel_size = None if not conv2d else mat.size()[2:4]
|
||||||
conv2d_3x3 = conv2d and kernel_size != (1, 1)
|
conv2d_3x3 = conv2d and kernel_size != (1, 1)
|
||||||
|
|
||||||
rank = args.dim if not conv2d_3x3 or args.conv_dim is None else args.conv_dim
|
rank = dim if not conv2d_3x3 or conv_dim is None else conv_dim
|
||||||
out_dim, in_dim = mat.size()[0:2]
|
out_dim, in_dim = mat.size()[0:2]
|
||||||
|
|
||||||
if args.device:
|
if device:
|
||||||
mat = mat.to(args.device)
|
mat = mat.to(device)
|
||||||
|
|
||||||
# print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim)
|
# print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim)
|
||||||
rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
|
rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
|
||||||
@@ -149,7 +161,7 @@ def svd(args):
|
|||||||
Vh = Vh[:rank, :]
|
Vh = Vh[:rank, :]
|
||||||
|
|
||||||
dist = torch.cat([U.flatten(), Vh.flatten()])
|
dist = torch.cat([U.flatten(), Vh.flatten()])
|
||||||
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
hi_val = torch.quantile(dist, clamp_quantile)
|
||||||
low_val = -hi_val
|
low_val = -hi_val
|
||||||
|
|
||||||
U = U.clamp(low_val, hi_val)
|
U = U.clamp(low_val, hi_val)
|
||||||
@@ -178,34 +190,32 @@ def svd(args):
|
|||||||
info = lora_network_save.load_state_dict(lora_sd)
|
info = lora_network_save.load_state_dict(lora_sd)
|
||||||
print(f"Loading extracted LoRA weights: {info}")
|
print(f"Loading extracted LoRA weights: {info}")
|
||||||
|
|
||||||
dir_name = os.path.dirname(args.save_to)
|
dir_name = os.path.dirname(save_to)
|
||||||
if dir_name and not os.path.exists(dir_name):
|
if dir_name and not os.path.exists(dir_name):
|
||||||
os.makedirs(dir_name, exist_ok=True)
|
os.makedirs(dir_name, exist_ok=True)
|
||||||
|
|
||||||
# minimum metadata
|
# minimum metadata
|
||||||
net_kwargs = {}
|
net_kwargs = {}
|
||||||
if args.conv_dim is not None:
|
if conv_dim is not None:
|
||||||
net_kwargs["conv_dim"] = args.conv_dim
|
net_kwargs["conv_dim"] = str(conv_dim)
|
||||||
net_kwargs["conv_alpha"] = args.conv_dim
|
net_kwargs["conv_alpha"] = str(float(conv_dim))
|
||||||
|
|
||||||
metadata = {
|
metadata = {
|
||||||
"ss_v2": str(args.v2),
|
"ss_v2": str(v2),
|
||||||
"ss_base_model_version": model_version,
|
"ss_base_model_version": model_version,
|
||||||
"ss_network_module": "networks.lora",
|
"ss_network_module": "networks.lora",
|
||||||
"ss_network_dim": str(args.dim),
|
"ss_network_dim": str(dim),
|
||||||
"ss_network_alpha": str(args.dim),
|
"ss_network_alpha": str(float(dim)),
|
||||||
"ss_network_args": json.dumps(net_kwargs),
|
"ss_network_args": json.dumps(net_kwargs),
|
||||||
}
|
}
|
||||||
|
|
||||||
if not args.no_metadata:
|
if not no_metadata:
|
||||||
title = os.path.splitext(os.path.basename(args.save_to))[0]
|
title = os.path.splitext(os.path.basename(save_to))[0]
|
||||||
sai_metadata = sai_model_spec.build_metadata(
|
sai_metadata = sai_model_spec.build_metadata(None, v2, v_parameterization, sdxl, True, False, time.time(), title=title)
|
||||||
None, args.v2, args.v_parameterization, args.sdxl, True, False, time.time(), title=title
|
|
||||||
)
|
|
||||||
metadata.update(sai_metadata)
|
metadata.update(sai_metadata)
|
||||||
|
|
||||||
lora_network_save.save_weights(args.save_to, save_dtype, metadata)
|
lora_network_save.save_weights(save_to, save_dtype, metadata)
|
||||||
print(f"LoRA weights are saved to: {args.save_to}")
|
print(f"LoRA weights are saved to: {save_to}")
|
||||||
|
|
||||||
|
|
||||||
def setup_parser() -> argparse.ArgumentParser:
|
def setup_parser() -> argparse.ArgumentParser:
|
||||||
@@ -213,7 +223,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む")
|
parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--v_parameterization",
|
"--v_parameterization",
|
||||||
type=bool,
|
action="store_true",
|
||||||
default=None,
|
default=None,
|
||||||
help="make LoRA metadata for v-parameterization (default is same to v2) / 作成するLoRAのメタデータにv-parameterization用と設定する(省略時はv2と同じ)",
|
help="make LoRA metadata for v-parameterization (default is same to v2) / 作成するLoRAのメタデータにv-parameterization用と設定する(省略時はv2と同じ)",
|
||||||
)
|
)
|
||||||
@@ -231,16 +241,22 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
"--model_org",
|
"--model_org",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
|
required=True,
|
||||||
help="Stable Diffusion original model: ckpt or safetensors file / 元モデル、ckptまたはsafetensors",
|
help="Stable Diffusion original model: ckpt or safetensors file / 元モデル、ckptまたはsafetensors",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model_tuned",
|
"--model_tuned",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
|
required=True,
|
||||||
help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors",
|
help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
|
"--save_to",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
required=True,
|
||||||
|
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors",
|
||||||
)
|
)
|
||||||
parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)")
|
parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -250,6 +266,19 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数(rank)(デフォルトNone、適用なし)",
|
help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数(rank)(デフォルトNone、適用なし)",
|
||||||
)
|
)
|
||||||
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
||||||
|
parser.add_argument(
|
||||||
|
"--clamp_quantile",
|
||||||
|
type=float,
|
||||||
|
default=0.99,
|
||||||
|
help="Quantile clamping value, float, (0-1). Default = 0.99 / 値をクランプするための分位点、float、(0-1)。デフォルトは0.99",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--min_diff",
|
||||||
|
type=float,
|
||||||
|
default=0.01,
|
||||||
|
help="Minimum difference between finetuned model and base to consider them different enough to extract, float, (0-1). Default = 0.01 /"
|
||||||
|
+ "LoRAを抽出するために元モデルと派生モデルの差分の最小値、float、(0-1)。デフォルトは0.01",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--no_metadata",
|
"--no_metadata",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
@@ -264,4 +293,4 @@ if __name__ == "__main__":
|
|||||||
parser = setup_parser()
|
parser = setup_parser()
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
svd(args)
|
svd(**vars(args))
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ import library.train_util as train_util
|
|||||||
import library.sdxl_model_util as sdxl_model_util
|
import library.sdxl_model_util as sdxl_model_util
|
||||||
import library.sdxl_train_util as sdxl_train_util
|
import library.sdxl_train_util as sdxl_train_util
|
||||||
from networks.lora import LoRANetwork
|
from networks.lora import LoRANetwork
|
||||||
from library.sdxl_original_unet import SdxlUNet2DConditionModel
|
from library.sdxl_original_unet import InferSdxlUNet2DConditionModel
|
||||||
from library.original_unet import FlashAttentionFunction
|
from library.original_unet import FlashAttentionFunction
|
||||||
from networks.control_net_lllite import ControlNetLLLite
|
from networks.control_net_lllite import ControlNetLLLite
|
||||||
|
|
||||||
@@ -290,7 +290,7 @@ class PipelineLike:
|
|||||||
vae: AutoencoderKL,
|
vae: AutoencoderKL,
|
||||||
text_encoders: List[CLIPTextModel],
|
text_encoders: List[CLIPTextModel],
|
||||||
tokenizers: List[CLIPTokenizer],
|
tokenizers: List[CLIPTokenizer],
|
||||||
unet: SdxlUNet2DConditionModel,
|
unet: InferSdxlUNet2DConditionModel,
|
||||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
||||||
clip_skip: int,
|
clip_skip: int,
|
||||||
):
|
):
|
||||||
@@ -328,7 +328,7 @@ class PipelineLike:
|
|||||||
self.vae = vae
|
self.vae = vae
|
||||||
self.text_encoders = text_encoders
|
self.text_encoders = text_encoders
|
||||||
self.tokenizers = tokenizers
|
self.tokenizers = tokenizers
|
||||||
self.unet: SdxlUNet2DConditionModel = unet
|
self.unet: InferSdxlUNet2DConditionModel = unet
|
||||||
self.scheduler = scheduler
|
self.scheduler = scheduler
|
||||||
self.safety_checker = None
|
self.safety_checker = None
|
||||||
|
|
||||||
@@ -1611,6 +1611,7 @@ def main(args):
|
|||||||
(_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model(
|
(_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model(
|
||||||
args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype
|
args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype
|
||||||
)
|
)
|
||||||
|
unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel(unet)
|
||||||
|
|
||||||
# xformers、Hypernetwork対応
|
# xformers、Hypernetwork対応
|
||||||
if not args.diffusers_xformers:
|
if not args.diffusers_xformers:
|
||||||
@@ -1766,10 +1767,14 @@ def main(args):
|
|||||||
print("set vae_dtype to float32")
|
print("set vae_dtype to float32")
|
||||||
vae_dtype = torch.float32
|
vae_dtype = torch.float32
|
||||||
vae.to(vae_dtype).to(device)
|
vae.to(vae_dtype).to(device)
|
||||||
|
vae.eval()
|
||||||
|
|
||||||
text_encoder1.to(dtype).to(device)
|
text_encoder1.to(dtype).to(device)
|
||||||
text_encoder2.to(dtype).to(device)
|
text_encoder2.to(dtype).to(device)
|
||||||
unet.to(dtype).to(device)
|
unet.to(dtype).to(device)
|
||||||
|
text_encoder1.eval()
|
||||||
|
text_encoder2.eval()
|
||||||
|
unet.eval()
|
||||||
|
|
||||||
# networkを組み込む
|
# networkを組み込む
|
||||||
if args.network_module:
|
if args.network_module:
|
||||||
|
|||||||
@@ -460,7 +460,7 @@ def train(args):
|
|||||||
loss = loss * loss_weights
|
loss = loss * loss_weights
|
||||||
|
|
||||||
if args.min_snr_gamma:
|
if args.min_snr_gamma:
|
||||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||||
if args.scale_v_pred_loss_like_noise_pred:
|
if args.scale_v_pred_loss_like_noise_pred:
|
||||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||||
if args.v_pred_like_loss:
|
if args.v_pred_like_loss:
|
||||||
|
|||||||
@@ -430,7 +430,7 @@ def train(args):
|
|||||||
loss = loss * loss_weights
|
loss = loss * loss_weights
|
||||||
|
|
||||||
if args.min_snr_gamma:
|
if args.min_snr_gamma:
|
||||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||||
if args.scale_v_pred_loss_like_noise_pred:
|
if args.scale_v_pred_loss_like_noise_pred:
|
||||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||||
if args.v_pred_like_loss:
|
if args.v_pred_like_loss:
|
||||||
|
|||||||
@@ -449,7 +449,7 @@ def train(args):
|
|||||||
loss = loss * loss_weights
|
loss = loss * loss_weights
|
||||||
|
|
||||||
if args.min_snr_gamma:
|
if args.min_snr_gamma:
|
||||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||||
|
|
||||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||||
|
|
||||||
|
|||||||
@@ -342,7 +342,7 @@ def train(args):
|
|||||||
loss = loss * loss_weights
|
loss = loss * loss_weights
|
||||||
|
|
||||||
if args.min_snr_gamma:
|
if args.min_snr_gamma:
|
||||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||||
if args.scale_v_pred_loss_like_noise_pred:
|
if args.scale_v_pred_loss_like_noise_pred:
|
||||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||||
if args.debiased_estimation_loss:
|
if args.debiased_estimation_loss:
|
||||||
|
|||||||
@@ -812,7 +812,7 @@ class NetworkTrainer:
|
|||||||
loss = loss * loss_weights
|
loss = loss * loss_weights
|
||||||
|
|
||||||
if args.min_snr_gamma:
|
if args.min_snr_gamma:
|
||||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||||
if args.scale_v_pred_loss_like_noise_pred:
|
if args.scale_v_pred_loss_like_noise_pred:
|
||||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||||
if args.v_pred_like_loss:
|
if args.v_pred_like_loss:
|
||||||
|
|||||||
@@ -578,7 +578,7 @@ class TextualInversionTrainer:
|
|||||||
loss = loss * loss_weights
|
loss = loss * loss_weights
|
||||||
|
|
||||||
if args.min_snr_gamma:
|
if args.min_snr_gamma:
|
||||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||||
if args.scale_v_pred_loss_like_noise_pred:
|
if args.scale_v_pred_loss_like_noise_pred:
|
||||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||||
if args.v_pred_like_loss:
|
if args.v_pred_like_loss:
|
||||||
|
|||||||
@@ -469,7 +469,7 @@ def train(args):
|
|||||||
|
|
||||||
loss = loss * loss_weights
|
loss = loss * loss_weights
|
||||||
if args.min_snr_gamma:
|
if args.min_snr_gamma:
|
||||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||||
if args.scale_v_pred_loss_like_noise_pred:
|
if args.scale_v_pred_loss_like_noise_pred:
|
||||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||||
if args.debiased_estimation_loss:
|
if args.debiased_estimation_loss:
|
||||||
|
|||||||
Reference in New Issue
Block a user