diff --git a/networks/lora_lumina.py b/networks/lora_lumina.py index f856d4e7..03d13039 100644 --- a/networks/lora_lumina.py +++ b/networks/lora_lumina.py @@ -22,10 +22,6 @@ import logging logger = logging.getLogger(__name__) -NUM_DOUBLE_BLOCKS = 19 -NUM_SINGLE_BLOCKS = 38 - - class LoRAModule(torch.nn.Module): """ replaces forward method of the original Linear, instead of replacing the original Linear module. @@ -142,9 +138,6 @@ class LoRAModule(torch.nn.Module): lx = self.lora_up(lx) - if self.dropout is not None and self.training: - lx = torch.nn.functional.dropout(lx, p=self.dropout) - return org_forwarded + lx * self.multiplier * scale else: lxs = [lora_down(x) for lora_down in self.lora_down] @@ -170,9 +163,6 @@ class LoRAModule(torch.nn.Module): lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)] - if self.dropout is not None and self.training: - lxs = [torch.nn.functional.dropout(lx, p=self.dropout) for lx in lxs] - return org_forwarded + torch.cat(lxs, dim=-1) * self.multiplier * scale