update PyTorch version and reorganize dependencies

This commit is contained in:
Kohya S
2024-03-24 22:06:37 +09:00
parent 1648ade6da
commit 9bbb28c361
6 changed files with 136 additions and 191 deletions

View File

@@ -11,20 +11,24 @@ import numpy as np
import torch
from library.device_utils import init_ipex, get_preferred_device
init_ipex()
from tqdm import tqdm
from transformers import CLIPTokenizer
from diffusers import EulerDiscreteScheduler
from PIL import Image
import open_clip
# import open_clip
from safetensors.torch import load_file
from library import model_util, sdxl_model_util
import networks.lora as lora
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
# scheduler: このあたりの設定はSD1/2と同じでいいらしい
@@ -154,12 +158,13 @@ if __name__ == "__main__":
text_model2.eval()
unet.set_use_memory_efficient_attention(True, False)
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
vae.set_use_memory_efficient_attention_xformers(True)
# Tokenizers
tokenizer1 = CLIPTokenizer.from_pretrained(text_encoder_1_name)
tokenizer2 = lambda x: open_clip.tokenize(x, context_length=77)
# tokenizer2 = lambda x: open_clip.tokenize(x, context_length=77)
tokenizer2 = CLIPTokenizer.from_pretrained(text_encoder_2_name)
# LoRA
for weights_file in args.lora_weights:
@@ -192,7 +197,9 @@ if __name__ == "__main__":
emb3 = get_timestep_embedding(torch.FloatTensor([target_height, target_width]).unsqueeze(0), 256)
# logger.info("emb1", emb1.shape)
c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(DEVICE, dtype=DTYPE)
uc_vector = c_vector.clone().to(DEVICE, dtype=DTYPE) # ちょっとここ正しいかどうかわからない I'm not sure if this is right
uc_vector = c_vector.clone().to(
DEVICE, dtype=DTYPE
) # ちょっとここ正しいかどうかわからない I'm not sure if this is right
# crossattn
@@ -215,13 +222,22 @@ if __name__ == "__main__":
# text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding) # layer normは通さないらしい
# text encoder 2
with torch.no_grad():
tokens = tokenizer2(text2).to(DEVICE)
# tokens = tokenizer2(text2).to(DEVICE)
tokens = tokenizer2(
text,
truncation=True,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
tokens = batch_encoding["input_ids"].to(DEVICE)
with torch.no_grad():
enc_out = text_model2(tokens, output_hidden_states=True, return_dict=True)
text_embedding2_penu = enc_out["hidden_states"][-2]
# logger.info("hidden_states2", text_embedding2_penu.shape)
text_embedding2_pool = enc_out["text_embeds"] # do not support Textual Inversion
text_embedding2_pool = enc_out["text_embeds"] # do not support Textual Inversion
# 連結して終了 concat and finish
text_embedding = torch.cat([text_embedding1, text_embedding2_penu], dim=2)