Fix bug in FLUX multi GPU training

This commit is contained in:
kohya-ss
2024-08-22 12:37:41 +09:00
parent e1cd19c0c0
commit 98c91a7625
8 changed files with 156 additions and 38 deletions

View File

@@ -745,7 +745,9 @@ class DoubleStreamBlock(nn.Module):
return custom_forward
return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), img, txt, vec, pe, txt_attention_mask)
return torch.utils.checkpoint.checkpoint(
create_custom_forward(self._forward), img, txt, vec, pe, txt_attention_mask, use_reentrant=False
)
else:
return self._forward(img, txt, vec, pe, txt_attention_mask)
@@ -836,7 +838,7 @@ class SingleStreamBlock(nn.Module):
return custom_forward
return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, vec, pe)
return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, vec, pe, use_reentrant=False)
else:
return self._forward(x, vec, pe)