mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Fix typo
This commit is contained in:
@@ -122,12 +122,12 @@ def torch_cat(tensor, *args, **kwargs):
|
||||
return original_torch_cat(tensor, *args, **kwargs)
|
||||
|
||||
# SwinIR BF16:
|
||||
original_funtional_pad = torch.nn.functional.pad
|
||||
def funtional_pad(input, pad, mode='constant', value=None):
|
||||
original_functional_pad = torch.nn.functional.pad
|
||||
def functional_pad(input, pad, mode='constant', value=None):
|
||||
if mode == 'reflect' and input.dtype == torch.bfloat16:
|
||||
return original_funtional_pad(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16)
|
||||
return original_functional_pad(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16)
|
||||
else:
|
||||
return original_funtional_pad(input, pad, mode=mode, value=value)
|
||||
return original_functional_pad(input, pad, mode=mode, value=value)
|
||||
|
||||
|
||||
original_torch_tensor = torch.tensor
|
||||
@@ -240,7 +240,7 @@ def ipex_hijacks():
|
||||
torch.nn.functional.linear = functional_linear
|
||||
torch.nn.functional.conv2d = functional_conv2d
|
||||
torch.nn.functional.interpolate = interpolate
|
||||
torch.nn.functional.pad = funtional_pad
|
||||
torch.nn.functional.pad = functional_pad
|
||||
|
||||
torch.bmm = torch_bmm
|
||||
torch.cat = torch_cat
|
||||
|
||||
Reference in New Issue
Block a user