This commit is contained in:
Disty0
2024-01-01 12:51:23 +03:00
parent 15d5e78ac2
commit 479bac447e

View File

@@ -122,12 +122,12 @@ def torch_cat(tensor, *args, **kwargs):
return original_torch_cat(tensor, *args, **kwargs)
# SwinIR BF16:
original_funtional_pad = torch.nn.functional.pad
def funtional_pad(input, pad, mode='constant', value=None):
original_functional_pad = torch.nn.functional.pad
def functional_pad(input, pad, mode='constant', value=None):
if mode == 'reflect' and input.dtype == torch.bfloat16:
return original_funtional_pad(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16)
return original_functional_pad(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16)
else:
return original_funtional_pad(input, pad, mode=mode, value=value)
return original_functional_pad(input, pad, mode=mode, value=value)
original_torch_tensor = torch.tensor
@@ -240,7 +240,7 @@ def ipex_hijacks():
torch.nn.functional.linear = functional_linear
torch.nn.functional.conv2d = functional_conv2d
torch.nn.functional.interpolate = interpolate
torch.nn.functional.pad = funtional_pad
torch.nn.functional.pad = functional_pad
torch.bmm = torch_bmm
torch.cat = torch_cat