get pool from CLIPVisionModel in img2img

This commit is contained in:
Kohya S
2023-09-13 20:58:37 +09:00
parent 90c47140b8
commit d337bbf8a0

View File

@@ -37,7 +37,7 @@ from diffusers import (
from einops import rearrange from einops import rearrange
from tqdm import tqdm from tqdm import tqdm
from torchvision import transforms from torchvision import transforms
from transformers import CLIPTextModel, CLIPTokenizer, CLIPModel, CLIPTextConfig from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, CLIPImageProcessor
import PIL import PIL
from PIL import Image from PIL import Image
from PIL.PngImagePlugin import PngInfo from PIL.PngImagePlugin import PngInfo
@@ -61,6 +61,8 @@ SCHEDLER_SCHEDULE = "scaled_linear"
LATENT_CHANNELS = 4 LATENT_CHANNELS = 4
DOWNSAMPLING_FACTOR = 8 DOWNSAMPLING_FACTOR = 8
CLIP_VISION_MODEL = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
# region モジュール入れ替え部 # region モジュール入れ替え部
""" """
高速化のためのモジュール入れ替え 高速化のためのモジュール入れ替え
@@ -320,6 +322,10 @@ class PipelineLike:
self.scheduler = scheduler self.scheduler = scheduler
self.safety_checker = None self.safety_checker = None
self.clip_vision_model: CLIPVisionModelWithProjection = None
self.clip_vision_processor: CLIPImageProcessor = None
self.clip_vision_strength = 0.0
# Textual Inversion # Textual Inversion
self.token_replacements_list = [] self.token_replacements_list = []
for _ in range(len(self.text_encoders)): for _ in range(len(self.text_encoders)):
@@ -535,6 +541,21 @@ class PipelineLike:
num_sub_prompts = len(text_pool) // batch_size num_sub_prompts = len(text_pool) // batch_size
text_pool = text_pool[num_sub_prompts - 1 :: num_sub_prompts] # last subprompt text_pool = text_pool[num_sub_prompts - 1 :: num_sub_prompts] # last subprompt
if init_image is not None and self.clip_vision_model is not None:
print(f"encode by clip_vision_model and apply clip_vision_strength={self.clip_vision_strength}")
vision_input = self.clip_vision_processor(init_image, return_tensors="pt", device=self.device)
pixel_values = vision_input["pixel_values"].to(self.device, dtype=text_embeddings.dtype)
clip_vision_embeddings = self.clip_vision_model(pixel_values=pixel_values, output_hidden_states=True, return_dict=True)
clip_vision_embeddings = clip_vision_embeddings.image_embeds
if len(clip_vision_embeddings) == 1 and batch_size > 1:
clip_vision_embeddings = clip_vision_embeddings.repeat((batch_size, 1))
clip_vision_embeddings = clip_vision_embeddings * self.clip_vision_strength
assert clip_vision_embeddings.shape == text_pool.shape, f"{clip_vision_embeddings.shape} != {text_pool.shape}"
text_pool = clip_vision_embeddings # replace: same as ComfyUI (?)
c_vector = torch.cat([text_pool, c_vector], dim=1) c_vector = torch.cat([text_pool, c_vector], dim=1)
uc_vector = torch.cat([uncond_pool, uc_vector], dim=1) uc_vector = torch.cat([uncond_pool, uc_vector], dim=1)
@@ -1767,6 +1788,19 @@ def main(args):
init_images = load_images(args.image_path) init_images = load_images(args.image_path)
assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}" assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}"
print(f"loaded {len(init_images)} images for img2img") print(f"loaded {len(init_images)} images for img2img")
# CLIP Vision
if args.clip_vision_strength is not None:
print(f"load CLIP Vision model: {CLIP_VISION_MODEL}")
vision_model = CLIPVisionModelWithProjection.from_pretrained(CLIP_VISION_MODEL, projection_dim=1280)
vision_model.to(device, dtype)
processor = CLIPImageProcessor.from_pretrained(CLIP_VISION_MODEL)
pipe.clip_vision_model = vision_model
pipe.clip_vision_processor = processor
pipe.clip_vision_strength = args.clip_vision_strength
print(f"CLIP Vision model loaded.")
else: else:
init_images = None init_images = None
@@ -2656,6 +2690,12 @@ def setup_parser() -> argparse.ArgumentParser:
nargs="*", nargs="*",
help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率", help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率",
) )
parser.add_argument(
"--clip_vision_strength",
type=float,
default=None,
help="enable CLIP Vision Conditioning for img2img with this strength / img2imgでCLIP Vision Conditioningを有効にしてこのstrengthで処理する",
)
# # parser.add_argument( # # parser.add_argument(
# "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像" # "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像"
# ) # )