mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
get pool from CLIPVisionModel in img2img
This commit is contained in:
@@ -37,7 +37,7 @@ from diffusers import (
|
||||
from einops import rearrange
|
||||
from tqdm import tqdm
|
||||
from torchvision import transforms
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, CLIPModel, CLIPTextConfig
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, CLIPImageProcessor
|
||||
import PIL
|
||||
from PIL import Image
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
@@ -61,6 +61,8 @@ SCHEDLER_SCHEDULE = "scaled_linear"
|
||||
LATENT_CHANNELS = 4
|
||||
DOWNSAMPLING_FACTOR = 8
|
||||
|
||||
CLIP_VISION_MODEL = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||
|
||||
# region モジュール入れ替え部
|
||||
"""
|
||||
高速化のためのモジュール入れ替え
|
||||
@@ -320,6 +322,10 @@ class PipelineLike:
|
||||
self.scheduler = scheduler
|
||||
self.safety_checker = None
|
||||
|
||||
self.clip_vision_model: CLIPVisionModelWithProjection = None
|
||||
self.clip_vision_processor: CLIPImageProcessor = None
|
||||
self.clip_vision_strength = 0.0
|
||||
|
||||
# Textual Inversion
|
||||
self.token_replacements_list = []
|
||||
for _ in range(len(self.text_encoders)):
|
||||
@@ -535,6 +541,21 @@ class PipelineLike:
|
||||
num_sub_prompts = len(text_pool) // batch_size
|
||||
text_pool = text_pool[num_sub_prompts - 1 :: num_sub_prompts] # last subprompt
|
||||
|
||||
if init_image is not None and self.clip_vision_model is not None:
|
||||
print(f"encode by clip_vision_model and apply clip_vision_strength={self.clip_vision_strength}")
|
||||
vision_input = self.clip_vision_processor(init_image, return_tensors="pt", device=self.device)
|
||||
pixel_values = vision_input["pixel_values"].to(self.device, dtype=text_embeddings.dtype)
|
||||
|
||||
clip_vision_embeddings = self.clip_vision_model(pixel_values=pixel_values, output_hidden_states=True, return_dict=True)
|
||||
clip_vision_embeddings = clip_vision_embeddings.image_embeds
|
||||
|
||||
if len(clip_vision_embeddings) == 1 and batch_size > 1:
|
||||
clip_vision_embeddings = clip_vision_embeddings.repeat((batch_size, 1))
|
||||
|
||||
clip_vision_embeddings = clip_vision_embeddings * self.clip_vision_strength
|
||||
assert clip_vision_embeddings.shape == text_pool.shape, f"{clip_vision_embeddings.shape} != {text_pool.shape}"
|
||||
text_pool = clip_vision_embeddings # replace: same as ComfyUI (?)
|
||||
|
||||
c_vector = torch.cat([text_pool, c_vector], dim=1)
|
||||
uc_vector = torch.cat([uncond_pool, uc_vector], dim=1)
|
||||
|
||||
@@ -1767,6 +1788,19 @@ def main(args):
|
||||
init_images = load_images(args.image_path)
|
||||
assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}"
|
||||
print(f"loaded {len(init_images)} images for img2img")
|
||||
|
||||
# CLIP Vision
|
||||
if args.clip_vision_strength is not None:
|
||||
print(f"load CLIP Vision model: {CLIP_VISION_MODEL}")
|
||||
vision_model = CLIPVisionModelWithProjection.from_pretrained(CLIP_VISION_MODEL, projection_dim=1280)
|
||||
vision_model.to(device, dtype)
|
||||
processor = CLIPImageProcessor.from_pretrained(CLIP_VISION_MODEL)
|
||||
|
||||
pipe.clip_vision_model = vision_model
|
||||
pipe.clip_vision_processor = processor
|
||||
pipe.clip_vision_strength = args.clip_vision_strength
|
||||
print(f"CLIP Vision model loaded.")
|
||||
|
||||
else:
|
||||
init_images = None
|
||||
|
||||
@@ -2656,6 +2690,12 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
nargs="*",
|
||||
help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clip_vision_strength",
|
||||
type=float,
|
||||
default=None,
|
||||
help="enable CLIP Vision Conditioning for img2img with this strength / img2imgでCLIP Vision Conditioningを有効にしてこのstrengthで処理する",
|
||||
)
|
||||
# # parser.add_argument(
|
||||
# "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像"
|
||||
# )
|
||||
|
||||
Reference in New Issue
Block a user