From 0646112010acaea0d2f3f3235472e05f013e9b18 Mon Sep 17 00:00:00 2001 From: ykume Date: Sun, 20 Aug 2023 00:09:09 +0900 Subject: [PATCH] fix a bug x is updated inplace --- networks/control_net_lllite.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/networks/control_net_lllite.py b/networks/control_net_lllite.py index e0157080..36e36071 100644 --- a/networks/control_net_lllite.py +++ b/networks/control_net_lllite.py @@ -153,7 +153,7 @@ class LLLiteModule(torch.nn.Module): # down reduces the number of input dimensions and combines it with conditioning image embedding # we expect that it will mix well by combining in the channel direction instead of adding - cx = torch.cat([cx, self.down(x)], dim=1 if self.is_conv2d else 2) + cx = torch.cat([cx, self.down(x if not self.batch_cond_only else x[1::2])], dim=1 if self.is_conv2d else 2) cx = self.mid(cx) if self.dropout is not None and self.training: @@ -161,18 +161,11 @@ class LLLiteModule(torch.nn.Module): cx = self.up(cx) - # residualを加算する / add residual + # residua (x) lを加算して元のforwardを呼び出す / add residual (x) and call the original forward if self.batch_cond_only: - x[1::2] += cx - else: - # to_outを対象とすると、cloneがないと次のエラーが出る / if to_out is the target, the following error will occur without clone - # RuntimeError: Output 0 of ReshapeAliasBackward0 is a view and is being modified inplace. - # This view was created inside a custom Function ... - # x = x.clone() + cx = torch.zeros_like(x)[1::2] + cx - x += cx - - x = self.org_forward(x) # ここで元のモジュールを呼び出す / call the original module here + x = self.org_forward(x + cx) # ここで元のモジュールを呼び出す / call the original module here return x