mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
update PyTorch version and reorganize dependencies
This commit is contained in:
@@ -11,20 +11,24 @@ import numpy as np
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex, get_preferred_device
|
||||
|
||||
init_ipex()
|
||||
|
||||
from tqdm import tqdm
|
||||
from transformers import CLIPTokenizer
|
||||
from diffusers import EulerDiscreteScheduler
|
||||
from PIL import Image
|
||||
import open_clip
|
||||
|
||||
# import open_clip
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from library import model_util, sdxl_model_util
|
||||
import networks.lora as lora
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# scheduler: このあたりの設定はSD1/2と同じでいいらしい
|
||||
@@ -154,12 +158,13 @@ if __name__ == "__main__":
|
||||
text_model2.eval()
|
||||
|
||||
unet.set_use_memory_efficient_attention(True, False)
|
||||
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
|
||||
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
|
||||
vae.set_use_memory_efficient_attention_xformers(True)
|
||||
|
||||
# Tokenizers
|
||||
tokenizer1 = CLIPTokenizer.from_pretrained(text_encoder_1_name)
|
||||
tokenizer2 = lambda x: open_clip.tokenize(x, context_length=77)
|
||||
# tokenizer2 = lambda x: open_clip.tokenize(x, context_length=77)
|
||||
tokenizer2 = CLIPTokenizer.from_pretrained(text_encoder_2_name)
|
||||
|
||||
# LoRA
|
||||
for weights_file in args.lora_weights:
|
||||
@@ -192,7 +197,9 @@ if __name__ == "__main__":
|
||||
emb3 = get_timestep_embedding(torch.FloatTensor([target_height, target_width]).unsqueeze(0), 256)
|
||||
# logger.info("emb1", emb1.shape)
|
||||
c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(DEVICE, dtype=DTYPE)
|
||||
uc_vector = c_vector.clone().to(DEVICE, dtype=DTYPE) # ちょっとここ正しいかどうかわからない I'm not sure if this is right
|
||||
uc_vector = c_vector.clone().to(
|
||||
DEVICE, dtype=DTYPE
|
||||
) # ちょっとここ正しいかどうかわからない I'm not sure if this is right
|
||||
|
||||
# crossattn
|
||||
|
||||
@@ -215,13 +222,22 @@ if __name__ == "__main__":
|
||||
# text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding) # layer normは通さないらしい
|
||||
|
||||
# text encoder 2
|
||||
with torch.no_grad():
|
||||
tokens = tokenizer2(text2).to(DEVICE)
|
||||
# tokens = tokenizer2(text2).to(DEVICE)
|
||||
tokens = tokenizer2(
|
||||
text,
|
||||
truncation=True,
|
||||
return_length=True,
|
||||
return_overflowing_tokens=False,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)
|
||||
tokens = batch_encoding["input_ids"].to(DEVICE)
|
||||
|
||||
with torch.no_grad():
|
||||
enc_out = text_model2(tokens, output_hidden_states=True, return_dict=True)
|
||||
text_embedding2_penu = enc_out["hidden_states"][-2]
|
||||
# logger.info("hidden_states2", text_embedding2_penu.shape)
|
||||
text_embedding2_pool = enc_out["text_embeds"] # do not support Textual Inversion
|
||||
text_embedding2_pool = enc_out["text_embeds"] # do not support Textual Inversion
|
||||
|
||||
# 連結して終了 concat and finish
|
||||
text_embedding = torch.cat([text_embedding1, text_embedding2_penu], dim=2)
|
||||
|
||||
Reference in New Issue
Block a user