diff --git a/library/slicing_vae.py b/library/slicing_vae.py index 31b2bd0a..5c4e056d 100644 --- a/library/slicing_vae.py +++ b/library/slicing_vae.py @@ -62,7 +62,7 @@ def cat_h(sliced): return x -def resblock_forward(_self, num_slices, input_tensor, temb): +def resblock_forward(_self, num_slices, input_tensor, temb, **kwargs): assert _self.upsample is None and _self.downsample is None assert _self.norm1.num_groups == _self.norm2.num_groups assert temb is None