mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
Fix bug in FLUX multi GPU training
This commit is contained in:
@@ -745,7 +745,9 @@ class DoubleStreamBlock(nn.Module):
|
||||
|
||||
return custom_forward
|
||||
|
||||
return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), img, txt, vec, pe, txt_attention_mask)
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self._forward), img, txt, vec, pe, txt_attention_mask, use_reentrant=False
|
||||
)
|
||||
|
||||
else:
|
||||
return self._forward(img, txt, vec, pe, txt_attention_mask)
|
||||
@@ -836,7 +838,7 @@ class SingleStreamBlock(nn.Module):
|
||||
|
||||
return custom_forward
|
||||
|
||||
return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, vec, pe)
|
||||
return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, vec, pe, use_reentrant=False)
|
||||
else:
|
||||
return self._forward(x, vec, pe)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user