mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Fix typo
This commit is contained in:
@@ -122,12 +122,12 @@ def torch_cat(tensor, *args, **kwargs):
|
|||||||
return original_torch_cat(tensor, *args, **kwargs)
|
return original_torch_cat(tensor, *args, **kwargs)
|
||||||
|
|
||||||
# SwinIR BF16:
|
# SwinIR BF16:
|
||||||
original_funtional_pad = torch.nn.functional.pad
|
original_functional_pad = torch.nn.functional.pad
|
||||||
def funtional_pad(input, pad, mode='constant', value=None):
|
def functional_pad(input, pad, mode='constant', value=None):
|
||||||
if mode == 'reflect' and input.dtype == torch.bfloat16:
|
if mode == 'reflect' and input.dtype == torch.bfloat16:
|
||||||
return original_funtional_pad(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16)
|
return original_functional_pad(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16)
|
||||||
else:
|
else:
|
||||||
return original_funtional_pad(input, pad, mode=mode, value=value)
|
return original_functional_pad(input, pad, mode=mode, value=value)
|
||||||
|
|
||||||
|
|
||||||
original_torch_tensor = torch.tensor
|
original_torch_tensor = torch.tensor
|
||||||
@@ -240,7 +240,7 @@ def ipex_hijacks():
|
|||||||
torch.nn.functional.linear = functional_linear
|
torch.nn.functional.linear = functional_linear
|
||||||
torch.nn.functional.conv2d = functional_conv2d
|
torch.nn.functional.conv2d = functional_conv2d
|
||||||
torch.nn.functional.interpolate = interpolate
|
torch.nn.functional.interpolate = interpolate
|
||||||
torch.nn.functional.pad = funtional_pad
|
torch.nn.functional.pad = functional_pad
|
||||||
|
|
||||||
torch.bmm = torch_bmm
|
torch.bmm = torch_bmm
|
||||||
torch.cat = torch_cat
|
torch.cat = torch_cat
|
||||||
|
|||||||
Reference in New Issue
Block a user