IPEX fix SDPA

This commit is contained in:
Disty0
2023-12-19 22:59:06 +03:00
parent aff05e043f
commit 8556b9d7f5
2 changed files with 64 additions and 58 deletions

View File

@@ -185,6 +185,10 @@ def ipex_hijacks():
CondFunc('torch.Generator',
lambda orig_func, device=None: torch.xpu.Generator(return_xpu(device)),
lambda orig_func, device=None: device is not None and device != torch.device("cpu") and device != "cpu")
else:
CondFunc('torch.Generator',
lambda orig_func, device=None: orig_func(return_xpu(device)),
lambda orig_func, device=None: check_device(device))
# TiledVAE and ControlNet:
CondFunc('torch.batch_norm',