This commit is contained in:
Disty0
2024-01-01 12:51:23 +03:00
parent 15d5e78ac2
commit 479bac447e

View File

@@ -122,12 +122,12 @@ def torch_cat(tensor, *args, **kwargs):
return original_torch_cat(tensor, *args, **kwargs) return original_torch_cat(tensor, *args, **kwargs)
# SwinIR BF16: # SwinIR BF16:
original_funtional_pad = torch.nn.functional.pad original_functional_pad = torch.nn.functional.pad
def funtional_pad(input, pad, mode='constant', value=None): def functional_pad(input, pad, mode='constant', value=None):
if mode == 'reflect' and input.dtype == torch.bfloat16: if mode == 'reflect' and input.dtype == torch.bfloat16:
return original_funtional_pad(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16) return original_functional_pad(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16)
else: else:
return original_funtional_pad(input, pad, mode=mode, value=value) return original_functional_pad(input, pad, mode=mode, value=value)
original_torch_tensor = torch.tensor original_torch_tensor = torch.tensor
@@ -240,7 +240,7 @@ def ipex_hijacks():
torch.nn.functional.linear = functional_linear torch.nn.functional.linear = functional_linear
torch.nn.functional.conv2d = functional_conv2d torch.nn.functional.conv2d = functional_conv2d
torch.nn.functional.interpolate = interpolate torch.nn.functional.interpolate = interpolate
torch.nn.functional.pad = funtional_pad torch.nn.functional.pad = functional_pad
torch.bmm = torch_bmm torch.bmm = torch_bmm
torch.cat = torch_cat torch.cat = torch_cat