diff --git a/fine_tune.py b/fine_tune.py index 52e84c43..b0787677 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -355,7 +355,7 @@ def train(args): loss = loss.mean([1, 2, 3]) if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.debiased_estimation_loss: diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index c656e6c6..fb7866fc 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -105,7 +105,7 @@ import library.train_util as train_util from networks.lora import LoRANetwork import tools.original_control_net as original_control_net from tools.original_control_net import ControlNetInfo -from library.original_unet import UNet2DConditionModel +from library.original_unet import UNet2DConditionModel, InferUNet2DConditionModel from library.original_unet import FlashAttentionFunction from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI @@ -378,7 +378,7 @@ class PipelineLike: vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, + unet: InferUNet2DConditionModel, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], clip_skip: int, clip_model: CLIPModel, @@ -2365,6 +2365,7 @@ def main(args): ) original_unet.load_state_dict(unet.state_dict()) unet = original_unet + unet: InferUNet2DConditionModel = InferUNet2DConditionModel(unet) # VAEを読み込む if args.vae is not None: @@ -2521,13 +2522,20 @@ def main(args): vae = sli_vae del sli_vae vae.to(dtype).to(device) + vae.eval() text_encoder.to(dtype).to(device) unet.to(dtype).to(device) + + text_encoder.eval() + unet.eval() + if clip_model is not None: clip_model.to(dtype).to(device) + clip_model.eval() if vgg16_model is not None: vgg16_model.to(dtype).to(device) + vgg16_model.eval() # networkを組み込む if args.network_module: diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 28b625d3..e0a026da 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -57,10 +57,13 @@ def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler): noise_scheduler.alphas_cumprod = alphas_cumprod -def apply_snr_weight(loss, timesteps, noise_scheduler, gamma): +def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False): snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) - gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr) - snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float().to(loss.device) # from paper + min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma)) + if v_prediction: + snr_weight = torch.div(min_snr_gamma, snr+1).float().to(loss.device) + else: + snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device) loss = loss * snr_weight return loss diff --git a/library/original_unet.py b/library/original_unet.py index 0454f13f..938b0b64 100644 --- a/library/original_unet.py +++ b/library/original_unet.py @@ -1148,10 +1148,6 @@ class UpBlock2D(nn.Module): res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] - # Deep Shrink - if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]: - hidden_states = resize_like(hidden_states, res_hidden_states) - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: @@ -1244,10 +1240,6 @@ class CrossAttnUpBlock2D(nn.Module): res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] - # Deep Shrink - if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]: - hidden_states = resize_like(hidden_states, res_hidden_states) - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: @@ -1444,31 +1436,6 @@ class UNet2DConditionModel(nn.Module): self.conv_act = nn.SiLU() self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1) - # Deep Shrink - self.ds_depth_1 = None - self.ds_depth_2 = None - self.ds_timesteps_1 = None - self.ds_timesteps_2 = None - self.ds_ratio = None - - def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5): - if ds_depth_1 is None: - print("Deep Shrink is disabled.") - self.ds_depth_1 = None - self.ds_timesteps_1 = None - self.ds_depth_2 = None - self.ds_timesteps_2 = None - self.ds_ratio = None - else: - print( - f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]" - ) - self.ds_depth_1 = ds_depth_1 - self.ds_timesteps_1 = ds_timesteps_1 - self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1 - self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000 - self.ds_ratio = ds_ratio - # region diffusers compatibility def prepare_config(self): self.config = SimpleNamespace() @@ -1572,20 +1539,7 @@ class UNet2DConditionModel(nn.Module): sample = self.conv_in(sample) down_block_res_samples = (sample,) - for depth, downsample_block in enumerate(self.down_blocks): - # Deep Shrink - if self.ds_depth_1 is not None: - if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or ( - self.ds_depth_2 is not None - and depth == self.ds_depth_2 - and timesteps[0] < self.ds_timesteps_1 - and timesteps[0] >= self.ds_timesteps_2 - ): - org_dtype = sample.dtype - if org_dtype == torch.bfloat16: - sample = sample.to(torch.float32) - sample = F.interpolate(sample, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype) - + for downsample_block in self.down_blocks: # downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、 # まあこちらのほうがわかりやすいかもしれない if downsample_block.has_cross_attention: @@ -1668,3 +1622,255 @@ class UNet2DConditionModel(nn.Module): timesteps = timesteps.expand(sample.shape[0]) return timesteps + + +class InferUNet2DConditionModel: + def __init__(self, original_unet: UNet2DConditionModel): + self.delegate = original_unet + + # override original model's forward method: because forward is not called by `__call__` + # overriding `__call__` is not enough, because nn.Module.forward has a special handling + self.delegate.forward = self.forward + + # override original model's up blocks' forward method + for up_block in self.delegate.up_blocks: + if up_block.__class__.__name__ == "UpBlock2D": + + def resnet_wrapper(func, block): + def forward(*args, **kwargs): + return func(block, *args, **kwargs) + + return forward + + up_block.forward = resnet_wrapper(self.up_block_forward, up_block) + + elif up_block.__class__.__name__ == "CrossAttnUpBlock2D": + + def cross_attn_up_wrapper(func, block): + def forward(*args, **kwargs): + return func(block, *args, **kwargs) + + return forward + + up_block.forward = cross_attn_up_wrapper(self.cross_attn_up_block_forward, up_block) + + # Deep Shrink + self.ds_depth_1 = None + self.ds_depth_2 = None + self.ds_timesteps_1 = None + self.ds_timesteps_2 = None + self.ds_ratio = None + + # call original model's methods + def __getattr__(self, name): + return getattr(self.delegate, name) + + def __call__(self, *args, **kwargs): + return self.delegate(*args, **kwargs) + + def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5): + if ds_depth_1 is None: + print("Deep Shrink is disabled.") + self.ds_depth_1 = None + self.ds_timesteps_1 = None + self.ds_depth_2 = None + self.ds_timesteps_2 = None + self.ds_ratio = None + else: + print( + f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]" + ) + self.ds_depth_1 = ds_depth_1 + self.ds_timesteps_1 = ds_timesteps_1 + self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1 + self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000 + self.ds_ratio = ds_ratio + + def up_block_forward(self, _self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + for resnet in _self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # Deep Shrink + if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]: + hidden_states = resize_like(hidden_states, res_hidden_states) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb) + + if _self.upsamplers is not None: + for upsampler in _self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + def cross_attn_up_block_forward( + self, + _self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + upsample_size=None, + ): + for resnet, attn in zip(_self.resnets, _self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # Deep Shrink + if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]: + hidden_states = resize_like(hidden_states, res_hidden_states) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + + if _self.upsamplers is not None: + for upsampler in _self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + return_dict: bool = True, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + ) -> Union[Dict, Tuple]: + r""" + current implementation is a copy of `UNet2DConditionModel.forward()` with Deep Shrink. + """ + + r""" + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a dict instead of a plain tuple. + + Returns: + `SampleOutput` or `tuple`: + `SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + """ + + _self = self.delegate + + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + # デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある + # ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する + # 多分画質が悪くなるので、64で割り切れるようにしておくのが良い + default_overall_up_factor = 2**_self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + # 64で割り切れないときはupsamplerにサイズを伝える + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + # logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # 1. time + timesteps = timestep + timesteps = _self.handle_unusual_timesteps(sample, timesteps) # 変な時だけ処理 + + t_emb = _self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + # timestepsは重みを含まないので常にfloat32のテンソルを返す + # しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある + # time_projでキャストしておけばいいんじゃね? + t_emb = t_emb.to(dtype=_self.dtype) + emb = _self.time_embedding(t_emb) + + # 2. pre-process + sample = _self.conv_in(sample) + + down_block_res_samples = (sample,) + for depth, downsample_block in enumerate(_self.down_blocks): + # Deep Shrink + if self.ds_depth_1 is not None: + if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or ( + self.ds_depth_2 is not None + and depth == self.ds_depth_2 + and timesteps[0] < self.ds_timesteps_1 + and timesteps[0] >= self.ds_timesteps_2 + ): + org_dtype = sample.dtype + if org_dtype == torch.bfloat16: + sample = sample.to(torch.float32) + sample = F.interpolate(sample, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype) + + # downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、 + # まあこちらのほうがわかりやすいかもしれない + if downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # skip connectionにControlNetの出力を追加する + if down_block_additional_residuals is not None: + down_block_res_samples = list(down_block_res_samples) + for i in range(len(down_block_res_samples)): + down_block_res_samples[i] += down_block_additional_residuals[i] + down_block_res_samples = tuple(down_block_res_samples) + + # 4. mid + sample = _self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states) + + # ControlNetの出力を追加する + if mid_block_additional_residual is not None: + sample += mid_block_additional_residual + + # 5. up + for i, upsample_block in enumerate(_self.up_blocks): + is_final_block = i == len(_self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # skip connection + + # if we have not reached the final block and need to forward the upsample size, we do it here + # 前述のように最後のブロック以外ではupsample_sizeを伝える + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + + # 6. post-process + sample = _self.conv_norm_out(sample) + sample = _self.conv_act(sample) + sample = _self.conv_out(sample) + + if not return_dict: + return (sample,) + + return SampleOutput(sample=sample) diff --git a/library/sdxl_original_unet.py b/library/sdxl_original_unet.py index d51dfdbc..babda8ec 100644 --- a/library/sdxl_original_unet.py +++ b/library/sdxl_original_unet.py @@ -24,7 +24,7 @@ import math from types import SimpleNamespace -from typing import Optional +from typing import Any, Optional import torch import torch.utils.checkpoint from torch import nn @@ -1013,31 +1013,6 @@ class SdxlUNet2DConditionModel(nn.Module): [GroupNorm32(32, self.model_channels), nn.SiLU(), nn.Conv2d(self.model_channels, self.out_channels, 3, padding=1)] ) - # Deep Shrink - self.ds_depth_1 = None - self.ds_depth_2 = None - self.ds_timesteps_1 = None - self.ds_timesteps_2 = None - self.ds_ratio = None - - def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5): - if ds_depth_1 is None: - print("Deep Shrink is disabled.") - self.ds_depth_1 = None - self.ds_timesteps_1 = None - self.ds_depth_2 = None - self.ds_timesteps_2 = None - self.ds_ratio = None - else: - print( - f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]" - ) - self.ds_depth_1 = ds_depth_1 - self.ds_timesteps_1 = ds_timesteps_1 - self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1 - self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000 - self.ds_ratio = ds_ratio - # region diffusers compatibility def prepare_config(self): self.config = SimpleNamespace() @@ -1120,7 +1095,97 @@ class SdxlUNet2DConditionModel(nn.Module): # h = x.type(self.dtype) h = x - for depth, module in enumerate(self.input_blocks): + for module in self.input_blocks: + h = call_module(module, h, emb, context) + hs.append(h) + + h = call_module(self.middle_block, h, emb, context) + + for module in self.output_blocks: + h = torch.cat([h, hs.pop()], dim=1) + h = call_module(module, h, emb, context) + + h = h.type(x.dtype) + h = call_module(self.out, h, emb, context) + + return h + + +class InferSdxlUNet2DConditionModel: + def __init__(self, original_unet: SdxlUNet2DConditionModel, **kwargs): + self.delegate = original_unet + + # override original model's forward method: because forward is not called by `__call__` + # overriding `__call__` is not enough, because nn.Module.forward has a special handling + self.delegate.forward = self.forward + + # Deep Shrink + self.ds_depth_1 = None + self.ds_depth_2 = None + self.ds_timesteps_1 = None + self.ds_timesteps_2 = None + self.ds_ratio = None + + # call original model's methods + def __getattr__(self, name): + return getattr(self.delegate, name) + + def __call__(self, *args, **kwargs): + return self.delegate(*args, **kwargs) + + def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5): + if ds_depth_1 is None: + print("Deep Shrink is disabled.") + self.ds_depth_1 = None + self.ds_timesteps_1 = None + self.ds_depth_2 = None + self.ds_timesteps_2 = None + self.ds_ratio = None + else: + print( + f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]" + ) + self.ds_depth_1 = ds_depth_1 + self.ds_timesteps_1 = ds_timesteps_1 + self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1 + self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000 + self.ds_ratio = ds_ratio + + def forward(self, x, timesteps=None, context=None, y=None, **kwargs): + r""" + current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink. + """ + _self = self.delegate + + # broadcast timesteps to batch dimension + timesteps = timesteps.expand(x.shape[0]) + + hs = [] + t_emb = get_timestep_embedding(timesteps, _self.model_channels) # , repeat_only=False) + t_emb = t_emb.to(x.dtype) + emb = _self.time_embed(t_emb) + + assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}" + assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}" + # assert x.dtype == _self.dtype + emb = emb + _self.label_emb(y) + + def call_module(module, h, emb, context): + x = h + for layer in module: + # print(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None) + if isinstance(layer, ResnetBlock2D): + x = layer(x, emb) + elif isinstance(layer, Transformer2DModel): + x = layer(x, context) + else: + x = layer(x) + return x + + # h = x.type(self.dtype) + h = x + + for depth, module in enumerate(_self.input_blocks): # Deep Shrink if self.ds_depth_1 is not None: if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or ( @@ -1138,9 +1203,9 @@ class SdxlUNet2DConditionModel(nn.Module): h = call_module(module, h, emb, context) hs.append(h) - h = call_module(self.middle_block, h, emb, context) + h = call_module(_self.middle_block, h, emb, context) - for module in self.output_blocks: + for module in _self.output_blocks: # Deep Shrink if self.ds_depth_1 is not None: if hs[-1].shape[-2:] != h.shape[-2:]: @@ -1156,7 +1221,7 @@ class SdxlUNet2DConditionModel(nn.Module): h = resize_like(h, x) h = h.type(x.dtype) - h = call_module(self.out, h, emb, context) + h = call_module(_self.out, h, emb, context) return h diff --git a/library/slicing_vae.py b/library/slicing_vae.py index 31b2bd0a..5c4e056d 100644 --- a/library/slicing_vae.py +++ b/library/slicing_vae.py @@ -62,7 +62,7 @@ def cat_h(sliced): return x -def resblock_forward(_self, num_slices, input_tensor, temb): +def resblock_forward(_self, num_slices, input_tensor, temb, **kwargs): assert _self.upsample is None and _self.downsample is None assert _self.norm1.num_groups == _self.norm2.num_groups assert temb is None diff --git a/networks/extract_lora_from_models.py b/networks/extract_lora_from_models.py index dba7cd4e..6357df55 100644 --- a/networks/extract_lora_from_models.py +++ b/networks/extract_lora_from_models.py @@ -13,8 +13,8 @@ from library import sai_model_spec, model_util, sdxl_model_util import lora -CLAMP_QUANTILE = 0.99 -MIN_DIFF = 1e-1 +# CLAMP_QUANTILE = 0.99 +# MIN_DIFF = 1e-1 def save_to_file(file_name, model, state_dict, dtype): @@ -29,7 +29,21 @@ def save_to_file(file_name, model, state_dict, dtype): torch.save(model, file_name) -def svd(args): +def svd( + model_org=None, + model_tuned=None, + save_to=None, + dim=4, + v2=None, + sdxl=None, + conv_dim=None, + v_parameterization=None, + device=None, + save_precision=None, + clamp_quantile=0.99, + min_diff=0.01, + no_metadata=False, +): def str_to_dtype(p): if p == "float": return torch.float @@ -39,44 +53,42 @@ def svd(args): return torch.bfloat16 return None - assert args.v2 != args.sdxl or ( - not args.v2 and not args.sdxl - ), "v2 and sdxl cannot be specified at the same time / v2とsdxlは同時に指定できません" - if args.v_parameterization is None: - args.v_parameterization = args.v2 + assert v2 != sdxl or (not v2 and not sdxl), "v2 and sdxl cannot be specified at the same time / v2とsdxlは同時に指定できません" + if v_parameterization is None: + v_parameterization = v2 - save_dtype = str_to_dtype(args.save_precision) + save_dtype = str_to_dtype(save_precision) # load models - if not args.sdxl: - print(f"loading original SD model : {args.model_org}") - text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_org) + if not sdxl: + print(f"loading original SD model : {model_org}") + text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org) text_encoders_o = [text_encoder_o] - print(f"loading tuned SD model : {args.model_tuned}") - text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned) + print(f"loading tuned SD model : {model_tuned}") + text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned) text_encoders_t = [text_encoder_t] - model_version = model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization) + model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization) else: - print(f"loading original SDXL model : {args.model_org}") + print(f"loading original SDXL model : {model_org}") text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( - sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.model_org, "cpu" + sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, "cpu" ) text_encoders_o = [text_encoder_o1, text_encoder_o2] - print(f"loading original SDXL model : {args.model_tuned}") + print(f"loading original SDXL model : {model_tuned}") text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( - sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.model_tuned, "cpu" + sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, "cpu" ) text_encoders_t = [text_encoder_t1, text_encoder_t2] model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0 # create LoRA network to extract weights: Use dim (rank) as alpha - if args.conv_dim is None: + if conv_dim is None: kwargs = {} else: - kwargs = {"conv_dim": args.conv_dim, "conv_alpha": args.conv_dim} + kwargs = {"conv_dim": conv_dim, "conv_alpha": conv_dim} - lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoders_o, unet_o, **kwargs) - lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoders_t, unet_t, **kwargs) + lora_network_o = lora.create_network(1.0, dim, dim, None, text_encoders_o, unet_o, **kwargs) + lora_network_t = lora.create_network(1.0, dim, dim, None, text_encoders_t, unet_t, **kwargs) assert len(lora_network_o.text_encoder_loras) == len( lora_network_t.text_encoder_loras ), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) " @@ -91,9 +103,9 @@ def svd(args): diff = module_t.weight - module_o.weight # Text Encoder might be same - if not text_encoder_different and torch.max(torch.abs(diff)) > MIN_DIFF: + if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff: text_encoder_different = True - print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {MIN_DIFF}") + print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {min_diff}") diff = diff.float() diffs[lora_name] = diff @@ -120,16 +132,16 @@ def svd(args): lora_weights = {} with torch.no_grad(): for lora_name, mat in tqdm(list(diffs.items())): - # if args.conv_dim is None, diffs do not include LoRAs for conv2d-3x3 + # if conv_dim is None, diffs do not include LoRAs for conv2d-3x3 conv2d = len(mat.size()) == 4 kernel_size = None if not conv2d else mat.size()[2:4] conv2d_3x3 = conv2d and kernel_size != (1, 1) - rank = args.dim if not conv2d_3x3 or args.conv_dim is None else args.conv_dim + rank = dim if not conv2d_3x3 or conv_dim is None else conv_dim out_dim, in_dim = mat.size()[0:2] - if args.device: - mat = mat.to(args.device) + if device: + mat = mat.to(device) # print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim) rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim @@ -149,7 +161,7 @@ def svd(args): Vh = Vh[:rank, :] dist = torch.cat([U.flatten(), Vh.flatten()]) - hi_val = torch.quantile(dist, CLAMP_QUANTILE) + hi_val = torch.quantile(dist, clamp_quantile) low_val = -hi_val U = U.clamp(low_val, hi_val) @@ -178,34 +190,32 @@ def svd(args): info = lora_network_save.load_state_dict(lora_sd) print(f"Loading extracted LoRA weights: {info}") - dir_name = os.path.dirname(args.save_to) + dir_name = os.path.dirname(save_to) if dir_name and not os.path.exists(dir_name): os.makedirs(dir_name, exist_ok=True) # minimum metadata net_kwargs = {} - if args.conv_dim is not None: - net_kwargs["conv_dim"] = args.conv_dim - net_kwargs["conv_alpha"] = args.conv_dim + if conv_dim is not None: + net_kwargs["conv_dim"] = str(conv_dim) + net_kwargs["conv_alpha"] = str(float(conv_dim)) metadata = { - "ss_v2": str(args.v2), + "ss_v2": str(v2), "ss_base_model_version": model_version, "ss_network_module": "networks.lora", - "ss_network_dim": str(args.dim), - "ss_network_alpha": str(args.dim), + "ss_network_dim": str(dim), + "ss_network_alpha": str(float(dim)), "ss_network_args": json.dumps(net_kwargs), } - if not args.no_metadata: - title = os.path.splitext(os.path.basename(args.save_to))[0] - sai_metadata = sai_model_spec.build_metadata( - None, args.v2, args.v_parameterization, args.sdxl, True, False, time.time(), title=title - ) + if not no_metadata: + title = os.path.splitext(os.path.basename(save_to))[0] + sai_metadata = sai_model_spec.build_metadata(None, v2, v_parameterization, sdxl, True, False, time.time(), title=title) metadata.update(sai_metadata) - lora_network_save.save_weights(args.save_to, save_dtype, metadata) - print(f"LoRA weights are saved to: {args.save_to}") + lora_network_save.save_weights(save_to, save_dtype, metadata) + print(f"LoRA weights are saved to: {save_to}") def setup_parser() -> argparse.ArgumentParser: @@ -213,7 +223,7 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む") parser.add_argument( "--v_parameterization", - type=bool, + action="store_true", default=None, help="make LoRA metadata for v-parameterization (default is same to v2) / 作成するLoRAのメタデータにv-parameterization用と設定する(省略時はv2と同じ)", ) @@ -231,16 +241,22 @@ def setup_parser() -> argparse.ArgumentParser: "--model_org", type=str, default=None, + required=True, help="Stable Diffusion original model: ckpt or safetensors file / 元モデル、ckptまたはsafetensors", ) parser.add_argument( "--model_tuned", type=str, default=None, + required=True, help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors", ) parser.add_argument( - "--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors" + "--save_to", + type=str, + default=None, + required=True, + help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors", ) parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)") parser.add_argument( @@ -250,6 +266,19 @@ def setup_parser() -> argparse.ArgumentParser: help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数(rank)(デフォルトNone、適用なし)", ) parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") + parser.add_argument( + "--clamp_quantile", + type=float, + default=0.99, + help="Quantile clamping value, float, (0-1). Default = 0.99 / 値をクランプするための分位点、float、(0-1)。デフォルトは0.99", + ) + parser.add_argument( + "--min_diff", + type=float, + default=0.01, + help="Minimum difference between finetuned model and base to consider them different enough to extract, float, (0-1). Default = 0.01 /" + + "LoRAを抽出するために元モデルと派生モデルの差分の最小値、float、(0-1)。デフォルトは0.01", + ) parser.add_argument( "--no_metadata", action="store_true", @@ -264,4 +293,4 @@ if __name__ == "__main__": parser = setup_parser() args = parser.parse_args() - svd(args) + svd(**vars(args)) diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index 8be04643..c8bd38dd 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -57,7 +57,7 @@ import library.train_util as train_util import library.sdxl_model_util as sdxl_model_util import library.sdxl_train_util as sdxl_train_util from networks.lora import LoRANetwork -from library.sdxl_original_unet import SdxlUNet2DConditionModel +from library.sdxl_original_unet import InferSdxlUNet2DConditionModel from library.original_unet import FlashAttentionFunction from networks.control_net_lllite import ControlNetLLLite @@ -290,7 +290,7 @@ class PipelineLike: vae: AutoencoderKL, text_encoders: List[CLIPTextModel], tokenizers: List[CLIPTokenizer], - unet: SdxlUNet2DConditionModel, + unet: InferSdxlUNet2DConditionModel, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], clip_skip: int, ): @@ -328,7 +328,7 @@ class PipelineLike: self.vae = vae self.text_encoders = text_encoders self.tokenizers = tokenizers - self.unet: SdxlUNet2DConditionModel = unet + self.unet: InferSdxlUNet2DConditionModel = unet self.scheduler = scheduler self.safety_checker = None @@ -1611,6 +1611,7 @@ def main(args): (_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model( args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype ) + unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel(unet) # xformers、Hypernetwork対応 if not args.diffusers_xformers: @@ -1766,10 +1767,14 @@ def main(args): print("set vae_dtype to float32") vae_dtype = torch.float32 vae.to(vae_dtype).to(device) + vae.eval() text_encoder1.to(dtype).to(device) text_encoder2.to(dtype).to(device) unet.to(dtype).to(device) + text_encoder1.eval() + text_encoder2.eval() + unet.eval() # networkを組み込む if args.network_module: diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 54abf697..44447d1f 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -460,7 +460,7 @@ def train(args): loss = loss * loss_weights if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index f00f10ea..91cbacc6 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -430,7 +430,7 @@ def train(args): loss = loss * loss_weights if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: diff --git a/train_controlnet.py b/train_controlnet.py index bbd915cb..e0118d1c 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -449,7 +449,7 @@ def train(args): loss = loss * loss_weights if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/train_db.py b/train_db.py index 7fbbc18a..966999df 100644 --- a/train_db.py +++ b/train_db.py @@ -342,7 +342,7 @@ def train(args): loss = loss * loss_weights if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.debiased_estimation_loss: diff --git a/train_network.py b/train_network.py index d50916b7..1cbed2e7 100644 --- a/train_network.py +++ b/train_network.py @@ -812,7 +812,7 @@ class NetworkTrainer: loss = loss * loss_weights if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 6b6e7f5a..45a437b9 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -578,7 +578,7 @@ class TextualInversionTrainer: loss = loss * loss_weights if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 8dd5c672..f77ad2eb 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -469,7 +469,7 @@ def train(args): loss = loss * loss_weights if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.debiased_estimation_loss: