mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Disable Diffusers slicing if device is not XPU
This commit is contained in:
@@ -20,7 +20,7 @@ def find_slice_size(slice_size, slice_block_size):
|
||||
return slice_size
|
||||
|
||||
@cache
|
||||
def find_attention_slice_sizes(query_shape, query_element_size, slice_size=None):
|
||||
def find_attention_slice_sizes(query_shape, query_element_size, query_device_type, slice_size=None):
|
||||
if len(query_shape) == 3:
|
||||
batch_size_attention, query_tokens, shape_three = query_shape
|
||||
shape_four = 1
|
||||
@@ -40,6 +40,9 @@ def find_attention_slice_sizes(query_shape, query_element_size, slice_size=None)
|
||||
do_split_2 = False
|
||||
do_split_3 = False
|
||||
|
||||
if query_device_type != "xpu"
|
||||
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
|
||||
|
||||
if block_size > attention_slice_rate:
|
||||
do_split = True
|
||||
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
|
||||
@@ -107,7 +110,7 @@ class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
|
||||
|
||||
####################################################################
|
||||
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
||||
_, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), slice_size=self.slice_size)
|
||||
_, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type, slice_size=self.slice_size)
|
||||
|
||||
for i in range(batch_size_attention // split_slice_size):
|
||||
start_idx = i * split_slice_size
|
||||
@@ -227,7 +230,7 @@ class AttnProcessor:
|
||||
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
||||
batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2]
|
||||
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
|
||||
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size())
|
||||
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type)
|
||||
|
||||
if do_split:
|
||||
for i in range(batch_size_attention // split_slice_size):
|
||||
|
||||
Reference in New Issue
Block a user