diff --git a/library/ipex/hijacks.py b/library/ipex/hijacks.py index 93fd7537..b6d246dd 100644 --- a/library/ipex/hijacks.py +++ b/library/ipex/hijacks.py @@ -122,12 +122,12 @@ def torch_cat(tensor, *args, **kwargs): return original_torch_cat(tensor, *args, **kwargs) # SwinIR BF16: -original_funtional_pad = torch.nn.functional.pad -def funtional_pad(input, pad, mode='constant', value=None): +original_functional_pad = torch.nn.functional.pad +def functional_pad(input, pad, mode='constant', value=None): if mode == 'reflect' and input.dtype == torch.bfloat16: - return original_funtional_pad(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16) + return original_functional_pad(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16) else: - return original_funtional_pad(input, pad, mode=mode, value=value) + return original_functional_pad(input, pad, mode=mode, value=value) original_torch_tensor = torch.tensor @@ -240,7 +240,7 @@ def ipex_hijacks(): torch.nn.functional.linear = functional_linear torch.nn.functional.conv2d = functional_conv2d torch.nn.functional.interpolate = interpolate - torch.nn.functional.pad = funtional_pad + torch.nn.functional.pad = functional_pad torch.bmm = torch_bmm torch.cat = torch_cat