Add OpenVINO and ROCm ONNX Runtime for WD14

This commit is contained in:
Disty0
2024-03-27 03:21:13 +03:00
parent 78e0a7630c
commit 6f7e93d5cc
3 changed files with 33 additions and 20 deletions

View File

@@ -142,10 +142,21 @@ def main(args):
del model del model
if "OpenVINOExecutionProvider" in ort.get_available_providers():
# requires provider options for gpu support
# fp16 causes nonsense outputs
ort_sess = ort.InferenceSession(
onnx_path,
providers=(["OpenVINOExecutionProvider"]),
provider_options=[{'device_type' : "GPU_FP32"}],
)
else:
ort_sess = ort.InferenceSession( ort_sess = ort.InferenceSession(
onnx_path, onnx_path,
providers=( providers=(
["CUDAExecutionProvider"] if "CUDAExecutionProvider" in ort.get_available_providers() else ["CPUExecutionProvider"] ["CUDAExecutionProvider"] if "CUDAExecutionProvider" in ort.get_available_providers() else
["ROCMExecutionProvider"] if "ROCMExecutionProvider" in ort.get_available_providers() else
["CPUExecutionProvider"]
), ),
) )
else: else:

View File

@@ -122,15 +122,15 @@ def torch_bmm_32_bit(input, mat2, *, out=None):
mat2[start_idx:end_idx], mat2[start_idx:end_idx],
out=out out=out
) )
torch.xpu.synchronize(input.device)
else: else:
return original_torch_bmm(input, mat2, out=out) return original_torch_bmm(input, mat2, out=out)
torch.xpu.synchronize(input.device)
return hidden_states return hidden_states
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False): def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs):
if query.device.type != "xpu": if query.device.type != "xpu":
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal) return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_sdpa_slice_sizes(query.shape, query.element_size()) do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_sdpa_slice_sizes(query.shape, query.element_size())
# Slice SDPA # Slice SDPA
@@ -153,7 +153,7 @@ def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropo
key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attn_mask is not None else attn_mask, attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal dropout_p=dropout_p, is_causal=is_causal, **kwargs
) )
else: else:
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention( hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention(
@@ -161,7 +161,7 @@ def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropo
key[start_idx:end_idx, start_idx_2:end_idx_2], key[start_idx:end_idx, start_idx_2:end_idx_2],
value[start_idx:end_idx, start_idx_2:end_idx_2], value[start_idx:end_idx, start_idx_2:end_idx_2],
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask, attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal dropout_p=dropout_p, is_causal=is_causal, **kwargs
) )
else: else:
hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention( hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention(
@@ -169,9 +169,9 @@ def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropo
key[start_idx:end_idx], key[start_idx:end_idx],
value[start_idx:end_idx], value[start_idx:end_idx],
attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask, attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal dropout_p=dropout_p, is_causal=is_causal, **kwargs
) )
else:
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
torch.xpu.synchronize(query.device) torch.xpu.synchronize(query.device)
else:
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
return hidden_states return hidden_states

View File

@@ -12,7 +12,7 @@ device_supports_fp64 = torch.xpu.has_fp64_dtype()
class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods
def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument
if isinstance(device_ids, list) and len(device_ids) > 1: if isinstance(device_ids, list) and len(device_ids) > 1:
logger.error("IPEX backend doesn't support DataParallel on multiple XPU devices") print("IPEX backend doesn't support DataParallel on multiple XPU devices")
return module.to("xpu") return module.to("xpu")
def return_null_context(*args, **kwargs): # pylint: disable=unused-argument def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
@@ -42,7 +42,7 @@ def autocast_init(self, device_type, dtype=None, enabled=True, cache_enabled=Non
original_interpolate = torch.nn.functional.interpolate original_interpolate = torch.nn.functional.interpolate
@wraps(torch.nn.functional.interpolate) @wraps(torch.nn.functional.interpolate)
def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
if antialias or align_corners is not None: if antialias or align_corners is not None or mode == 'bicubic':
return_device = tensor.device return_device = tensor.device
return_dtype = tensor.dtype return_dtype = tensor.dtype
return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode, return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode,
@@ -216,7 +216,9 @@ def torch_empty(*args, device=None, **kwargs):
original_torch_randn = torch.randn original_torch_randn = torch.randn
@wraps(torch.randn) @wraps(torch.randn)
def torch_randn(*args, device=None, **kwargs): def torch_randn(*args, device=None, dtype=None, **kwargs):
if dtype == bytes:
dtype = None
if check_device(device): if check_device(device):
return original_torch_randn(*args, device=return_xpu(device), **kwargs) return original_torch_randn(*args, device=return_xpu(device), **kwargs)
else: else:
@@ -256,11 +258,11 @@ def torch_Generator(device=None):
original_torch_load = torch.load original_torch_load = torch.load
@wraps(torch.load) @wraps(torch.load)
def torch_load(f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs): def torch_load(f, map_location=None, *args, **kwargs):
if check_device(map_location): if check_device(map_location):
return original_torch_load(f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs) return original_torch_load(f, map_location=return_xpu(map_location), *args, **kwargs)
else: else:
return original_torch_load(f, map_location=map_location, pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs) return original_torch_load(f, map_location=map_location, *args, **kwargs)
# Hijack Functions: # Hijack Functions: