mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
use CLIPTextModelWithProjection
This commit is contained in:
@@ -66,9 +66,9 @@ def get_timestep_embedding(x, outdim):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 画像生成条件を変更する場合はここを変更
|
||||
# 画像生成条件を変更する場合はここを変更 / change here to change image generation conditions
|
||||
|
||||
# SDXLの追加のvector embeddingへ渡す値
|
||||
# SDXLの追加のvector embeddingへ渡す値 / Values to pass to additional vector embedding of SDXL
|
||||
target_height = 1024
|
||||
target_width = 1024
|
||||
original_height = target_height
|
||||
@@ -95,6 +95,7 @@ if __name__ == "__main__":
|
||||
default=[],
|
||||
help="LoRA weights, only supports networks.lora, each arguement is a `path;multiplier` (semi-colon separated)",
|
||||
)
|
||||
parser.add_argument("--interactive", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
# HuggingFaceのmodel id
|
||||
@@ -106,7 +107,7 @@ if __name__ == "__main__":
|
||||
|
||||
# 本体RAMが少ない場合はGPUにロードするといいかも
|
||||
# If the main RAM is small, it may be better to load it on the GPU
|
||||
text_model1, text_model2, vae, unet, text_projection, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
||||
text_model1, text_model2, vae, unet, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
||||
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, args.ckpt_path, "cpu"
|
||||
)
|
||||
|
||||
@@ -140,9 +141,12 @@ if __name__ == "__main__":
|
||||
text_model2.to(DEVICE, dtype=DTYPE)
|
||||
text_model2.eval()
|
||||
|
||||
text_projection = text_projection.to(DEVICE, dtype=DTYPE)
|
||||
|
||||
unet.set_use_memory_efficient_attention(True, False)
|
||||
vae.set_use_memory_efficient_attention_xformers(True)
|
||||
|
||||
# Tokenizers
|
||||
tokenizer1 = CLIPTokenizer.from_pretrained(text_encoder_1_name)
|
||||
tokenizer2 = lambda x: open_clip.tokenize(x, context_length=77)
|
||||
|
||||
# LoRA
|
||||
for weights_file in args.lora_weights:
|
||||
@@ -157,19 +161,27 @@ if __name__ == "__main__":
|
||||
)
|
||||
lora_model.merge_to([text_model1, text_model2], unet, weights_sd, DTYPE, DEVICE)
|
||||
|
||||
# prepare embedding
|
||||
with torch.no_grad():
|
||||
# vector
|
||||
emb1 = get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256)
|
||||
emb2 = get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256)
|
||||
emb3 = get_timestep_embedding(torch.FloatTensor([target_height, target_width]).unsqueeze(0), 256)
|
||||
# print("emb1", emb1.shape)
|
||||
c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(DEVICE, dtype=DTYPE)
|
||||
uc_vector = c_vector.clone().to(DEVICE, dtype=DTYPE) # ちょっとここ正しいかどうかわからない I'm not sure if this is right
|
||||
# scheduler
|
||||
scheduler = EulerDiscreteScheduler(
|
||||
num_train_timesteps=SCHEDULER_TIMESTEPS,
|
||||
beta_start=SCHEDULER_LINEAR_START,
|
||||
beta_end=SCHEDULER_LINEAR_END,
|
||||
beta_schedule=SCHEDLER_SCHEDULE,
|
||||
)
|
||||
|
||||
# crossattn
|
||||
tokenizer1 = CLIPTokenizer.from_pretrained(text_encoder_1_name)
|
||||
tokenizer2 = lambda x: open_clip.tokenize(x, context_length=77)
|
||||
def generate_image(text, negative_text, seed=None):
|
||||
# 将来的にサイズ情報も変えられるようにする / Make it possible to change the size information in the future
|
||||
# prepare embedding
|
||||
with torch.no_grad():
|
||||
# vector
|
||||
emb1 = get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256)
|
||||
emb2 = get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256)
|
||||
emb3 = get_timestep_embedding(torch.FloatTensor([target_height, target_width]).unsqueeze(0), 256)
|
||||
# print("emb1", emb1.shape)
|
||||
c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(DEVICE, dtype=DTYPE)
|
||||
uc_vector = c_vector.clone().to(DEVICE, dtype=DTYPE) # ちょっとここ正しいかどうかわからない I'm not sure if this is right
|
||||
|
||||
# crossattn
|
||||
|
||||
# Text Encoderを二つ呼ぶ関数 Function to call two Text Encoders
|
||||
def call_text_encoder(text):
|
||||
@@ -184,30 +196,31 @@ if __name__ == "__main__":
|
||||
)
|
||||
tokens = batch_encoding["input_ids"].to(DEVICE)
|
||||
|
||||
enc_out = text_model1(tokens, output_hidden_states=True, return_dict=True)
|
||||
text_embedding1 = enc_out["hidden_states"][11]
|
||||
# text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding) # layer normは通さないらしい
|
||||
with torch.no_grad():
|
||||
enc_out = text_model1(tokens, output_hidden_states=True, return_dict=True)
|
||||
text_embedding1 = enc_out["hidden_states"][11]
|
||||
# text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding) # layer normは通さないらしい
|
||||
|
||||
# text encoder 2
|
||||
tokens = tokenizer2(text).to(DEVICE)
|
||||
with torch.no_grad():
|
||||
tokens = tokenizer2(text).to(DEVICE)
|
||||
|
||||
enc_out = text_model2(tokens, output_hidden_states=True, return_dict=True)
|
||||
text_embedding2_penu = enc_out["hidden_states"][-2]
|
||||
# print("hidden_states2", text_embedding2_penu.shape)
|
||||
text_embedding2_pool = enc_out["pooler_output"]
|
||||
text_embedding2_pool = text_embedding2_pool @ text_projection.to(text_embedding2_pool.dtype)
|
||||
enc_out = text_model2(tokens, output_hidden_states=True, return_dict=True)
|
||||
text_embedding2_penu = enc_out["hidden_states"][-2]
|
||||
# print("hidden_states2", text_embedding2_penu.shape)
|
||||
text_embedding2_pool = enc_out["text_embeds"]
|
||||
|
||||
# 連結して終了 concat and finish
|
||||
text_embedding = torch.cat([text_embedding1, text_embedding2_penu], dim=2)
|
||||
return text_embedding, text_embedding2_pool
|
||||
|
||||
# cond
|
||||
c_ctx, c_ctx_pool = call_text_encoder(args.prompt)
|
||||
c_ctx, c_ctx_pool = call_text_encoder(prompt)
|
||||
# print(c_ctx.shape, c_ctx_p.shape, c_vector.shape)
|
||||
c_vector = torch.cat([c_ctx_pool, c_vector], dim=1)
|
||||
|
||||
# uncond
|
||||
uc_ctx, uc_ctx_pool = call_text_encoder(args.negative_prompt)
|
||||
uc_ctx, uc_ctx_pool = call_text_encoder(negative_prompt)
|
||||
uc_vector = torch.cat([uc_ctx_pool, uc_vector], dim=1)
|
||||
|
||||
text_embeddings = torch.cat([uc_ctx, c_ctx])
|
||||
@@ -215,14 +228,6 @@ if __name__ == "__main__":
|
||||
|
||||
# メモリ使用量を減らすにはここでText Encoderを削除するかCPUへ移動する
|
||||
|
||||
# scheduler
|
||||
scheduler = EulerDiscreteScheduler(
|
||||
num_train_timesteps=SCHEDULER_TIMESTEPS,
|
||||
beta_start=SCHEDULER_LINEAR_START,
|
||||
beta_end=SCHEDULER_LINEAR_END,
|
||||
beta_schedule=SCHEDLER_SCHEDULE,
|
||||
)
|
||||
|
||||
if seed is not None:
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
@@ -256,25 +261,26 @@ if __name__ == "__main__":
|
||||
# Copy from Diffusers
|
||||
timesteps = scheduler.timesteps.to(DEVICE) # .to(DTYPE)
|
||||
num_latent_input = 2
|
||||
for i, t in enumerate(tqdm(timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
|
||||
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
||||
with torch.no_grad():
|
||||
for i, t in enumerate(tqdm(timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
|
||||
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
noise_pred = unet(latent_model_input, t, text_embeddings, vector_embeddings)
|
||||
noise_pred = unet(latent_model_input, t, text_embeddings, vector_embeddings)
|
||||
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_latent_input) # uncond by negative prompt
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_latent_input) # uncond by negative prompt
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
# latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
latents = scheduler.step(noise_pred, t, latents).prev_sample
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
# latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
latents = scheduler.step(noise_pred, t, latents).prev_sample
|
||||
|
||||
# latents = 1 / 0.18215 * latents
|
||||
latents = 1 / sdxl_model_util.VAE_SCALE_FACTOR * latents
|
||||
latents = latents.to(torch.float32)
|
||||
image = vae.decode(latents).sample
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# latents = 1 / 0.18215 * latents
|
||||
latents = 1 / sdxl_model_util.VAE_SCALE_FACTOR * latents
|
||||
latents = latents.to(torch.float32)
|
||||
image = vae.decode(latents).sample
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
@@ -288,4 +294,20 @@ if __name__ == "__main__":
|
||||
for i, img in enumerate(image):
|
||||
img.save(os.path.join(args.output_dir, f"image_{timestamp}_{i:03d}.png"))
|
||||
|
||||
print("Done!")
|
||||
if not args.interactive:
|
||||
generate_image(args.prompt, args.negative_prompt, args.seed)
|
||||
else:
|
||||
# loop for interactive
|
||||
while True:
|
||||
prompt = input("prompt: ")
|
||||
if prompt == "":
|
||||
break
|
||||
negative_prompt = input("negative prompt: ")
|
||||
seed = input("seed: ")
|
||||
if seed == "":
|
||||
seed = None
|
||||
else:
|
||||
seed = int(seed)
|
||||
generate_image(prompt, negative_prompt, seed)
|
||||
|
||||
print("Done!")
|
||||
|
||||
Reference in New Issue
Block a user