Update dependencies ref #1024

This commit is contained in:
Kohya S
2024-01-04 19:53:25 +09:00
parent 07bf2a21ac
commit 716bad188b
4 changed files with 19 additions and 40 deletions

View File

@@ -9,7 +9,7 @@ import numpy as np
import PIL.Image import PIL.Image
import torch import torch
from packaging import version from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
import diffusers import diffusers
from diffusers import SchedulerMixin, StableDiffusionPipeline from diffusers import SchedulerMixin, StableDiffusionPipeline
@@ -520,6 +520,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
image_encoder: CLIPVisionModelWithProjection = None,
clip_skip: int = 1, clip_skip: int = 1,
): ):
super().__init__( super().__init__(
@@ -531,32 +532,11 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
safety_checker=safety_checker, safety_checker=safety_checker,
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
requires_safety_checker=requires_safety_checker, requires_safety_checker=requires_safety_checker,
image_encoder=image_encoder,
) )
self.clip_skip = clip_skip self.custom_clip_skip = clip_skip
self.__init__additional__() self.__init__additional__()
# else:
# def __init__(
# self,
# vae: AutoencoderKL,
# text_encoder: CLIPTextModel,
# tokenizer: CLIPTokenizer,
# unet: UNet2DConditionModel,
# scheduler: SchedulerMixin,
# safety_checker: StableDiffusionSafetyChecker,
# feature_extractor: CLIPFeatureExtractor,
# ):
# super().__init__(
# vae=vae,
# text_encoder=text_encoder,
# tokenizer=tokenizer,
# unet=unet,
# scheduler=scheduler,
# safety_checker=safety_checker,
# feature_extractor=feature_extractor,
# )
# self.__init__additional__()
def __init__additional__(self): def __init__additional__(self):
if not hasattr(self, "vae_scale_factor"): if not hasattr(self, "vae_scale_factor"):
setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1)) setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1))
@@ -624,7 +604,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
prompt=prompt, prompt=prompt,
uncond_prompt=negative_prompt if do_classifier_free_guidance else None, uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
max_embeddings_multiples=max_embeddings_multiples, max_embeddings_multiples=max_embeddings_multiples,
clip_skip=self.clip_skip, clip_skip=self.custom_clip_skip,
) )
bs_embed, seq_len, _ = text_embeddings.shape bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)

View File

@@ -4,10 +4,13 @@
import math import math
import os import os
import torch import torch
try: try:
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
if torch.xpu.is_available(): if torch.xpu.is_available():
from library.ipex import ipex_init from library.ipex import ipex_init
ipex_init() ipex_init()
except Exception: except Exception:
pass pass
@@ -571,9 +574,9 @@ def convert_ldm_clip_checkpoint_v1(checkpoint):
if key.startswith("cond_stage_model.transformer"): if key.startswith("cond_stage_model.transformer"):
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
# support checkpoint without position_ids (invalid checkpoint) # remove position_ids for newer transformer, which causes error :(
if "text_model.embeddings.position_ids" not in text_model_dict: if "text_model.embeddings.position_ids" in text_model_dict:
text_model_dict["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0) # 77 is the max length of the text text_model_dict.pop("text_model.embeddings.position_ids")
return text_model_dict return text_model_dict

View File

@@ -100,7 +100,7 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
key = key.replace(".ln_final", ".final_layer_norm") key = key.replace(".ln_final", ".final_layer_norm")
# ckpt from comfy has this key: text_model.encoder.text_model.embeddings.position_ids # ckpt from comfy has this key: text_model.encoder.text_model.embeddings.position_ids
elif ".embeddings.position_ids" in key: elif ".embeddings.position_ids" in key:
key = None # remove this key: make position_ids by ourselves key = None # remove this key: position_ids is not used in newer transformers
return key return key
keys = list(checkpoint.keys()) keys = list(checkpoint.keys())
@@ -126,10 +126,6 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
new_sd[key_pfx + "k_proj" + key_suffix] = values[1] new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
new_sd[key_pfx + "v_proj" + key_suffix] = values[2] new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
# original SD にはないので、position_idsを追加
position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
new_sd["text_model.embeddings.position_ids"] = position_ids
# logit_scale はDiffusersには含まれないが、保存時に戻したいので別途返す # logit_scale はDiffusersには含まれないが、保存時に戻したいので別途返す
logit_scale = checkpoint.get(SDXL_KEY_PREFIX + "logit_scale", None) logit_scale = checkpoint.get(SDXL_KEY_PREFIX + "logit_scale", None)
@@ -265,9 +261,9 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty
elif k.startswith("conditioner.embedders.1.model."): elif k.startswith("conditioner.embedders.1.model."):
te2_sd[k] = state_dict.pop(k) te2_sd[k] = state_dict.pop(k)
# 一部のposition_idsがないモデルへの対応 / add position_ids for some models # 最新の transformers では position_ids を含むとエラーになるので削除 / remove position_ids for latest transformers
if "text_model.embeddings.position_ids" not in te1_sd: if "text_model.embeddings.position_ids" in te1_sd:
te1_sd["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0) te1_sd.pop("text_model.embeddings.position_ids")
info1 = _load_state_dict_on_device(text_model1, te1_sd, device=map_location) # remain fp32 info1 = _load_state_dict_on_device(text_model1, te1_sd, device=map_location) # remain fp32
print("text encoder 1:", info1) print("text encoder 1:", info1)

View File

@@ -1,6 +1,6 @@
accelerate==0.23.0 accelerate==0.25.0
transformers==4.30.2 transformers==4.36.2
diffusers[torch]==0.21.2 diffusers[torch]==0.25.0
ftfy==6.1.1 ftfy==6.1.1
# albumentations==1.3.0 # albumentations==1.3.0
opencv-python==4.7.0.68 opencv-python==4.7.0.68
@@ -14,7 +14,7 @@ altair==4.2.2
easygui==0.98.3 easygui==0.98.3
toml==0.10.2 toml==0.10.2
voluptuous==0.13.1 voluptuous==0.13.1
huggingface-hub==0.15.1 huggingface-hub==0.20.1
# for BLIP captioning # for BLIP captioning
# requests==2.28.2 # requests==2.28.2
# timm==0.6.12 # timm==0.6.12