fix a bug x is updated inplace

This commit is contained in:
ykume
2023-08-20 00:09:09 +09:00
parent 5a86bbc0a0
commit 0646112010

View File

@@ -153,7 +153,7 @@ class LLLiteModule(torch.nn.Module):
# down reduces the number of input dimensions and combines it with conditioning image embedding
# we expect that it will mix well by combining in the channel direction instead of adding
cx = torch.cat([cx, self.down(x)], dim=1 if self.is_conv2d else 2)
cx = torch.cat([cx, self.down(x if not self.batch_cond_only else x[1::2])], dim=1 if self.is_conv2d else 2)
cx = self.mid(cx)
if self.dropout is not None and self.training:
@@ -161,18 +161,11 @@ class LLLiteModule(torch.nn.Module):
cx = self.up(cx)
# residualを加算する / add residual
# residua (x) lを加算して元のforwardを呼び出す / add residual (x) and call the original forward
if self.batch_cond_only:
x[1::2] += cx
else:
# to_outを対象とすると、cloneがないと次のエラーが出る / if to_out is the target, the following error will occur without clone
# RuntimeError: Output 0 of ReshapeAliasBackward0 is a view and is being modified inplace.
# This view was created inside a custom Function ...
# x = x.clone()
cx = torch.zeros_like(x)[1::2] + cx
x += cx
x = self.org_forward(x) # ここで元のモジュールを呼び出す / call the original module here
x = self.org_forward(x + cx) # ここで元のモジュールを呼び出す / call the original module here
return x