diff --git a/networks/lora.py b/networks/lora.py index cd73cbe7..0c75cd42 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -241,9 +241,13 @@ class LoRAInfModule(LoRAModule): else: area = x.size()[1] - mask = self.network.mask_dic[area] + mask = self.network.mask_dic.get(area, None) if mask is None: - raise ValueError(f"mask is None for resolution {area}") + # raise ValueError(f"mask is None for resolution {area}") + # emb_layers in SDXL doesn't have mask + # print(f"mask is None for resolution {area}, {x.size()}") + mask_size = (1, x.size()[1]) if len(x.size()) == 2 else (1, *x.size()[1:-1], 1) + return torch.ones(mask_size, dtype=x.dtype, device=x.device) / self.network.num_sub_prompts if len(x.size()) != 4: mask = torch.reshape(mask, (1, -1, 1)) return mask @@ -348,9 +352,10 @@ class LoRAInfModule(LoRAModule): out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond # print("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts) - # for i in range(len(masks)): - # if masks[i] is None: - # masks[i] = torch.zeros_like(masks[-1]) + # if num_sub_prompts > num of LoRAs, fill with zero + for i in range(len(masks)): + if masks[i] is None: + masks[i] = torch.zeros_like(masks[0]) mask = torch.cat(masks) mask_sum = torch.sum(mask, dim=0) + 1e-4 diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index 35bd6c61..c506ad3f 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -451,10 +451,11 @@ class PipelineLike: tes_text_embs = [] tes_uncond_embs = [] tes_real_uncond_embs = [] - # use last pool + for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders): token_replacer = self.get_token_replacer(tokenizer) + # use last text_pool, because it is from text encoder 2 text_embeddings, text_pool, uncond_embeddings, uncond_pool, _ = get_weighted_text_embeddings( tokenizer, text_encoder, @@ -529,6 +530,11 @@ class PipelineLike: c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(self.device, dtype=text_embeddings.dtype).repeat(batch_size, 1) uc_vector = torch.cat([uc_emb1, emb2, emb3], dim=1).to(self.device, dtype=text_embeddings.dtype).repeat(batch_size, 1) + if reginonal_network: + # use last pool for conditioning + num_sub_prompts = len(text_pool) // batch_size + text_pool = text_pool[num_sub_prompts - 1 :: num_sub_prompts] # last subprompt + c_vector = torch.cat([text_pool, c_vector], dim=1) uc_vector = torch.cat([uncond_pool, uc_vector], dim=1) @@ -762,7 +768,7 @@ class PipelineLike: image = image.cpu().permute(0, 2, 3, 1).float().numpy() if torch.cuda.is_available(): - torch.cuda.empty_cache() + torch.cuda.empty_cache() if output_type == "pil": # image = self.numpy_to_pil(image)