mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Merge branch 'dev' into gradual_latent_hires_fix
This commit is contained in:
@@ -53,6 +53,7 @@ class BaseSubsetParams:
|
||||
shuffle_caption: bool = False
|
||||
caption_separator: str = ',',
|
||||
keep_tokens: int = 0
|
||||
keep_tokens_separator: str = None,
|
||||
color_aug: bool = False
|
||||
flip_aug: bool = False
|
||||
face_crop_aug_range: Optional[Tuple[float, float]] = None
|
||||
@@ -160,6 +161,7 @@ class ConfigSanitizer:
|
||||
"random_crop": bool,
|
||||
"shuffle_caption": bool,
|
||||
"keep_tokens": int,
|
||||
"keep_tokens_separator": str,
|
||||
"token_warmup_min": int,
|
||||
"token_warmup_step": Any(float,int),
|
||||
"caption_prefix": str,
|
||||
@@ -461,6 +463,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
num_repeats: {subset.num_repeats}
|
||||
shuffle_caption: {subset.shuffle_caption}
|
||||
keep_tokens: {subset.keep_tokens}
|
||||
keep_tokens_separator: {subset.keep_tokens_separator}
|
||||
caption_dropout_rate: {subset.caption_dropout_rate}
|
||||
caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
|
||||
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
|
||||
|
||||
@@ -140,6 +140,7 @@ def ipex_init(): # pylint: disable=too-many-statements
|
||||
|
||||
# C
|
||||
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
|
||||
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_eu_count
|
||||
ipex._C._DeviceProperties.major = 2023
|
||||
ipex._C._DeviceProperties.minor = 2
|
||||
|
||||
@@ -156,20 +157,9 @@ def ipex_init(): # pylint: disable=too-many-statements
|
||||
torch.cuda.get_device_properties.minor = 7
|
||||
torch.cuda.ipc_collect = lambda *args, **kwargs: None
|
||||
torch.cuda.utilization = lambda *args, **kwargs: 0
|
||||
if hasattr(torch.xpu, 'getDeviceIdListForCard'):
|
||||
torch.cuda.getDeviceIdListForCard = torch.xpu.getDeviceIdListForCard
|
||||
torch.cuda.get_device_id_list_per_card = torch.xpu.getDeviceIdListForCard
|
||||
else:
|
||||
torch.cuda.getDeviceIdListForCard = torch.xpu.get_device_id_list_per_card
|
||||
torch.cuda.get_device_id_list_per_card = torch.xpu.get_device_id_list_per_card
|
||||
|
||||
ipex_hijacks()
|
||||
if not torch.xpu.has_fp64_dtype():
|
||||
try:
|
||||
from .attention import attention_init
|
||||
attention_init()
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
pass
|
||||
try:
|
||||
from .diffusers import ipex_diffusers
|
||||
ipex_diffusers()
|
||||
|
||||
@@ -1,45 +1,98 @@
|
||||
import os
|
||||
import torch
|
||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||
from functools import cache
|
||||
|
||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||
|
||||
original_torch_bmm = torch.bmm
|
||||
def torch_bmm(input, mat2, *, out=None):
|
||||
if input.dtype != mat2.dtype:
|
||||
mat2 = mat2.to(input.dtype)
|
||||
# ARC GPUs can't allocate more than 4GB to a single block so we slice the attetion layers
|
||||
|
||||
#ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
||||
batch_size_attention, input_tokens, mat2_shape = input.shape[0], input.shape[1], mat2.shape[2]
|
||||
block_multiply = input.element_size()
|
||||
slice_block_size = input_tokens * mat2_shape / 1024 / 1024 * block_multiply
|
||||
sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 4))
|
||||
attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))
|
||||
|
||||
# Find something divisible with the input_tokens
|
||||
@cache
|
||||
def find_slice_size(slice_size, slice_block_size):
|
||||
while (slice_size * slice_block_size) > attention_slice_rate:
|
||||
slice_size = slice_size // 2
|
||||
if slice_size <= 1:
|
||||
slice_size = 1
|
||||
break
|
||||
return slice_size
|
||||
|
||||
# Find slice sizes for SDPA
|
||||
@cache
|
||||
def find_sdpa_slice_sizes(query_shape, query_element_size):
|
||||
if len(query_shape) == 3:
|
||||
batch_size_attention, query_tokens, shape_three = query_shape
|
||||
shape_four = 1
|
||||
else:
|
||||
batch_size_attention, query_tokens, shape_three, shape_four = query_shape
|
||||
|
||||
slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size
|
||||
block_size = batch_size_attention * slice_block_size
|
||||
|
||||
split_slice_size = batch_size_attention
|
||||
if block_size > 4:
|
||||
split_2_slice_size = query_tokens
|
||||
split_3_slice_size = shape_three
|
||||
|
||||
do_split = False
|
||||
do_split_2 = False
|
||||
do_split_3 = False
|
||||
|
||||
if block_size > sdpa_slice_trigger_rate:
|
||||
do_split = True
|
||||
#Find something divisible with the input_tokens
|
||||
while (split_slice_size * slice_block_size) > 4:
|
||||
split_slice_size = split_slice_size // 2
|
||||
if split_slice_size <= 1:
|
||||
split_slice_size = 1
|
||||
break
|
||||
else:
|
||||
do_split = False
|
||||
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
|
||||
if split_slice_size * slice_block_size > attention_slice_rate:
|
||||
slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size
|
||||
do_split_2 = True
|
||||
split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
|
||||
if split_2_slice_size * slice_2_block_size > attention_slice_rate:
|
||||
slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size
|
||||
do_split_3 = True
|
||||
split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
|
||||
|
||||
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
|
||||
|
||||
# Find slice sizes for BMM
|
||||
@cache
|
||||
def find_bmm_slice_sizes(input_shape, input_element_size, mat2_shape):
|
||||
batch_size_attention, input_tokens, mat2_atten_shape = input_shape[0], input_shape[1], mat2_shape[2]
|
||||
slice_block_size = input_tokens * mat2_atten_shape / 1024 / 1024 * input_element_size
|
||||
block_size = batch_size_attention * slice_block_size
|
||||
|
||||
split_slice_size = batch_size_attention
|
||||
split_2_slice_size = input_tokens
|
||||
if split_slice_size * slice_block_size > 4:
|
||||
slice_block_size2 = split_slice_size * mat2_shape / 1024 / 1024 * block_multiply
|
||||
do_split_2 = True
|
||||
#Find something divisible with the input_tokens
|
||||
while (split_2_slice_size * slice_block_size2) > 4:
|
||||
split_2_slice_size = split_2_slice_size // 2
|
||||
if split_2_slice_size <= 1:
|
||||
split_2_slice_size = 1
|
||||
break
|
||||
else:
|
||||
do_split_2 = False
|
||||
split_3_slice_size = mat2_atten_shape
|
||||
|
||||
do_split = False
|
||||
do_split_2 = False
|
||||
do_split_3 = False
|
||||
|
||||
if block_size > attention_slice_rate:
|
||||
do_split = True
|
||||
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
|
||||
if split_slice_size * slice_block_size > attention_slice_rate:
|
||||
slice_2_block_size = split_slice_size * mat2_atten_shape / 1024 / 1024 * input_element_size
|
||||
do_split_2 = True
|
||||
split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
|
||||
if split_2_slice_size * slice_2_block_size > attention_slice_rate:
|
||||
slice_3_block_size = split_slice_size * split_2_slice_size / 1024 / 1024 * input_element_size
|
||||
do_split_3 = True
|
||||
split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
|
||||
|
||||
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
|
||||
|
||||
|
||||
original_torch_bmm = torch.bmm
|
||||
def torch_bmm_32_bit(input, mat2, *, out=None):
|
||||
if input.device.type != "xpu":
|
||||
return original_torch_bmm(input, mat2, out=out)
|
||||
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_bmm_slice_sizes(input.shape, input.element_size(), mat2.shape)
|
||||
|
||||
# Slice BMM
|
||||
if do_split:
|
||||
batch_size_attention, input_tokens, mat2_atten_shape = input.shape[0], input.shape[1], mat2.shape[2]
|
||||
hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype)
|
||||
for i in range(batch_size_attention // split_slice_size):
|
||||
start_idx = i * split_slice_size
|
||||
@@ -48,11 +101,21 @@ def torch_bmm(input, mat2, *, out=None):
|
||||
for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
||||
start_idx_2 = i2 * split_2_slice_size
|
||||
end_idx_2 = (i2 + 1) * split_2_slice_size
|
||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm(
|
||||
input[start_idx:end_idx, start_idx_2:end_idx_2],
|
||||
mat2[start_idx:end_idx, start_idx_2:end_idx_2],
|
||||
out=out
|
||||
)
|
||||
if do_split_3:
|
||||
for i3 in range(mat2_atten_shape // split_3_slice_size): # pylint: disable=invalid-name
|
||||
start_idx_3 = i3 * split_3_slice_size
|
||||
end_idx_3 = (i3 + 1) * split_3_slice_size
|
||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_torch_bmm(
|
||||
input[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
||||
mat2[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
||||
out=out
|
||||
)
|
||||
else:
|
||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm(
|
||||
input[start_idx:end_idx, start_idx_2:end_idx_2],
|
||||
mat2[start_idx:end_idx, start_idx_2:end_idx_2],
|
||||
out=out
|
||||
)
|
||||
else:
|
||||
hidden_states[start_idx:end_idx] = original_torch_bmm(
|
||||
input[start_idx:end_idx],
|
||||
@@ -64,51 +127,14 @@ def torch_bmm(input, mat2, *, out=None):
|
||||
return hidden_states
|
||||
|
||||
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
||||
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
|
||||
#ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
||||
if len(query.shape) == 3:
|
||||
batch_size_attention, query_tokens, shape_four = query.shape
|
||||
shape_one = 1
|
||||
no_shape_one = True
|
||||
else:
|
||||
shape_one, batch_size_attention, query_tokens, shape_four = query.shape
|
||||
no_shape_one = False
|
||||
|
||||
if query.dtype != key.dtype:
|
||||
key = key.to(dtype=query.dtype)
|
||||
if query.dtype != value.dtype:
|
||||
value = value.to(dtype=query.dtype)
|
||||
|
||||
block_multiply = query.element_size()
|
||||
slice_block_size = shape_one * query_tokens * shape_four / 1024 / 1024 * block_multiply
|
||||
block_size = batch_size_attention * slice_block_size
|
||||
|
||||
split_slice_size = batch_size_attention
|
||||
if block_size > 4:
|
||||
do_split = True
|
||||
#Find something divisible with the shape_one
|
||||
while (split_slice_size * slice_block_size) > 4:
|
||||
split_slice_size = split_slice_size // 2
|
||||
if split_slice_size <= 1:
|
||||
split_slice_size = 1
|
||||
break
|
||||
else:
|
||||
do_split = False
|
||||
|
||||
split_2_slice_size = query_tokens
|
||||
if split_slice_size * slice_block_size > 4:
|
||||
slice_block_size2 = shape_one * split_slice_size * shape_four / 1024 / 1024 * block_multiply
|
||||
do_split_2 = True
|
||||
#Find something divisible with the batch_size_attention
|
||||
while (split_2_slice_size * slice_block_size2) > 4:
|
||||
split_2_slice_size = split_2_slice_size // 2
|
||||
if split_2_slice_size <= 1:
|
||||
split_2_slice_size = 1
|
||||
break
|
||||
else:
|
||||
do_split_2 = False
|
||||
def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
|
||||
if query.device.type != "xpu":
|
||||
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
|
||||
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_sdpa_slice_sizes(query.shape, query.element_size())
|
||||
|
||||
# Slice SDPA
|
||||
if do_split:
|
||||
batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2]
|
||||
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
|
||||
for i in range(batch_size_attention // split_slice_size):
|
||||
start_idx = i * split_slice_size
|
||||
@@ -117,7 +143,18 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.
|
||||
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
||||
start_idx_2 = i2 * split_2_slice_size
|
||||
end_idx_2 = (i2 + 1) * split_2_slice_size
|
||||
if no_shape_one:
|
||||
if do_split_3:
|
||||
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
|
||||
start_idx_3 = i3 * split_3_slice_size
|
||||
end_idx_3 = (i3 + 1) * split_3_slice_size
|
||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_scaled_dot_product_attention(
|
||||
query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
||||
key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
||||
value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
||||
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attn_mask is not None else attn_mask,
|
||||
dropout_p=dropout_p, is_causal=is_causal
|
||||
)
|
||||
else:
|
||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention(
|
||||
query[start_idx:end_idx, start_idx_2:end_idx_2],
|
||||
key[start_idx:end_idx, start_idx_2:end_idx_2],
|
||||
@@ -125,38 +162,14 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.
|
||||
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask,
|
||||
dropout_p=dropout_p, is_causal=is_causal
|
||||
)
|
||||
else:
|
||||
hidden_states[:, start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention(
|
||||
query[:, start_idx:end_idx, start_idx_2:end_idx_2],
|
||||
key[:, start_idx:end_idx, start_idx_2:end_idx_2],
|
||||
value[:, start_idx:end_idx, start_idx_2:end_idx_2],
|
||||
attn_mask=attn_mask[:, start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask,
|
||||
dropout_p=dropout_p, is_causal=is_causal
|
||||
)
|
||||
else:
|
||||
if no_shape_one:
|
||||
hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention(
|
||||
query[start_idx:end_idx],
|
||||
key[start_idx:end_idx],
|
||||
value[start_idx:end_idx],
|
||||
attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask,
|
||||
dropout_p=dropout_p, is_causal=is_causal
|
||||
)
|
||||
else:
|
||||
hidden_states[:, start_idx:end_idx] = original_scaled_dot_product_attention(
|
||||
query[:, start_idx:end_idx],
|
||||
key[:, start_idx:end_idx],
|
||||
value[:, start_idx:end_idx],
|
||||
attn_mask=attn_mask[:, start_idx:end_idx] if attn_mask is not None else attn_mask,
|
||||
dropout_p=dropout_p, is_causal=is_causal
|
||||
)
|
||||
hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention(
|
||||
query[start_idx:end_idx],
|
||||
key[start_idx:end_idx],
|
||||
value[start_idx:end_idx],
|
||||
attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask,
|
||||
dropout_p=dropout_p, is_causal=is_causal
|
||||
)
|
||||
else:
|
||||
return original_scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal
|
||||
)
|
||||
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
|
||||
return hidden_states
|
||||
|
||||
def attention_init():
|
||||
#ARC GPUs can't allocate more than 4GB to a single block:
|
||||
torch.bmm = torch_bmm
|
||||
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention
|
||||
|
||||
@@ -1,10 +1,62 @@
|
||||
import os
|
||||
import torch
|
||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||
import diffusers #0.24.0 # pylint: disable=import-error
|
||||
from diffusers.models.attention_processor import Attention
|
||||
from diffusers.utils import USE_PEFT_BACKEND
|
||||
from functools import cache
|
||||
|
||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||
|
||||
attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))
|
||||
|
||||
@cache
|
||||
def find_slice_size(slice_size, slice_block_size):
|
||||
while (slice_size * slice_block_size) > attention_slice_rate:
|
||||
slice_size = slice_size // 2
|
||||
if slice_size <= 1:
|
||||
slice_size = 1
|
||||
break
|
||||
return slice_size
|
||||
|
||||
@cache
|
||||
def find_attention_slice_sizes(query_shape, query_element_size, query_device_type, slice_size=None):
|
||||
if len(query_shape) == 3:
|
||||
batch_size_attention, query_tokens, shape_three = query_shape
|
||||
shape_four = 1
|
||||
else:
|
||||
batch_size_attention, query_tokens, shape_three, shape_four = query_shape
|
||||
if slice_size is not None:
|
||||
batch_size_attention = slice_size
|
||||
|
||||
slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size
|
||||
block_size = batch_size_attention * slice_block_size
|
||||
|
||||
split_slice_size = batch_size_attention
|
||||
split_2_slice_size = query_tokens
|
||||
split_3_slice_size = shape_three
|
||||
|
||||
do_split = False
|
||||
do_split_2 = False
|
||||
do_split_3 = False
|
||||
|
||||
if query_device_type != "xpu":
|
||||
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
|
||||
|
||||
if block_size > attention_slice_rate:
|
||||
do_split = True
|
||||
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
|
||||
if split_slice_size * slice_block_size > attention_slice_rate:
|
||||
slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size
|
||||
do_split_2 = True
|
||||
split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
|
||||
if split_2_slice_size * slice_2_block_size > attention_slice_rate:
|
||||
slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size
|
||||
do_split_3 = True
|
||||
split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
|
||||
|
||||
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
|
||||
|
||||
class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
|
||||
r"""
|
||||
Processor for implementing sliced attention.
|
||||
@@ -18,7 +70,9 @@ class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
|
||||
def __init__(self, slice_size):
|
||||
self.slice_size = slice_size
|
||||
|
||||
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): # pylint: disable=too-many-statements, too-many-locals, too-many-branches
|
||||
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states=None, attention_mask=None) -> torch.FloatTensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
@@ -54,49 +108,61 @@ class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
|
||||
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
|
||||
)
|
||||
|
||||
#ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
||||
block_multiply = query.element_size()
|
||||
slice_block_size = self.slice_size * shape_three / 1024 / 1024 * block_multiply
|
||||
block_size = query_tokens * slice_block_size
|
||||
split_2_slice_size = query_tokens
|
||||
if block_size > 4:
|
||||
do_split_2 = True
|
||||
#Find something divisible with the query_tokens
|
||||
while (split_2_slice_size * slice_block_size) > 4:
|
||||
split_2_slice_size = split_2_slice_size // 2
|
||||
if split_2_slice_size <= 1:
|
||||
split_2_slice_size = 1
|
||||
break
|
||||
else:
|
||||
do_split_2 = False
|
||||
|
||||
for i in range(batch_size_attention // self.slice_size):
|
||||
start_idx = i * self.slice_size
|
||||
end_idx = (i + 1) * self.slice_size
|
||||
####################################################################
|
||||
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
||||
_, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type, slice_size=self.slice_size)
|
||||
|
||||
for i in range(batch_size_attention // split_slice_size):
|
||||
start_idx = i * split_slice_size
|
||||
end_idx = (i + 1) * split_slice_size
|
||||
if do_split_2:
|
||||
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
||||
start_idx_2 = i2 * split_2_slice_size
|
||||
end_idx_2 = (i2 + 1) * split_2_slice_size
|
||||
if do_split_3:
|
||||
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
|
||||
start_idx_3 = i3 * split_3_slice_size
|
||||
end_idx_3 = (i3 + 1) * split_3_slice_size
|
||||
|
||||
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
|
||||
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
|
||||
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
|
||||
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
|
||||
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
|
||||
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None
|
||||
|
||||
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
||||
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
|
||||
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
||||
del query_slice
|
||||
del key_slice
|
||||
del attn_mask_slice
|
||||
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3])
|
||||
|
||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
|
||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice
|
||||
del attn_slice
|
||||
else:
|
||||
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
|
||||
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
|
||||
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
|
||||
|
||||
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
||||
del query_slice
|
||||
del key_slice
|
||||
del attn_mask_slice
|
||||
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
|
||||
|
||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
|
||||
del attn_slice
|
||||
else:
|
||||
query_slice = query[start_idx:end_idx]
|
||||
key_slice = key[start_idx:end_idx]
|
||||
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
||||
|
||||
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
||||
|
||||
del query_slice
|
||||
del key_slice
|
||||
del attn_mask_slice
|
||||
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
||||
|
||||
hidden_states[start_idx:end_idx] = attn_slice
|
||||
del attn_slice
|
||||
####################################################################
|
||||
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
@@ -115,6 +181,130 @@ class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AttnProcessor:
|
||||
r"""
|
||||
Default processor for performing attention-related computations.
|
||||
"""
|
||||
|
||||
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states=None, attention_mask=None,
|
||||
temb=None, scale: float = 1.0) -> torch.Tensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
args = () if USE_PEFT_BACKEND else (scale,)
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states, *args)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states, *args)
|
||||
value = attn.to_v(encoder_hidden_states, *args)
|
||||
|
||||
query = attn.head_to_batch_dim(query)
|
||||
key = attn.head_to_batch_dim(key)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
|
||||
####################################################################
|
||||
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
||||
batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2]
|
||||
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
|
||||
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type)
|
||||
|
||||
if do_split:
|
||||
for i in range(batch_size_attention // split_slice_size):
|
||||
start_idx = i * split_slice_size
|
||||
end_idx = (i + 1) * split_slice_size
|
||||
if do_split_2:
|
||||
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
||||
start_idx_2 = i2 * split_2_slice_size
|
||||
end_idx_2 = (i2 + 1) * split_2_slice_size
|
||||
if do_split_3:
|
||||
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
|
||||
start_idx_3 = i3 * split_3_slice_size
|
||||
end_idx_3 = (i3 + 1) * split_3_slice_size
|
||||
|
||||
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
|
||||
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
|
||||
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None
|
||||
|
||||
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
||||
del query_slice
|
||||
del key_slice
|
||||
del attn_mask_slice
|
||||
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3])
|
||||
|
||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice
|
||||
del attn_slice
|
||||
else:
|
||||
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
|
||||
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
|
||||
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
|
||||
|
||||
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
||||
del query_slice
|
||||
del key_slice
|
||||
del attn_mask_slice
|
||||
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
|
||||
|
||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
|
||||
del attn_slice
|
||||
else:
|
||||
query_slice = query[start_idx:end_idx]
|
||||
key_slice = key[start_idx:end_idx]
|
||||
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
||||
|
||||
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
||||
del query_slice
|
||||
del key_slice
|
||||
del attn_mask_slice
|
||||
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
||||
|
||||
hidden_states[start_idx:end_idx] = attn_slice
|
||||
del attn_slice
|
||||
else:
|
||||
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
||||
hidden_states = torch.bmm(attention_probs, value)
|
||||
####################################################################
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states, *args)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
def ipex_diffusers():
|
||||
#ARC GPUs can't allocate more than 4GB to a single block:
|
||||
diffusers.models.attention_processor.SlicedAttnProcessor = SlicedAttnProcessor
|
||||
diffusers.models.attention_processor.AttnProcessor = AttnProcessor
|
||||
|
||||
@@ -1,67 +1,9 @@
|
||||
import contextlib
|
||||
import importlib
|
||||
import torch
|
||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||
|
||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return
|
||||
|
||||
class CondFunc: # pylint: disable=missing-class-docstring
|
||||
def __new__(cls, orig_func, sub_func, cond_func):
|
||||
self = super(CondFunc, cls).__new__(cls)
|
||||
if isinstance(orig_func, str):
|
||||
func_path = orig_func.split('.')
|
||||
for i in range(len(func_path)-1, -1, -1):
|
||||
try:
|
||||
resolved_obj = importlib.import_module('.'.join(func_path[:i]))
|
||||
break
|
||||
except ImportError:
|
||||
pass
|
||||
for attr_name in func_path[i:-1]:
|
||||
resolved_obj = getattr(resolved_obj, attr_name)
|
||||
orig_func = getattr(resolved_obj, func_path[-1])
|
||||
setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
|
||||
self.__init__(orig_func, sub_func, cond_func)
|
||||
return lambda *args, **kwargs: self(*args, **kwargs)
|
||||
def __init__(self, orig_func, sub_func, cond_func):
|
||||
self.__orig_func = orig_func
|
||||
self.__sub_func = sub_func
|
||||
self.__cond_func = cond_func
|
||||
def __call__(self, *args, **kwargs):
|
||||
if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs):
|
||||
return self.__sub_func(self.__orig_func, *args, **kwargs)
|
||||
else:
|
||||
return self.__orig_func(*args, **kwargs)
|
||||
|
||||
_utils = torch.utils.data._utils
|
||||
def _shutdown_workers(self):
|
||||
if torch.utils.data._utils is None or torch.utils.data._utils.python_exit_status is True or torch.utils.data._utils.python_exit_status is None:
|
||||
return
|
||||
if hasattr(self, "_shutdown") and not self._shutdown:
|
||||
self._shutdown = True
|
||||
try:
|
||||
if hasattr(self, '_pin_memory_thread'):
|
||||
self._pin_memory_thread_done_event.set()
|
||||
self._worker_result_queue.put((None, None))
|
||||
self._pin_memory_thread.join()
|
||||
self._worker_result_queue.cancel_join_thread()
|
||||
self._worker_result_queue.close()
|
||||
self._workers_done_event.set()
|
||||
for worker_id in range(len(self._workers)):
|
||||
if self._persistent_workers or self._workers_status[worker_id]:
|
||||
self._mark_worker_as_unavailable(worker_id, shutdown=True)
|
||||
for w in self._workers: # pylint: disable=invalid-name
|
||||
w.join(timeout=torch.utils.data._utils.MP_STATUS_CHECK_INTERVAL)
|
||||
for q in self._index_queues: # pylint: disable=invalid-name
|
||||
q.cancel_join_thread()
|
||||
q.close()
|
||||
finally:
|
||||
if self._worker_pids_set:
|
||||
torch.utils.data._utils.signal_handling._remove_worker_pids(id(self))
|
||||
self._worker_pids_set = False
|
||||
for w in self._workers: # pylint: disable=invalid-name
|
||||
if w.is_alive():
|
||||
w.terminate()
|
||||
|
||||
class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods
|
||||
def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument
|
||||
if isinstance(device_ids, list) and len(device_ids) > 1:
|
||||
@@ -71,17 +13,18 @@ class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstr
|
||||
def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
|
||||
return contextlib.nullcontext()
|
||||
|
||||
@property
|
||||
def is_cuda(self):
|
||||
return self.device.type == 'xpu' or self.device.type == 'cuda'
|
||||
|
||||
def check_device(device):
|
||||
return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int))
|
||||
|
||||
def return_xpu(device):
|
||||
return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu"
|
||||
|
||||
def ipex_no_cuda(orig_func, *args, **kwargs):
|
||||
torch.cuda.is_available = lambda: False
|
||||
orig_func(*args, **kwargs)
|
||||
torch.cuda.is_available = torch.xpu.is_available
|
||||
|
||||
# Autocast
|
||||
original_autocast = torch.autocast
|
||||
def ipex_autocast(*args, **kwargs):
|
||||
if len(args) > 0 and args[0] == "cuda":
|
||||
@@ -89,15 +32,7 @@ def ipex_autocast(*args, **kwargs):
|
||||
else:
|
||||
return original_autocast(*args, **kwargs)
|
||||
|
||||
# Embedding BF16
|
||||
original_torch_cat = torch.cat
|
||||
def torch_cat(tensor, *args, **kwargs):
|
||||
if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype):
|
||||
return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs)
|
||||
else:
|
||||
return original_torch_cat(tensor, *args, **kwargs)
|
||||
|
||||
# Latent antialias:
|
||||
# Latent Antialias CPU Offload:
|
||||
original_interpolate = torch.nn.functional.interpolate
|
||||
def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
|
||||
if antialias or align_corners is not None:
|
||||
@@ -109,110 +44,205 @@ def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corn
|
||||
return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode,
|
||||
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias)
|
||||
|
||||
original_linalg_solve = torch.linalg.solve
|
||||
def linalg_solve(A, B, *args, **kwargs): # pylint: disable=invalid-name
|
||||
if A.device != torch.device("cpu") or B.device != torch.device("cpu"):
|
||||
return_device = A.device
|
||||
return original_linalg_solve(A.to("cpu"), B.to("cpu"), *args, **kwargs).to(return_device)
|
||||
# Diffusers Float64 (Alchemist GPUs doesn't support 64 bit):
|
||||
original_from_numpy = torch.from_numpy
|
||||
def from_numpy(ndarray):
|
||||
if ndarray.dtype == float:
|
||||
return original_from_numpy(ndarray.astype('float32'))
|
||||
else:
|
||||
return original_linalg_solve(A, B, *args, **kwargs)
|
||||
return original_from_numpy(ndarray)
|
||||
|
||||
@property
|
||||
def is_cuda(self):
|
||||
return self.device.type == 'xpu'
|
||||
if torch.xpu.has_fp64_dtype():
|
||||
original_torch_bmm = torch.bmm
|
||||
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
||||
else:
|
||||
# 32 bit attention workarounds for Alchemist:
|
||||
try:
|
||||
from .attention import torch_bmm_32_bit as original_torch_bmm
|
||||
from .attention import scaled_dot_product_attention_32_bit as original_scaled_dot_product_attention
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
original_torch_bmm = torch.bmm
|
||||
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
||||
|
||||
|
||||
# Data Type Errors:
|
||||
def torch_bmm(input, mat2, *, out=None):
|
||||
if input.dtype != mat2.dtype:
|
||||
mat2 = mat2.to(input.dtype)
|
||||
return original_torch_bmm(input, mat2, out=out)
|
||||
|
||||
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
|
||||
if query.dtype != key.dtype:
|
||||
key = key.to(dtype=query.dtype)
|
||||
if query.dtype != value.dtype:
|
||||
value = value.to(dtype=query.dtype)
|
||||
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
|
||||
|
||||
# A1111 FP16
|
||||
original_functional_group_norm = torch.nn.functional.group_norm
|
||||
def functional_group_norm(input, num_groups, weight=None, bias=None, eps=1e-05):
|
||||
if weight is not None and input.dtype != weight.data.dtype:
|
||||
input = input.to(dtype=weight.data.dtype)
|
||||
if bias is not None and weight is not None and bias.data.dtype != weight.data.dtype:
|
||||
bias.data = bias.data.to(dtype=weight.data.dtype)
|
||||
return original_functional_group_norm(input, num_groups, weight=weight, bias=bias, eps=eps)
|
||||
|
||||
# A1111 BF16
|
||||
original_functional_layer_norm = torch.nn.functional.layer_norm
|
||||
def functional_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
|
||||
if weight is not None and input.dtype != weight.data.dtype:
|
||||
input = input.to(dtype=weight.data.dtype)
|
||||
if bias is not None and weight is not None and bias.data.dtype != weight.data.dtype:
|
||||
bias.data = bias.data.to(dtype=weight.data.dtype)
|
||||
return original_functional_layer_norm(input, normalized_shape, weight=weight, bias=bias, eps=eps)
|
||||
|
||||
# Training
|
||||
original_functional_linear = torch.nn.functional.linear
|
||||
def functional_linear(input, weight, bias=None):
|
||||
if input.dtype != weight.data.dtype:
|
||||
input = input.to(dtype=weight.data.dtype)
|
||||
if bias is not None and bias.data.dtype != weight.data.dtype:
|
||||
bias.data = bias.data.to(dtype=weight.data.dtype)
|
||||
return original_functional_linear(input, weight, bias=bias)
|
||||
|
||||
original_functional_conv2d = torch.nn.functional.conv2d
|
||||
def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
||||
if input.dtype != weight.data.dtype:
|
||||
input = input.to(dtype=weight.data.dtype)
|
||||
if bias is not None and bias.data.dtype != weight.data.dtype:
|
||||
bias.data = bias.data.to(dtype=weight.data.dtype)
|
||||
return original_functional_conv2d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
||||
|
||||
# A1111 Embedding BF16
|
||||
original_torch_cat = torch.cat
|
||||
def torch_cat(tensor, *args, **kwargs):
|
||||
if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype):
|
||||
return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs)
|
||||
else:
|
||||
return original_torch_cat(tensor, *args, **kwargs)
|
||||
|
||||
# SwinIR BF16:
|
||||
original_functional_pad = torch.nn.functional.pad
|
||||
def functional_pad(input, pad, mode='constant', value=None):
|
||||
if mode == 'reflect' and input.dtype == torch.bfloat16:
|
||||
return original_functional_pad(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16)
|
||||
else:
|
||||
return original_functional_pad(input, pad, mode=mode, value=value)
|
||||
|
||||
|
||||
original_torch_tensor = torch.tensor
|
||||
def torch_tensor(*args, device=None, **kwargs):
|
||||
if check_device(device):
|
||||
return original_torch_tensor(*args, device=return_xpu(device), **kwargs)
|
||||
else:
|
||||
return original_torch_tensor(*args, device=device, **kwargs)
|
||||
|
||||
original_Tensor_to = torch.Tensor.to
|
||||
def Tensor_to(self, device=None, *args, **kwargs):
|
||||
if check_device(device):
|
||||
return original_Tensor_to(self, return_xpu(device), *args, **kwargs)
|
||||
else:
|
||||
return original_Tensor_to(self, device, *args, **kwargs)
|
||||
|
||||
original_Tensor_cuda = torch.Tensor.cuda
|
||||
def Tensor_cuda(self, device=None, *args, **kwargs):
|
||||
if check_device(device):
|
||||
return original_Tensor_cuda(self, return_xpu(device), *args, **kwargs)
|
||||
else:
|
||||
return original_Tensor_cuda(self, device, *args, **kwargs)
|
||||
|
||||
original_UntypedStorage_init = torch.UntypedStorage.__init__
|
||||
def UntypedStorage_init(*args, device=None, **kwargs):
|
||||
if check_device(device):
|
||||
return original_UntypedStorage_init(*args, device=return_xpu(device), **kwargs)
|
||||
else:
|
||||
return original_UntypedStorage_init(*args, device=device, **kwargs)
|
||||
|
||||
original_UntypedStorage_cuda = torch.UntypedStorage.cuda
|
||||
def UntypedStorage_cuda(self, device=None, *args, **kwargs):
|
||||
if check_device(device):
|
||||
return original_UntypedStorage_cuda(self, return_xpu(device), *args, **kwargs)
|
||||
else:
|
||||
return original_UntypedStorage_cuda(self, device, *args, **kwargs)
|
||||
|
||||
original_torch_empty = torch.empty
|
||||
def torch_empty(*args, device=None, **kwargs):
|
||||
if check_device(device):
|
||||
return original_torch_empty(*args, device=return_xpu(device), **kwargs)
|
||||
else:
|
||||
return original_torch_empty(*args, device=device, **kwargs)
|
||||
|
||||
original_torch_randn = torch.randn
|
||||
def torch_randn(*args, device=None, **kwargs):
|
||||
if check_device(device):
|
||||
return original_torch_randn(*args, device=return_xpu(device), **kwargs)
|
||||
else:
|
||||
return original_torch_randn(*args, device=device, **kwargs)
|
||||
|
||||
original_torch_ones = torch.ones
|
||||
def torch_ones(*args, device=None, **kwargs):
|
||||
if check_device(device):
|
||||
return original_torch_ones(*args, device=return_xpu(device), **kwargs)
|
||||
else:
|
||||
return original_torch_ones(*args, device=device, **kwargs)
|
||||
|
||||
original_torch_zeros = torch.zeros
|
||||
def torch_zeros(*args, device=None, **kwargs):
|
||||
if check_device(device):
|
||||
return original_torch_zeros(*args, device=return_xpu(device), **kwargs)
|
||||
else:
|
||||
return original_torch_zeros(*args, device=device, **kwargs)
|
||||
|
||||
original_torch_linspace = torch.linspace
|
||||
def torch_linspace(*args, device=None, **kwargs):
|
||||
if check_device(device):
|
||||
return original_torch_linspace(*args, device=return_xpu(device), **kwargs)
|
||||
else:
|
||||
return original_torch_linspace(*args, device=device, **kwargs)
|
||||
|
||||
original_torch_Generator = torch.Generator
|
||||
def torch_Generator(device=None):
|
||||
if check_device(device):
|
||||
return original_torch_Generator(return_xpu(device))
|
||||
else:
|
||||
return original_torch_Generator(device)
|
||||
|
||||
original_torch_load = torch.load
|
||||
def torch_load(f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs):
|
||||
if check_device(map_location):
|
||||
return original_torch_load(f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs)
|
||||
else:
|
||||
return original_torch_load(f, map_location=map_location, pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs)
|
||||
|
||||
# Hijack Functions:
|
||||
def ipex_hijacks():
|
||||
CondFunc('torch.tensor',
|
||||
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
|
||||
lambda orig_func, *args, device=None, **kwargs: check_device(device))
|
||||
CondFunc('torch.Tensor.to',
|
||||
lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs),
|
||||
lambda orig_func, self, device=None, *args, **kwargs: check_device(device))
|
||||
CondFunc('torch.Tensor.cuda',
|
||||
lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs),
|
||||
lambda orig_func, self, device=None, *args, **kwargs: check_device(device))
|
||||
CondFunc('torch.UntypedStorage.__init__',
|
||||
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
|
||||
lambda orig_func, *args, device=None, **kwargs: check_device(device))
|
||||
CondFunc('torch.UntypedStorage.cuda',
|
||||
lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs),
|
||||
lambda orig_func, self, device=None, *args, **kwargs: check_device(device))
|
||||
CondFunc('torch.empty',
|
||||
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
|
||||
lambda orig_func, *args, device=None, **kwargs: check_device(device))
|
||||
CondFunc('torch.randn',
|
||||
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
|
||||
lambda orig_func, *args, device=None, **kwargs: check_device(device))
|
||||
CondFunc('torch.ones',
|
||||
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
|
||||
lambda orig_func, *args, device=None, **kwargs: check_device(device))
|
||||
CondFunc('torch.zeros',
|
||||
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
|
||||
lambda orig_func, *args, device=None, **kwargs: check_device(device))
|
||||
CondFunc('torch.linspace',
|
||||
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
|
||||
lambda orig_func, *args, device=None, **kwargs: check_device(device))
|
||||
CondFunc('torch.load',
|
||||
lambda orig_func, f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs:
|
||||
orig_func(orig_func, f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs),
|
||||
lambda orig_func, f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs: check_device(map_location))
|
||||
torch.tensor = torch_tensor
|
||||
torch.Tensor.to = Tensor_to
|
||||
torch.Tensor.cuda = Tensor_cuda
|
||||
torch.UntypedStorage.__init__ = UntypedStorage_init
|
||||
torch.UntypedStorage.cuda = UntypedStorage_cuda
|
||||
torch.empty = torch_empty
|
||||
torch.randn = torch_randn
|
||||
torch.ones = torch_ones
|
||||
torch.zeros = torch_zeros
|
||||
torch.linspace = torch_linspace
|
||||
torch.Generator = torch_Generator
|
||||
torch.load = torch_load
|
||||
|
||||
CondFunc('torch.Generator',
|
||||
lambda orig_func, device=None: torch.xpu.Generator(return_xpu(device)),
|
||||
lambda orig_func, device=None: device is not None and device != torch.device("cpu") and device != "cpu")
|
||||
|
||||
# TiledVAE and ControlNet:
|
||||
CondFunc('torch.batch_norm',
|
||||
lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input,
|
||||
weight if weight is not None else torch.ones(input.size()[1], device=input.device),
|
||||
bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs),
|
||||
lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu"))
|
||||
CondFunc('torch.instance_norm',
|
||||
lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input,
|
||||
weight if weight is not None else torch.ones(input.size()[1], device=input.device),
|
||||
bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs),
|
||||
lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu"))
|
||||
|
||||
# Functions with dtype errors:
|
||||
CondFunc('torch.nn.modules.GroupNorm.forward',
|
||||
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
|
||||
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
|
||||
# Training:
|
||||
CondFunc('torch.nn.modules.linear.Linear.forward',
|
||||
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
|
||||
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
|
||||
CondFunc('torch.nn.modules.conv.Conv2d.forward',
|
||||
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
|
||||
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
|
||||
# BF16:
|
||||
CondFunc('torch.nn.functional.layer_norm',
|
||||
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
|
||||
orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs),
|
||||
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
|
||||
weight is not None and input.dtype != weight.data.dtype)
|
||||
# SwinIR BF16:
|
||||
CondFunc('torch.nn.functional.pad',
|
||||
lambda orig_func, input, pad, mode='constant', value=None: orig_func(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16),
|
||||
lambda orig_func, input, pad, mode='constant', value=None: mode == 'reflect' and input.dtype == torch.bfloat16)
|
||||
|
||||
# Diffusers Float64 (Alchemist GPUs doesn't support 64 bit):
|
||||
if not torch.xpu.has_fp64_dtype():
|
||||
CondFunc('torch.from_numpy',
|
||||
lambda orig_func, ndarray: orig_func(ndarray.astype('float32')),
|
||||
lambda orig_func, ndarray: ndarray.dtype == float)
|
||||
|
||||
# Broken functions when torch.cuda.is_available is True:
|
||||
# Pin Memory:
|
||||
CondFunc('torch.utils.data.dataloader._BaseDataLoaderIter.__init__',
|
||||
lambda orig_func, *args, **kwargs: ipex_no_cuda(orig_func, *args, **kwargs),
|
||||
lambda orig_func, *args, **kwargs: True)
|
||||
|
||||
# Functions that make compile mad with CondFunc:
|
||||
torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers = _shutdown_workers
|
||||
torch.nn.DataParallel = DummyDataParallel
|
||||
torch.autocast = ipex_autocast
|
||||
torch.cat = torch_cat
|
||||
torch.linalg.solve = linalg_solve
|
||||
torch.UntypedStorage.is_cuda = is_cuda
|
||||
torch.nn.functional.interpolate = interpolate
|
||||
torch.backends.cuda.sdp_kernel = return_null_context
|
||||
torch.nn.DataParallel = DummyDataParallel
|
||||
torch.UntypedStorage.is_cuda = is_cuda
|
||||
torch.autocast = ipex_autocast
|
||||
|
||||
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention
|
||||
torch.nn.functional.group_norm = functional_group_norm
|
||||
torch.nn.functional.layer_norm = functional_layer_norm
|
||||
torch.nn.functional.linear = functional_linear
|
||||
torch.nn.functional.conv2d = functional_conv2d
|
||||
torch.nn.functional.interpolate = interpolate
|
||||
torch.nn.functional.pad = functional_pad
|
||||
|
||||
torch.bmm = torch_bmm
|
||||
torch.cat = torch_cat
|
||||
if not torch.xpu.has_fp64_dtype():
|
||||
torch.from_numpy = from_numpy
|
||||
|
||||
@@ -9,7 +9,7 @@ import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
from packaging import version
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
import diffusers
|
||||
from diffusers import SchedulerMixin, StableDiffusionPipeline
|
||||
@@ -520,6 +520,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
image_encoder: CLIPVisionModelWithProjection = None,
|
||||
clip_skip: int = 1,
|
||||
):
|
||||
super().__init__(
|
||||
@@ -531,32 +532,11 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
requires_safety_checker=requires_safety_checker,
|
||||
image_encoder=image_encoder,
|
||||
)
|
||||
self.clip_skip = clip_skip
|
||||
self.custom_clip_skip = clip_skip
|
||||
self.__init__additional__()
|
||||
|
||||
# else:
|
||||
# def __init__(
|
||||
# self,
|
||||
# vae: AutoencoderKL,
|
||||
# text_encoder: CLIPTextModel,
|
||||
# tokenizer: CLIPTokenizer,
|
||||
# unet: UNet2DConditionModel,
|
||||
# scheduler: SchedulerMixin,
|
||||
# safety_checker: StableDiffusionSafetyChecker,
|
||||
# feature_extractor: CLIPFeatureExtractor,
|
||||
# ):
|
||||
# super().__init__(
|
||||
# vae=vae,
|
||||
# text_encoder=text_encoder,
|
||||
# tokenizer=tokenizer,
|
||||
# unet=unet,
|
||||
# scheduler=scheduler,
|
||||
# safety_checker=safety_checker,
|
||||
# feature_extractor=feature_extractor,
|
||||
# )
|
||||
# self.__init__additional__()
|
||||
|
||||
def __init__additional__(self):
|
||||
if not hasattr(self, "vae_scale_factor"):
|
||||
setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1))
|
||||
@@ -624,7 +604,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
||||
prompt=prompt,
|
||||
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
|
||||
max_embeddings_multiples=max_embeddings_multiples,
|
||||
clip_skip=self.clip_skip,
|
||||
clip_skip=self.custom_clip_skip,
|
||||
)
|
||||
bs_embed, seq_len, _ = text_embeddings.shape
|
||||
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
|
||||
|
||||
@@ -4,10 +4,13 @@
|
||||
import math
|
||||
import os
|
||||
import torch
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
@@ -571,9 +574,9 @@ def convert_ldm_clip_checkpoint_v1(checkpoint):
|
||||
if key.startswith("cond_stage_model.transformer"):
|
||||
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
|
||||
|
||||
# support checkpoint without position_ids (invalid checkpoint)
|
||||
if "text_model.embeddings.position_ids" not in text_model_dict:
|
||||
text_model_dict["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0) # 77 is the max length of the text
|
||||
# remove position_ids for newer transformer, which causes error :(
|
||||
if "text_model.embeddings.position_ids" in text_model_dict:
|
||||
text_model_dict.pop("text_model.embeddings.position_ids")
|
||||
|
||||
return text_model_dict
|
||||
|
||||
@@ -1307,19 +1310,19 @@ def load_vae(vae_id, dtype):
|
||||
|
||||
def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
|
||||
max_width, max_height = max_reso
|
||||
max_area = (max_width // divisible) * (max_height // divisible)
|
||||
max_area = max_width * max_height
|
||||
|
||||
resos = set()
|
||||
|
||||
size = int(math.sqrt(max_area)) * divisible
|
||||
resos.add((size, size))
|
||||
width = int(math.sqrt(max_area) // divisible) * divisible
|
||||
resos.add((width, width))
|
||||
|
||||
size = min_size
|
||||
while size <= max_size:
|
||||
width = size
|
||||
height = min(max_size, (max_area // (width // divisible)) * divisible)
|
||||
resos.add((width, height))
|
||||
resos.add((height, width))
|
||||
width = min_size
|
||||
while width <= max_size:
|
||||
height = min(max_size, int((max_area // width) // divisible) * divisible)
|
||||
if height >= min_size:
|
||||
resos.add((width, height))
|
||||
resos.add((height, width))
|
||||
|
||||
# # make additional resos
|
||||
# if width >= height and width - divisible >= min_size:
|
||||
@@ -1329,7 +1332,7 @@ def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64)
|
||||
# resos.add((width, height - divisible))
|
||||
# resos.add((height - divisible, width))
|
||||
|
||||
size += divisible
|
||||
width += divisible
|
||||
|
||||
resos = list(resos)
|
||||
resos.sort()
|
||||
|
||||
@@ -100,7 +100,7 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
|
||||
key = key.replace(".ln_final", ".final_layer_norm")
|
||||
# ckpt from comfy has this key: text_model.encoder.text_model.embeddings.position_ids
|
||||
elif ".embeddings.position_ids" in key:
|
||||
key = None # remove this key: make position_ids by ourselves
|
||||
key = None # remove this key: position_ids is not used in newer transformers
|
||||
return key
|
||||
|
||||
keys = list(checkpoint.keys())
|
||||
@@ -126,10 +126,6 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
|
||||
new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
|
||||
new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
|
||||
|
||||
# original SD にはないので、position_idsを追加
|
||||
position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
|
||||
new_sd["text_model.embeddings.position_ids"] = position_ids
|
||||
|
||||
# logit_scale はDiffusersには含まれないが、保存時に戻したいので別途返す
|
||||
logit_scale = checkpoint.get(SDXL_KEY_PREFIX + "logit_scale", None)
|
||||
|
||||
@@ -265,9 +261,9 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty
|
||||
elif k.startswith("conditioner.embedders.1.model."):
|
||||
te2_sd[k] = state_dict.pop(k)
|
||||
|
||||
# 一部のposition_idsがないモデルへの対応 / add position_ids for some models
|
||||
if "text_model.embeddings.position_ids" not in te1_sd:
|
||||
te1_sd["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0)
|
||||
# 最新の transformers では position_ids を含むとエラーになるので削除 / remove position_ids for latest transformers
|
||||
if "text_model.embeddings.position_ids" in te1_sd:
|
||||
te1_sd.pop("text_model.embeddings.position_ids")
|
||||
|
||||
info1 = _load_state_dict_on_device(text_model1, te1_sd, device=map_location) # remain fp32
|
||||
print("text encoder 1:", info1)
|
||||
|
||||
@@ -19,7 +19,7 @@ from typing import (
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
from accelerate import Accelerator, InitProcessGroupKwargs
|
||||
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs
|
||||
import gc
|
||||
import glob
|
||||
import math
|
||||
@@ -351,6 +351,7 @@ class BaseSubset:
|
||||
shuffle_caption: bool,
|
||||
caption_separator: str,
|
||||
keep_tokens: int,
|
||||
keep_tokens_separator: str,
|
||||
color_aug: bool,
|
||||
flip_aug: bool,
|
||||
face_crop_aug_range: Optional[Tuple[float, float]],
|
||||
@@ -368,6 +369,7 @@ class BaseSubset:
|
||||
self.shuffle_caption = shuffle_caption
|
||||
self.caption_separator = caption_separator
|
||||
self.keep_tokens = keep_tokens
|
||||
self.keep_tokens_separator = keep_tokens_separator
|
||||
self.color_aug = color_aug
|
||||
self.flip_aug = flip_aug
|
||||
self.face_crop_aug_range = face_crop_aug_range
|
||||
@@ -395,6 +397,7 @@ class DreamBoothSubset(BaseSubset):
|
||||
shuffle_caption,
|
||||
caption_separator: str,
|
||||
keep_tokens,
|
||||
keep_tokens_separator,
|
||||
color_aug,
|
||||
flip_aug,
|
||||
face_crop_aug_range,
|
||||
@@ -415,6 +418,7 @@ class DreamBoothSubset(BaseSubset):
|
||||
shuffle_caption,
|
||||
caption_separator,
|
||||
keep_tokens,
|
||||
keep_tokens_separator,
|
||||
color_aug,
|
||||
flip_aug,
|
||||
face_crop_aug_range,
|
||||
@@ -449,6 +453,7 @@ class FineTuningSubset(BaseSubset):
|
||||
shuffle_caption,
|
||||
caption_separator,
|
||||
keep_tokens,
|
||||
keep_tokens_separator,
|
||||
color_aug,
|
||||
flip_aug,
|
||||
face_crop_aug_range,
|
||||
@@ -469,6 +474,7 @@ class FineTuningSubset(BaseSubset):
|
||||
shuffle_caption,
|
||||
caption_separator,
|
||||
keep_tokens,
|
||||
keep_tokens_separator,
|
||||
color_aug,
|
||||
flip_aug,
|
||||
face_crop_aug_range,
|
||||
@@ -500,6 +506,7 @@ class ControlNetSubset(BaseSubset):
|
||||
shuffle_caption,
|
||||
caption_separator,
|
||||
keep_tokens,
|
||||
keep_tokens_separator,
|
||||
color_aug,
|
||||
flip_aug,
|
||||
face_crop_aug_range,
|
||||
@@ -520,6 +527,7 @@ class ControlNetSubset(BaseSubset):
|
||||
shuffle_caption,
|
||||
caption_separator,
|
||||
keep_tokens,
|
||||
keep_tokens_separator,
|
||||
color_aug,
|
||||
flip_aug,
|
||||
face_crop_aug_range,
|
||||
@@ -654,15 +662,33 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
caption = ""
|
||||
else:
|
||||
if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
|
||||
tokens = [t.strip() for t in caption.strip().split(subset.caption_separator)]
|
||||
fixed_tokens = []
|
||||
flex_tokens = []
|
||||
if (
|
||||
hasattr(subset, "keep_tokens_separator")
|
||||
and subset.keep_tokens_separator
|
||||
and subset.keep_tokens_separator in caption
|
||||
):
|
||||
fixed_part, flex_part = caption.split(subset.keep_tokens_separator, 1)
|
||||
fixed_tokens = [t.strip() for t in fixed_part.split(subset.caption_separator) if t.strip()]
|
||||
flex_tokens = [t.strip() for t in flex_part.split(subset.caption_separator) if t.strip()]
|
||||
else:
|
||||
tokens = [t.strip() for t in caption.strip().split(subset.caption_separator)]
|
||||
flex_tokens = tokens[:]
|
||||
if subset.keep_tokens > 0:
|
||||
fixed_tokens = flex_tokens[: subset.keep_tokens]
|
||||
flex_tokens = tokens[subset.keep_tokens :]
|
||||
|
||||
if subset.token_warmup_step < 1: # 初回に上書きする
|
||||
subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps)
|
||||
if subset.token_warmup_step and self.current_step < subset.token_warmup_step:
|
||||
tokens_len = (
|
||||
math.floor((self.current_step) * ((len(tokens) - subset.token_warmup_min) / (subset.token_warmup_step)))
|
||||
math.floor(
|
||||
(self.current_step) * ((len(flex_tokens) - subset.token_warmup_min) / (subset.token_warmup_step))
|
||||
)
|
||||
+ subset.token_warmup_min
|
||||
)
|
||||
tokens = tokens[:tokens_len]
|
||||
flex_tokens = flex_tokens[:tokens_len]
|
||||
|
||||
def dropout_tags(tokens):
|
||||
if subset.caption_tag_dropout_rate <= 0:
|
||||
@@ -673,12 +699,6 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
l.append(token)
|
||||
return l
|
||||
|
||||
fixed_tokens = []
|
||||
flex_tokens = tokens[:]
|
||||
if subset.keep_tokens > 0:
|
||||
fixed_tokens = flex_tokens[: subset.keep_tokens]
|
||||
flex_tokens = tokens[subset.keep_tokens :]
|
||||
|
||||
if subset.shuffle_caption:
|
||||
random.shuffle(flex_tokens)
|
||||
|
||||
@@ -1724,6 +1744,7 @@ class ControlNetDataset(BaseDataset):
|
||||
subset.shuffle_caption,
|
||||
subset.caption_separator,
|
||||
subset.keep_tokens,
|
||||
subset.keep_tokens_separator,
|
||||
subset.color_aug,
|
||||
subset.flip_aug,
|
||||
subset.face_crop_aug_range,
|
||||
@@ -2827,6 +2848,17 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
action="store_true",
|
||||
help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う",
|
||||
)
|
||||
parser.add_argument("--torch_compile", action="store_true", help="use torch.compile (requires PyTorch 2.0) / torch.compile を使う")
|
||||
parser.add_argument(
|
||||
"--dynamo_backend",
|
||||
type=str,
|
||||
default="inductor",
|
||||
# available backends:
|
||||
# https://github.com/huggingface/accelerate/blob/d1abd59114ada8ba673e1214218cb2878c13b82d/src/accelerate/utils/dataclasses.py#L376-L388C5
|
||||
# https://pytorch.org/docs/stable/torch.compiler.html
|
||||
choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt"],
|
||||
help="dynamo backend type (default is inductor) / dynamoのbackendの種類(デフォルトは inductor)"
|
||||
)
|
||||
parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
|
||||
parser.add_argument(
|
||||
"--sdpa",
|
||||
@@ -2878,6 +2910,16 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
default=None,
|
||||
help="DDP timeout (min, None for default of accelerate) / DDPのタイムアウト(分、Noneでaccelerateのデフォルト)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ddp_gradient_as_bucket_view",
|
||||
action="store_true",
|
||||
help="enable gradient_as_bucket_view for DDP / DDPでgradient_as_bucket_viewを有効にする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ddp_static_graph",
|
||||
action="store_true",
|
||||
help="enable static_graph for DDP / DDPでstatic_graphを有効にする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clip_skip",
|
||||
type=int,
|
||||
@@ -3131,6 +3173,13 @@ def add_dataset_arguments(
|
||||
default=0,
|
||||
help="keep heading N tokens when shuffling caption tokens (token means comma separated strings) / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す(トークンはカンマ区切りの各部分を意味する)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--keep_tokens_separator",
|
||||
type=str,
|
||||
default="",
|
||||
help="A custom separator to divide the caption into fixed and flexible parts. Tokens before this separator will not be shuffled. If not specified, '--keep_tokens' will be used to determine the fixed number of tokens."
|
||||
+ " / captionを固定部分と可変部分に分けるためのカスタム区切り文字。この区切り文字より前のトークンはシャッフルされない。指定しない場合、'--keep_tokens'が固定部分のトークン数として使用される。",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--caption_prefix",
|
||||
type=str,
|
||||
@@ -3831,16 +3880,26 @@ def prepare_accelerator(args: argparse.Namespace):
|
||||
os.environ["WANDB_DIR"] = logging_dir
|
||||
if args.wandb_api_key is not None:
|
||||
wandb.login(key=args.wandb_api_key)
|
||||
|
||||
# torch.compile のオプション。 NO の場合は torch.compile は使わない
|
||||
dynamo_backend = "NO"
|
||||
if args.torch_compile:
|
||||
dynamo_backend = args.dynamo_backend
|
||||
|
||||
kwargs_handlers = (
|
||||
None if args.ddp_timeout is None else [InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout))]
|
||||
InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout)) if args.ddp_timeout else None,
|
||||
DistributedDataParallelKwargs(gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph)
|
||||
if args.ddp_gradient_as_bucket_view or args.ddp_static_graph
|
||||
else None,
|
||||
)
|
||||
kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers))
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
mixed_precision=args.mixed_precision,
|
||||
log_with=log_with,
|
||||
project_dir=logging_dir,
|
||||
kwargs_handlers=kwargs_handlers,
|
||||
dynamo_backend=dynamo_backend,
|
||||
)
|
||||
return accelerator
|
||||
|
||||
@@ -4561,7 +4620,7 @@ def line_to_prompt_dict(line: str) -> dict:
|
||||
|
||||
def sample_images_common(
|
||||
pipe_class,
|
||||
accelerator,
|
||||
accelerator: Accelerator,
|
||||
args: argparse.Namespace,
|
||||
epoch,
|
||||
steps,
|
||||
@@ -4598,6 +4657,13 @@ def sample_images_common(
|
||||
org_vae_device = vae.device # CPUにいるはず
|
||||
vae.to(device)
|
||||
|
||||
# unwrap unet and text_encoder(s)
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
if isinstance(text_encoder, (list, tuple)):
|
||||
text_encoder = [accelerator.unwrap_model(te) for te in text_encoder]
|
||||
else:
|
||||
text_encoder = accelerator.unwrap_model(text_encoder)
|
||||
|
||||
# read prompts
|
||||
|
||||
# with open(args.sample_prompts, "rt", encoding="utf-8") as f:
|
||||
|
||||
Reference in New Issue
Block a user