mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix to work regional LoRA
This commit is contained in:
@@ -241,9 +241,13 @@ class LoRAInfModule(LoRAModule):
|
||||
else:
|
||||
area = x.size()[1]
|
||||
|
||||
mask = self.network.mask_dic[area]
|
||||
mask = self.network.mask_dic.get(area, None)
|
||||
if mask is None:
|
||||
raise ValueError(f"mask is None for resolution {area}")
|
||||
# raise ValueError(f"mask is None for resolution {area}")
|
||||
# emb_layers in SDXL doesn't have mask
|
||||
# print(f"mask is None for resolution {area}, {x.size()}")
|
||||
mask_size = (1, x.size()[1]) if len(x.size()) == 2 else (1, *x.size()[1:-1], 1)
|
||||
return torch.ones(mask_size, dtype=x.dtype, device=x.device) / self.network.num_sub_prompts
|
||||
if len(x.size()) != 4:
|
||||
mask = torch.reshape(mask, (1, -1, 1))
|
||||
return mask
|
||||
@@ -348,9 +352,10 @@ class LoRAInfModule(LoRAModule):
|
||||
out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond
|
||||
|
||||
# print("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts)
|
||||
# for i in range(len(masks)):
|
||||
# if masks[i] is None:
|
||||
# masks[i] = torch.zeros_like(masks[-1])
|
||||
# if num_sub_prompts > num of LoRAs, fill with zero
|
||||
for i in range(len(masks)):
|
||||
if masks[i] is None:
|
||||
masks[i] = torch.zeros_like(masks[0])
|
||||
|
||||
mask = torch.cat(masks)
|
||||
mask_sum = torch.sum(mask, dim=0) + 1e-4
|
||||
|
||||
Reference in New Issue
Block a user