mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Update dependencies ref #1024
This commit is contained in:
@@ -9,7 +9,7 @@ import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
from packaging import version
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
import diffusers
|
||||
from diffusers import SchedulerMixin, StableDiffusionPipeline
|
||||
@@ -520,6 +520,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
image_encoder: CLIPVisionModelWithProjection = None,
|
||||
clip_skip: int = 1,
|
||||
):
|
||||
super().__init__(
|
||||
@@ -531,32 +532,11 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
requires_safety_checker=requires_safety_checker,
|
||||
image_encoder=image_encoder,
|
||||
)
|
||||
self.clip_skip = clip_skip
|
||||
self.custom_clip_skip = clip_skip
|
||||
self.__init__additional__()
|
||||
|
||||
# else:
|
||||
# def __init__(
|
||||
# self,
|
||||
# vae: AutoencoderKL,
|
||||
# text_encoder: CLIPTextModel,
|
||||
# tokenizer: CLIPTokenizer,
|
||||
# unet: UNet2DConditionModel,
|
||||
# scheduler: SchedulerMixin,
|
||||
# safety_checker: StableDiffusionSafetyChecker,
|
||||
# feature_extractor: CLIPFeatureExtractor,
|
||||
# ):
|
||||
# super().__init__(
|
||||
# vae=vae,
|
||||
# text_encoder=text_encoder,
|
||||
# tokenizer=tokenizer,
|
||||
# unet=unet,
|
||||
# scheduler=scheduler,
|
||||
# safety_checker=safety_checker,
|
||||
# feature_extractor=feature_extractor,
|
||||
# )
|
||||
# self.__init__additional__()
|
||||
|
||||
def __init__additional__(self):
|
||||
if not hasattr(self, "vae_scale_factor"):
|
||||
setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1))
|
||||
@@ -624,7 +604,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
||||
prompt=prompt,
|
||||
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
|
||||
max_embeddings_multiples=max_embeddings_multiples,
|
||||
clip_skip=self.clip_skip,
|
||||
clip_skip=self.custom_clip_skip,
|
||||
)
|
||||
bs_embed, seq_len, _ = text_embeddings.shape
|
||||
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
|
||||
|
||||
@@ -4,10 +4,13 @@
|
||||
import math
|
||||
import os
|
||||
import torch
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
@@ -571,9 +574,9 @@ def convert_ldm_clip_checkpoint_v1(checkpoint):
|
||||
if key.startswith("cond_stage_model.transformer"):
|
||||
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
|
||||
|
||||
# support checkpoint without position_ids (invalid checkpoint)
|
||||
if "text_model.embeddings.position_ids" not in text_model_dict:
|
||||
text_model_dict["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0) # 77 is the max length of the text
|
||||
# remove position_ids for newer transformer, which causes error :(
|
||||
if "text_model.embeddings.position_ids" in text_model_dict:
|
||||
text_model_dict.pop("text_model.embeddings.position_ids")
|
||||
|
||||
return text_model_dict
|
||||
|
||||
|
||||
@@ -100,7 +100,7 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
|
||||
key = key.replace(".ln_final", ".final_layer_norm")
|
||||
# ckpt from comfy has this key: text_model.encoder.text_model.embeddings.position_ids
|
||||
elif ".embeddings.position_ids" in key:
|
||||
key = None # remove this key: make position_ids by ourselves
|
||||
key = None # remove this key: position_ids is not used in newer transformers
|
||||
return key
|
||||
|
||||
keys = list(checkpoint.keys())
|
||||
@@ -126,10 +126,6 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
|
||||
new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
|
||||
new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
|
||||
|
||||
# original SD にはないので、position_idsを追加
|
||||
position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
|
||||
new_sd["text_model.embeddings.position_ids"] = position_ids
|
||||
|
||||
# logit_scale はDiffusersには含まれないが、保存時に戻したいので別途返す
|
||||
logit_scale = checkpoint.get(SDXL_KEY_PREFIX + "logit_scale", None)
|
||||
|
||||
@@ -265,9 +261,9 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty
|
||||
elif k.startswith("conditioner.embedders.1.model."):
|
||||
te2_sd[k] = state_dict.pop(k)
|
||||
|
||||
# 一部のposition_idsがないモデルへの対応 / add position_ids for some models
|
||||
if "text_model.embeddings.position_ids" not in te1_sd:
|
||||
te1_sd["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0)
|
||||
# 最新の transformers では position_ids を含むとエラーになるので削除 / remove position_ids for latest transformers
|
||||
if "text_model.embeddings.position_ids" in te1_sd:
|
||||
te1_sd.pop("text_model.embeddings.position_ids")
|
||||
|
||||
info1 = _load_state_dict_on_device(text_model1, te1_sd, device=map_location) # remain fp32
|
||||
print("text encoder 1:", info1)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
accelerate==0.23.0
|
||||
transformers==4.30.2
|
||||
diffusers[torch]==0.21.2
|
||||
accelerate==0.25.0
|
||||
transformers==4.36.2
|
||||
diffusers[torch]==0.25.0
|
||||
ftfy==6.1.1
|
||||
# albumentations==1.3.0
|
||||
opencv-python==4.7.0.68
|
||||
@@ -14,7 +14,7 @@ altair==4.2.2
|
||||
easygui==0.98.3
|
||||
toml==0.10.2
|
||||
voluptuous==0.13.1
|
||||
huggingface-hub==0.15.1
|
||||
huggingface-hub==0.20.1
|
||||
# for BLIP captioning
|
||||
# requests==2.28.2
|
||||
# timm==0.6.12
|
||||
|
||||
Reference in New Issue
Block a user