mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
get pool from CLIPVisionModel in img2img
This commit is contained in:
@@ -37,7 +37,7 @@ from diffusers import (
|
|||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer, CLIPModel, CLIPTextConfig
|
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, CLIPImageProcessor
|
||||||
import PIL
|
import PIL
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from PIL.PngImagePlugin import PngInfo
|
from PIL.PngImagePlugin import PngInfo
|
||||||
@@ -61,6 +61,8 @@ SCHEDLER_SCHEDULE = "scaled_linear"
|
|||||||
LATENT_CHANNELS = 4
|
LATENT_CHANNELS = 4
|
||||||
DOWNSAMPLING_FACTOR = 8
|
DOWNSAMPLING_FACTOR = 8
|
||||||
|
|
||||||
|
CLIP_VISION_MODEL = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||||
|
|
||||||
# region モジュール入れ替え部
|
# region モジュール入れ替え部
|
||||||
"""
|
"""
|
||||||
高速化のためのモジュール入れ替え
|
高速化のためのモジュール入れ替え
|
||||||
@@ -320,6 +322,10 @@ class PipelineLike:
|
|||||||
self.scheduler = scheduler
|
self.scheduler = scheduler
|
||||||
self.safety_checker = None
|
self.safety_checker = None
|
||||||
|
|
||||||
|
self.clip_vision_model: CLIPVisionModelWithProjection = None
|
||||||
|
self.clip_vision_processor: CLIPImageProcessor = None
|
||||||
|
self.clip_vision_strength = 0.0
|
||||||
|
|
||||||
# Textual Inversion
|
# Textual Inversion
|
||||||
self.token_replacements_list = []
|
self.token_replacements_list = []
|
||||||
for _ in range(len(self.text_encoders)):
|
for _ in range(len(self.text_encoders)):
|
||||||
@@ -535,6 +541,21 @@ class PipelineLike:
|
|||||||
num_sub_prompts = len(text_pool) // batch_size
|
num_sub_prompts = len(text_pool) // batch_size
|
||||||
text_pool = text_pool[num_sub_prompts - 1 :: num_sub_prompts] # last subprompt
|
text_pool = text_pool[num_sub_prompts - 1 :: num_sub_prompts] # last subprompt
|
||||||
|
|
||||||
|
if init_image is not None and self.clip_vision_model is not None:
|
||||||
|
print(f"encode by clip_vision_model and apply clip_vision_strength={self.clip_vision_strength}")
|
||||||
|
vision_input = self.clip_vision_processor(init_image, return_tensors="pt", device=self.device)
|
||||||
|
pixel_values = vision_input["pixel_values"].to(self.device, dtype=text_embeddings.dtype)
|
||||||
|
|
||||||
|
clip_vision_embeddings = self.clip_vision_model(pixel_values=pixel_values, output_hidden_states=True, return_dict=True)
|
||||||
|
clip_vision_embeddings = clip_vision_embeddings.image_embeds
|
||||||
|
|
||||||
|
if len(clip_vision_embeddings) == 1 and batch_size > 1:
|
||||||
|
clip_vision_embeddings = clip_vision_embeddings.repeat((batch_size, 1))
|
||||||
|
|
||||||
|
clip_vision_embeddings = clip_vision_embeddings * self.clip_vision_strength
|
||||||
|
assert clip_vision_embeddings.shape == text_pool.shape, f"{clip_vision_embeddings.shape} != {text_pool.shape}"
|
||||||
|
text_pool = clip_vision_embeddings # replace: same as ComfyUI (?)
|
||||||
|
|
||||||
c_vector = torch.cat([text_pool, c_vector], dim=1)
|
c_vector = torch.cat([text_pool, c_vector], dim=1)
|
||||||
uc_vector = torch.cat([uncond_pool, uc_vector], dim=1)
|
uc_vector = torch.cat([uncond_pool, uc_vector], dim=1)
|
||||||
|
|
||||||
@@ -1767,6 +1788,19 @@ def main(args):
|
|||||||
init_images = load_images(args.image_path)
|
init_images = load_images(args.image_path)
|
||||||
assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}"
|
assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}"
|
||||||
print(f"loaded {len(init_images)} images for img2img")
|
print(f"loaded {len(init_images)} images for img2img")
|
||||||
|
|
||||||
|
# CLIP Vision
|
||||||
|
if args.clip_vision_strength is not None:
|
||||||
|
print(f"load CLIP Vision model: {CLIP_VISION_MODEL}")
|
||||||
|
vision_model = CLIPVisionModelWithProjection.from_pretrained(CLIP_VISION_MODEL, projection_dim=1280)
|
||||||
|
vision_model.to(device, dtype)
|
||||||
|
processor = CLIPImageProcessor.from_pretrained(CLIP_VISION_MODEL)
|
||||||
|
|
||||||
|
pipe.clip_vision_model = vision_model
|
||||||
|
pipe.clip_vision_processor = processor
|
||||||
|
pipe.clip_vision_strength = args.clip_vision_strength
|
||||||
|
print(f"CLIP Vision model loaded.")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
init_images = None
|
init_images = None
|
||||||
|
|
||||||
@@ -2656,6 +2690,12 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
nargs="*",
|
nargs="*",
|
||||||
help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率",
|
help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--clip_vision_strength",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="enable CLIP Vision Conditioning for img2img with this strength / img2imgでCLIP Vision Conditioningを有効にしてこのstrengthで処理する",
|
||||||
|
)
|
||||||
# # parser.add_argument(
|
# # parser.add_argument(
|
||||||
# "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像"
|
# "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像"
|
||||||
# )
|
# )
|
||||||
|
|||||||
Reference in New Issue
Block a user