mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix to work regional LoRA
This commit is contained in:
@@ -241,9 +241,13 @@ class LoRAInfModule(LoRAModule):
|
|||||||
else:
|
else:
|
||||||
area = x.size()[1]
|
area = x.size()[1]
|
||||||
|
|
||||||
mask = self.network.mask_dic[area]
|
mask = self.network.mask_dic.get(area, None)
|
||||||
if mask is None:
|
if mask is None:
|
||||||
raise ValueError(f"mask is None for resolution {area}")
|
# raise ValueError(f"mask is None for resolution {area}")
|
||||||
|
# emb_layers in SDXL doesn't have mask
|
||||||
|
# print(f"mask is None for resolution {area}, {x.size()}")
|
||||||
|
mask_size = (1, x.size()[1]) if len(x.size()) == 2 else (1, *x.size()[1:-1], 1)
|
||||||
|
return torch.ones(mask_size, dtype=x.dtype, device=x.device) / self.network.num_sub_prompts
|
||||||
if len(x.size()) != 4:
|
if len(x.size()) != 4:
|
||||||
mask = torch.reshape(mask, (1, -1, 1))
|
mask = torch.reshape(mask, (1, -1, 1))
|
||||||
return mask
|
return mask
|
||||||
@@ -348,9 +352,10 @@ class LoRAInfModule(LoRAModule):
|
|||||||
out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond
|
out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond
|
||||||
|
|
||||||
# print("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts)
|
# print("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts)
|
||||||
# for i in range(len(masks)):
|
# if num_sub_prompts > num of LoRAs, fill with zero
|
||||||
# if masks[i] is None:
|
for i in range(len(masks)):
|
||||||
# masks[i] = torch.zeros_like(masks[-1])
|
if masks[i] is None:
|
||||||
|
masks[i] = torch.zeros_like(masks[0])
|
||||||
|
|
||||||
mask = torch.cat(masks)
|
mask = torch.cat(masks)
|
||||||
mask_sum = torch.sum(mask, dim=0) + 1e-4
|
mask_sum = torch.sum(mask, dim=0) + 1e-4
|
||||||
|
|||||||
@@ -451,10 +451,11 @@ class PipelineLike:
|
|||||||
tes_text_embs = []
|
tes_text_embs = []
|
||||||
tes_uncond_embs = []
|
tes_uncond_embs = []
|
||||||
tes_real_uncond_embs = []
|
tes_real_uncond_embs = []
|
||||||
# use last pool
|
|
||||||
for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):
|
for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):
|
||||||
token_replacer = self.get_token_replacer(tokenizer)
|
token_replacer = self.get_token_replacer(tokenizer)
|
||||||
|
|
||||||
|
# use last text_pool, because it is from text encoder 2
|
||||||
text_embeddings, text_pool, uncond_embeddings, uncond_pool, _ = get_weighted_text_embeddings(
|
text_embeddings, text_pool, uncond_embeddings, uncond_pool, _ = get_weighted_text_embeddings(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
text_encoder,
|
text_encoder,
|
||||||
@@ -529,6 +530,11 @@ class PipelineLike:
|
|||||||
c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(self.device, dtype=text_embeddings.dtype).repeat(batch_size, 1)
|
c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(self.device, dtype=text_embeddings.dtype).repeat(batch_size, 1)
|
||||||
uc_vector = torch.cat([uc_emb1, emb2, emb3], dim=1).to(self.device, dtype=text_embeddings.dtype).repeat(batch_size, 1)
|
uc_vector = torch.cat([uc_emb1, emb2, emb3], dim=1).to(self.device, dtype=text_embeddings.dtype).repeat(batch_size, 1)
|
||||||
|
|
||||||
|
if reginonal_network:
|
||||||
|
# use last pool for conditioning
|
||||||
|
num_sub_prompts = len(text_pool) // batch_size
|
||||||
|
text_pool = text_pool[num_sub_prompts - 1 :: num_sub_prompts] # last subprompt
|
||||||
|
|
||||||
c_vector = torch.cat([text_pool, c_vector], dim=1)
|
c_vector = torch.cat([text_pool, c_vector], dim=1)
|
||||||
uc_vector = torch.cat([uncond_pool, uc_vector], dim=1)
|
uc_vector = torch.cat([uncond_pool, uc_vector], dim=1)
|
||||||
|
|
||||||
@@ -762,7 +768,7 @@ class PipelineLike:
|
|||||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
if output_type == "pil":
|
if output_type == "pil":
|
||||||
# image = self.numpy_to_pil(image)
|
# image = self.numpy_to_pil(image)
|
||||||
|
|||||||
Reference in New Issue
Block a user