mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
Undo dropout after up
This commit is contained in:
@@ -22,10 +22,6 @@ import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
NUM_DOUBLE_BLOCKS = 19
|
||||
NUM_SINGLE_BLOCKS = 38
|
||||
|
||||
|
||||
class LoRAModule(torch.nn.Module):
|
||||
"""
|
||||
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
||||
@@ -142,9 +138,6 @@ class LoRAModule(torch.nn.Module):
|
||||
|
||||
lx = self.lora_up(lx)
|
||||
|
||||
if self.dropout is not None and self.training:
|
||||
lx = torch.nn.functional.dropout(lx, p=self.dropout)
|
||||
|
||||
return org_forwarded + lx * self.multiplier * scale
|
||||
else:
|
||||
lxs = [lora_down(x) for lora_down in self.lora_down]
|
||||
@@ -170,9 +163,6 @@ class LoRAModule(torch.nn.Module):
|
||||
|
||||
lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)]
|
||||
|
||||
if self.dropout is not None and self.training:
|
||||
lxs = [torch.nn.functional.dropout(lx, p=self.dropout) for lx in lxs]
|
||||
|
||||
return org_forwarded + torch.cat(lxs, dim=-1) * self.multiplier * scale
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user