mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix a bug x is updated inplace
This commit is contained in:
@@ -153,7 +153,7 @@ class LLLiteModule(torch.nn.Module):
|
|||||||
# down reduces the number of input dimensions and combines it with conditioning image embedding
|
# down reduces the number of input dimensions and combines it with conditioning image embedding
|
||||||
# we expect that it will mix well by combining in the channel direction instead of adding
|
# we expect that it will mix well by combining in the channel direction instead of adding
|
||||||
|
|
||||||
cx = torch.cat([cx, self.down(x)], dim=1 if self.is_conv2d else 2)
|
cx = torch.cat([cx, self.down(x if not self.batch_cond_only else x[1::2])], dim=1 if self.is_conv2d else 2)
|
||||||
cx = self.mid(cx)
|
cx = self.mid(cx)
|
||||||
|
|
||||||
if self.dropout is not None and self.training:
|
if self.dropout is not None and self.training:
|
||||||
@@ -161,18 +161,11 @@ class LLLiteModule(torch.nn.Module):
|
|||||||
|
|
||||||
cx = self.up(cx)
|
cx = self.up(cx)
|
||||||
|
|
||||||
# residualを加算する / add residual
|
# residua (x) lを加算して元のforwardを呼び出す / add residual (x) and call the original forward
|
||||||
if self.batch_cond_only:
|
if self.batch_cond_only:
|
||||||
x[1::2] += cx
|
cx = torch.zeros_like(x)[1::2] + cx
|
||||||
else:
|
|
||||||
# to_outを対象とすると、cloneがないと次のエラーが出る / if to_out is the target, the following error will occur without clone
|
|
||||||
# RuntimeError: Output 0 of ReshapeAliasBackward0 is a view and is being modified inplace.
|
|
||||||
# This view was created inside a custom Function ...
|
|
||||||
# x = x.clone()
|
|
||||||
|
|
||||||
x += cx
|
x = self.org_forward(x + cx) # ここで元のモジュールを呼び出す / call the original module here
|
||||||
|
|
||||||
x = self.org_forward(x) # ここで元のモジュールを呼び出す / call the original module here
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user