diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index c506ad3f..7d9c68bf 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -37,7 +37,7 @@ from diffusers import ( from einops import rearrange from tqdm import tqdm from torchvision import transforms -from transformers import CLIPTextModel, CLIPTokenizer, CLIPModel, CLIPTextConfig +from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, CLIPImageProcessor import PIL from PIL import Image from PIL.PngImagePlugin import PngInfo @@ -61,6 +61,8 @@ SCHEDLER_SCHEDULE = "scaled_linear" LATENT_CHANNELS = 4 DOWNSAMPLING_FACTOR = 8 +CLIP_VISION_MODEL = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" + # region モジュール入れ替え部 """ 高速化のためのモジュール入れ替え @@ -320,6 +322,10 @@ class PipelineLike: self.scheduler = scheduler self.safety_checker = None + self.clip_vision_model: CLIPVisionModelWithProjection = None + self.clip_vision_processor: CLIPImageProcessor = None + self.clip_vision_strength = 0.0 + # Textual Inversion self.token_replacements_list = [] for _ in range(len(self.text_encoders)): @@ -535,6 +541,21 @@ class PipelineLike: num_sub_prompts = len(text_pool) // batch_size text_pool = text_pool[num_sub_prompts - 1 :: num_sub_prompts] # last subprompt + if init_image is not None and self.clip_vision_model is not None: + print(f"encode by clip_vision_model and apply clip_vision_strength={self.clip_vision_strength}") + vision_input = self.clip_vision_processor(init_image, return_tensors="pt", device=self.device) + pixel_values = vision_input["pixel_values"].to(self.device, dtype=text_embeddings.dtype) + + clip_vision_embeddings = self.clip_vision_model(pixel_values=pixel_values, output_hidden_states=True, return_dict=True) + clip_vision_embeddings = clip_vision_embeddings.image_embeds + + if len(clip_vision_embeddings) == 1 and batch_size > 1: + clip_vision_embeddings = clip_vision_embeddings.repeat((batch_size, 1)) + + clip_vision_embeddings = clip_vision_embeddings * self.clip_vision_strength + assert clip_vision_embeddings.shape == text_pool.shape, f"{clip_vision_embeddings.shape} != {text_pool.shape}" + text_pool = clip_vision_embeddings # replace: same as ComfyUI (?) + c_vector = torch.cat([text_pool, c_vector], dim=1) uc_vector = torch.cat([uncond_pool, uc_vector], dim=1) @@ -1767,6 +1788,19 @@ def main(args): init_images = load_images(args.image_path) assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}" print(f"loaded {len(init_images)} images for img2img") + + # CLIP Vision + if args.clip_vision_strength is not None: + print(f"load CLIP Vision model: {CLIP_VISION_MODEL}") + vision_model = CLIPVisionModelWithProjection.from_pretrained(CLIP_VISION_MODEL, projection_dim=1280) + vision_model.to(device, dtype) + processor = CLIPImageProcessor.from_pretrained(CLIP_VISION_MODEL) + + pipe.clip_vision_model = vision_model + pipe.clip_vision_processor = processor + pipe.clip_vision_strength = args.clip_vision_strength + print(f"CLIP Vision model loaded.") + else: init_images = None @@ -2656,6 +2690,12 @@ def setup_parser() -> argparse.ArgumentParser: nargs="*", help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率", ) + parser.add_argument( + "--clip_vision_strength", + type=float, + default=None, + help="enable CLIP Vision Conditioning for img2img with this strength / img2imgでCLIP Vision Conditioningを有効にしてこのstrengthで処理する", + ) # # parser.add_argument( # "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像" # )