mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Add OpenVINO and ROCm ONNX Runtime for WD14
This commit is contained in:
@@ -142,10 +142,21 @@ def main(args):
|
|||||||
|
|
||||||
del model
|
del model
|
||||||
|
|
||||||
|
if "OpenVINOExecutionProvider" in ort.get_available_providers():
|
||||||
|
# requires provider options for gpu support
|
||||||
|
# fp16 causes nonsense outputs
|
||||||
|
ort_sess = ort.InferenceSession(
|
||||||
|
onnx_path,
|
||||||
|
providers=(["OpenVINOExecutionProvider"]),
|
||||||
|
provider_options=[{'device_type' : "GPU_FP32"}],
|
||||||
|
)
|
||||||
|
else:
|
||||||
ort_sess = ort.InferenceSession(
|
ort_sess = ort.InferenceSession(
|
||||||
onnx_path,
|
onnx_path,
|
||||||
providers=(
|
providers=(
|
||||||
["CUDAExecutionProvider"] if "CUDAExecutionProvider" in ort.get_available_providers() else ["CPUExecutionProvider"]
|
["CUDAExecutionProvider"] if "CUDAExecutionProvider" in ort.get_available_providers() else
|
||||||
|
["ROCMExecutionProvider"] if "ROCMExecutionProvider" in ort.get_available_providers() else
|
||||||
|
["CPUExecutionProvider"]
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -122,15 +122,15 @@ def torch_bmm_32_bit(input, mat2, *, out=None):
|
|||||||
mat2[start_idx:end_idx],
|
mat2[start_idx:end_idx],
|
||||||
out=out
|
out=out
|
||||||
)
|
)
|
||||||
|
torch.xpu.synchronize(input.device)
|
||||||
else:
|
else:
|
||||||
return original_torch_bmm(input, mat2, out=out)
|
return original_torch_bmm(input, mat2, out=out)
|
||||||
torch.xpu.synchronize(input.device)
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
||||||
def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
|
def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs):
|
||||||
if query.device.type != "xpu":
|
if query.device.type != "xpu":
|
||||||
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
|
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
|
||||||
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_sdpa_slice_sizes(query.shape, query.element_size())
|
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_sdpa_slice_sizes(query.shape, query.element_size())
|
||||||
|
|
||||||
# Slice SDPA
|
# Slice SDPA
|
||||||
@@ -153,7 +153,7 @@ def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropo
|
|||||||
key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
||||||
value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
||||||
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attn_mask is not None else attn_mask,
|
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attn_mask is not None else attn_mask,
|
||||||
dropout_p=dropout_p, is_causal=is_causal
|
dropout_p=dropout_p, is_causal=is_causal, **kwargs
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention(
|
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention(
|
||||||
@@ -161,7 +161,7 @@ def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropo
|
|||||||
key[start_idx:end_idx, start_idx_2:end_idx_2],
|
key[start_idx:end_idx, start_idx_2:end_idx_2],
|
||||||
value[start_idx:end_idx, start_idx_2:end_idx_2],
|
value[start_idx:end_idx, start_idx_2:end_idx_2],
|
||||||
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask,
|
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask,
|
||||||
dropout_p=dropout_p, is_causal=is_causal
|
dropout_p=dropout_p, is_causal=is_causal, **kwargs
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention(
|
hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention(
|
||||||
@@ -169,9 +169,9 @@ def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropo
|
|||||||
key[start_idx:end_idx],
|
key[start_idx:end_idx],
|
||||||
value[start_idx:end_idx],
|
value[start_idx:end_idx],
|
||||||
attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask,
|
attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask,
|
||||||
dropout_p=dropout_p, is_causal=is_causal
|
dropout_p=dropout_p, is_causal=is_causal, **kwargs
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
|
|
||||||
torch.xpu.synchronize(query.device)
|
torch.xpu.synchronize(query.device)
|
||||||
|
else:
|
||||||
|
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ device_supports_fp64 = torch.xpu.has_fp64_dtype()
|
|||||||
class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods
|
class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods
|
||||||
def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument
|
def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument
|
||||||
if isinstance(device_ids, list) and len(device_ids) > 1:
|
if isinstance(device_ids, list) and len(device_ids) > 1:
|
||||||
logger.error("IPEX backend doesn't support DataParallel on multiple XPU devices")
|
print("IPEX backend doesn't support DataParallel on multiple XPU devices")
|
||||||
return module.to("xpu")
|
return module.to("xpu")
|
||||||
|
|
||||||
def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
|
def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
|
||||||
@@ -42,7 +42,7 @@ def autocast_init(self, device_type, dtype=None, enabled=True, cache_enabled=Non
|
|||||||
original_interpolate = torch.nn.functional.interpolate
|
original_interpolate = torch.nn.functional.interpolate
|
||||||
@wraps(torch.nn.functional.interpolate)
|
@wraps(torch.nn.functional.interpolate)
|
||||||
def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
|
def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
|
||||||
if antialias or align_corners is not None:
|
if antialias or align_corners is not None or mode == 'bicubic':
|
||||||
return_device = tensor.device
|
return_device = tensor.device
|
||||||
return_dtype = tensor.dtype
|
return_dtype = tensor.dtype
|
||||||
return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode,
|
return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode,
|
||||||
@@ -216,7 +216,9 @@ def torch_empty(*args, device=None, **kwargs):
|
|||||||
|
|
||||||
original_torch_randn = torch.randn
|
original_torch_randn = torch.randn
|
||||||
@wraps(torch.randn)
|
@wraps(torch.randn)
|
||||||
def torch_randn(*args, device=None, **kwargs):
|
def torch_randn(*args, device=None, dtype=None, **kwargs):
|
||||||
|
if dtype == bytes:
|
||||||
|
dtype = None
|
||||||
if check_device(device):
|
if check_device(device):
|
||||||
return original_torch_randn(*args, device=return_xpu(device), **kwargs)
|
return original_torch_randn(*args, device=return_xpu(device), **kwargs)
|
||||||
else:
|
else:
|
||||||
@@ -256,11 +258,11 @@ def torch_Generator(device=None):
|
|||||||
|
|
||||||
original_torch_load = torch.load
|
original_torch_load = torch.load
|
||||||
@wraps(torch.load)
|
@wraps(torch.load)
|
||||||
def torch_load(f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs):
|
def torch_load(f, map_location=None, *args, **kwargs):
|
||||||
if check_device(map_location):
|
if check_device(map_location):
|
||||||
return original_torch_load(f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs)
|
return original_torch_load(f, map_location=return_xpu(map_location), *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
return original_torch_load(f, map_location=map_location, pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs)
|
return original_torch_load(f, map_location=map_location, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
# Hijack Functions:
|
# Hijack Functions:
|
||||||
|
|||||||
Reference in New Issue
Block a user