mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Disable Diffusers slicing if device is not XPU
This commit is contained in:
@@ -20,7 +20,7 @@ def find_slice_size(slice_size, slice_block_size):
|
|||||||
return slice_size
|
return slice_size
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
def find_attention_slice_sizes(query_shape, query_element_size, slice_size=None):
|
def find_attention_slice_sizes(query_shape, query_element_size, query_device_type, slice_size=None):
|
||||||
if len(query_shape) == 3:
|
if len(query_shape) == 3:
|
||||||
batch_size_attention, query_tokens, shape_three = query_shape
|
batch_size_attention, query_tokens, shape_three = query_shape
|
||||||
shape_four = 1
|
shape_four = 1
|
||||||
@@ -40,6 +40,9 @@ def find_attention_slice_sizes(query_shape, query_element_size, slice_size=None)
|
|||||||
do_split_2 = False
|
do_split_2 = False
|
||||||
do_split_3 = False
|
do_split_3 = False
|
||||||
|
|
||||||
|
if query_device_type != "xpu"
|
||||||
|
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
|
||||||
|
|
||||||
if block_size > attention_slice_rate:
|
if block_size > attention_slice_rate:
|
||||||
do_split = True
|
do_split = True
|
||||||
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
|
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
|
||||||
@@ -107,7 +110,7 @@ class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
|
|||||||
|
|
||||||
####################################################################
|
####################################################################
|
||||||
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
||||||
_, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), slice_size=self.slice_size)
|
_, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type, slice_size=self.slice_size)
|
||||||
|
|
||||||
for i in range(batch_size_attention // split_slice_size):
|
for i in range(batch_size_attention // split_slice_size):
|
||||||
start_idx = i * split_slice_size
|
start_idx = i * split_slice_size
|
||||||
@@ -227,7 +230,7 @@ class AttnProcessor:
|
|||||||
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
||||||
batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2]
|
batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2]
|
||||||
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
|
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
|
||||||
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size())
|
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type)
|
||||||
|
|
||||||
if do_split:
|
if do_split:
|
||||||
for i in range(batch_size_attention // split_slice_size):
|
for i in range(batch_size_attention // split_slice_size):
|
||||||
|
|||||||
Reference in New Issue
Block a user