mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix to work regional LoRA
This commit is contained in:
@@ -451,10 +451,11 @@ class PipelineLike:
|
||||
tes_text_embs = []
|
||||
tes_uncond_embs = []
|
||||
tes_real_uncond_embs = []
|
||||
# use last pool
|
||||
|
||||
for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):
|
||||
token_replacer = self.get_token_replacer(tokenizer)
|
||||
|
||||
# use last text_pool, because it is from text encoder 2
|
||||
text_embeddings, text_pool, uncond_embeddings, uncond_pool, _ = get_weighted_text_embeddings(
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
@@ -529,6 +530,11 @@ class PipelineLike:
|
||||
c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(self.device, dtype=text_embeddings.dtype).repeat(batch_size, 1)
|
||||
uc_vector = torch.cat([uc_emb1, emb2, emb3], dim=1).to(self.device, dtype=text_embeddings.dtype).repeat(batch_size, 1)
|
||||
|
||||
if reginonal_network:
|
||||
# use last pool for conditioning
|
||||
num_sub_prompts = len(text_pool) // batch_size
|
||||
text_pool = text_pool[num_sub_prompts - 1 :: num_sub_prompts] # last subprompt
|
||||
|
||||
c_vector = torch.cat([text_pool, c_vector], dim=1)
|
||||
uc_vector = torch.cat([uncond_pool, uc_vector], dim=1)
|
||||
|
||||
@@ -762,7 +768,7 @@ class PipelineLike:
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if output_type == "pil":
|
||||
# image = self.numpy_to_pil(image)
|
||||
|
||||
Reference in New Issue
Block a user