mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix a bug x is updated inplace
This commit is contained in:
@@ -153,7 +153,7 @@ class LLLiteModule(torch.nn.Module):
|
||||
# down reduces the number of input dimensions and combines it with conditioning image embedding
|
||||
# we expect that it will mix well by combining in the channel direction instead of adding
|
||||
|
||||
cx = torch.cat([cx, self.down(x)], dim=1 if self.is_conv2d else 2)
|
||||
cx = torch.cat([cx, self.down(x if not self.batch_cond_only else x[1::2])], dim=1 if self.is_conv2d else 2)
|
||||
cx = self.mid(cx)
|
||||
|
||||
if self.dropout is not None and self.training:
|
||||
@@ -161,18 +161,11 @@ class LLLiteModule(torch.nn.Module):
|
||||
|
||||
cx = self.up(cx)
|
||||
|
||||
# residualを加算する / add residual
|
||||
# residua (x) lを加算して元のforwardを呼び出す / add residual (x) and call the original forward
|
||||
if self.batch_cond_only:
|
||||
x[1::2] += cx
|
||||
else:
|
||||
# to_outを対象とすると、cloneがないと次のエラーが出る / if to_out is the target, the following error will occur without clone
|
||||
# RuntimeError: Output 0 of ReshapeAliasBackward0 is a view and is being modified inplace.
|
||||
# This view was created inside a custom Function ...
|
||||
# x = x.clone()
|
||||
cx = torch.zeros_like(x)[1::2] + cx
|
||||
|
||||
x += cx
|
||||
|
||||
x = self.org_forward(x) # ここで元のモジュールを呼び出す / call the original module here
|
||||
x = self.org_forward(x + cx) # ここで元のモジュールを呼び出す / call the original module here
|
||||
return x
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user