diff --git a/library/original_unet.py b/library/original_unet.py index aa9dc233..375126cb 100644 --- a/library/original_unet.py +++ b/library/original_unet.py @@ -508,7 +508,7 @@ class DownBlock2D(nn.Module): self.resnets = nn.ModuleList(resnets) if add_downsample: - self.downsamplers = [Downsample2D(out_channels, out_channels=out_channels)] + self.downsamplers = nn.ModuleList([Downsample2D(out_channels, out_channels=out_channels)]) else: self.downsamplers = None