fix attention couple+deep shink cause error in some reso

This commit is contained in:
Kohya S
2024-04-03 12:43:08 +09:00
parent 2258a1b753
commit b748b48dbb

View File

@@ -247,14 +247,13 @@ class LoRAInfModule(LoRAModule):
area = x.size()[1] area = x.size()[1]
mask = self.network.mask_dic.get(area, None) mask = self.network.mask_dic.get(area, None)
if mask is None: if mask is None or len(x.size()) == 2:
# raise ValueError(f"mask is None for resolution {area}")
# emb_layers in SDXL doesn't have mask # emb_layers in SDXL doesn't have mask
# if "emb" not in self.lora_name: # if "emb" not in self.lora_name:
# print(f"mask is None for resolution {self.lora_name}, {area}, {x.size()}") # print(f"mask is None for resolution {self.lora_name}, {area}, {x.size()}")
mask_size = (1, x.size()[1]) if len(x.size()) == 2 else (1, *x.size()[1:-1], 1) mask_size = (1, x.size()[1]) if len(x.size()) == 2 else (1, *x.size()[1:-1], 1)
return torch.ones(mask_size, dtype=x.dtype, device=x.device) / self.network.num_sub_prompts return torch.ones(mask_size, dtype=x.dtype, device=x.device) / self.network.num_sub_prompts
if len(x.size()) != 4: if len(x.size()) == 3:
mask = torch.reshape(mask, (1, -1, 1)) mask = torch.reshape(mask, (1, -1, 1))
return mask return mask