mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Merge branch 'dev' into gradual_latent_hires_fix
This commit is contained in:
2
.github/workflows/typos.yml
vendored
2
.github/workflows/typos.yml
vendored
@@ -18,4 +18,4 @@ jobs:
|
|||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: typos-action
|
- name: typos-action
|
||||||
uses: crate-ci/typos@v1.16.15
|
uses: crate-ci/typos@v1.16.26
|
||||||
|
|||||||
35
README.md
35
README.md
@@ -281,6 +281,41 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum
|
|||||||
|
|
||||||
## Change History
|
## Change History
|
||||||
|
|
||||||
|
### Dec 24, 2023 / 2023/12/24
|
||||||
|
|
||||||
|
- Fixed to work `tools/convert_diffusers20_original_sd.py`. Thanks to Disty0! PR [#1016](https://github.com/kohya-ss/sd-scripts/pull/1016)
|
||||||
|
|
||||||
|
- `tools/convert_diffusers20_original_sd.py` が動かなくなっていたのが修正されました。Disty0 氏に感謝します。 PR [#1016](https://github.com/kohya-ss/sd-scripts/pull/1016)
|
||||||
|
|
||||||
|
|
||||||
|
### Dec 21, 2023 / 2023/12/21
|
||||||
|
|
||||||
|
- The issues in multi-GPU training are fixed. Thanks to Isotr0py! PR [#989](https://github.com/kohya-ss/sd-scripts/pull/989) and [#1000](https://github.com/kohya-ss/sd-scripts/pull/1000)
|
||||||
|
- `--ddp_gradient_as_bucket_view` and `--ddp_bucket_view`options are added to `sdxl_train.py`. Please specify these options for multi-GPU training.
|
||||||
|
- IPEX support is updated. Thanks to Disty0!
|
||||||
|
- Fixed the bug that the size of the bucket becomes less than `min_bucket_reso`. Thanks to Cauldrath! PR [#1008](https://github.com/kohya-ss/sd-scripts/pull/1008)
|
||||||
|
- `--sample_at_first` option is added to each training script. This option is useful to generate images at the first step, before training. Thanks to shirayu! PR [#907](https://github.com/kohya-ss/sd-scripts/pull/907)
|
||||||
|
- `--ss` option is added to the sampling prompt in training. You can specify the scheduler for the sampling like `--ss euler_a`. Thanks to shirayu! PR [#906](https://github.com/kohya-ss/sd-scripts/pull/906)
|
||||||
|
- `keep_tokens_separator` is added to the dataset config. This option is useful to keep (prevent from shuffling) the tokens in the captions. See [#975](https://github.com/kohya-ss/sd-scripts/pull/975) for details. Thanks to Linaqruf!
|
||||||
|
- You can specify the separator with an option like `--keep_tokens_separator "|||"` or with `keep_tokens_separator: "|||"` in `.toml`. The tokens before `|||` are not shuffled.
|
||||||
|
- Attention processor hook is added. See [#961](https://github.com/kohya-ss/sd-scripts/pull/961) for details. Thanks to rockerBOO!
|
||||||
|
- The optimizer `PagedAdamW` is added. Thanks to xzuyn! PR [#955](https://github.com/kohya-ss/sd-scripts/pull/955)
|
||||||
|
- NaN replacement in SDXL VAE is sped up. Thanks to liubo0902! PR [#1009](https://github.com/kohya-ss/sd-scripts/pull/1009)
|
||||||
|
- Fixed the path error in `finetune/make_captions.py`. Thanks to CjangCjengh! PR [#986](https://github.com/kohya-ss/sd-scripts/pull/986)
|
||||||
|
|
||||||
|
- マルチGPUでの学習の不具合を修正しました。Isotr0py 氏に感謝します。 PR [#989](https://github.com/kohya-ss/sd-scripts/pull/989) および [#1000](https://github.com/kohya-ss/sd-scripts/pull/1000)
|
||||||
|
- `sdxl_train.py` に `--ddp_gradient_as_bucket_view` と `--ddp_bucket_view` オプションが追加されました。マルチGPUでの学習時にはこれらのオプションを指定してください。
|
||||||
|
- IPEX サポートが更新されました。Disty0 氏に感謝します。
|
||||||
|
- Aspect Ratio Bucketing で bucket のサイズが `min_bucket_reso` 未満になる不具合を修正しました。Cauldrath 氏に感謝します。 PR [#1008](https://github.com/kohya-ss/sd-scripts/pull/1008)
|
||||||
|
- 各学習スクリプトに `--sample_at_first` オプションが追加されました。学習前に画像を生成することで、学習結果が比較しやすくなります。shirayu 氏に感謝します。 PR [#907](https://github.com/kohya-ss/sd-scripts/pull/907)
|
||||||
|
- 学習時のプロンプトに `--ss` オプションが追加されました。`--ss euler_a` のようにスケジューラを指定できます。shirayu 氏に感謝します。 PR [#906](https://github.com/kohya-ss/sd-scripts/pull/906)
|
||||||
|
- データセット設定に `keep_tokens_separator` が追加されました。キャプション内のトークンをどの位置までシャッフルしないかを指定できます。詳細は [#975](https://github.com/kohya-ss/sd-scripts/pull/975) を参照してください。Linaqruf 氏に感謝します。
|
||||||
|
- オプションで `--keep_tokens_separator "|||"` のように指定するか、`.toml` で `keep_tokens_separator: "|||"` のように指定します。`|||` の前のトークンはシャッフルされません。
|
||||||
|
- Attention processor hook が追加されました。詳細は [#961](https://github.com/kohya-ss/sd-scripts/pull/961) を参照してください。rockerBOO 氏に感謝します。
|
||||||
|
- オプティマイザ `PagedAdamW` が追加されました。xzuyn 氏に感謝します。 PR [#955](https://github.com/kohya-ss/sd-scripts/pull/955)
|
||||||
|
- 学習時、SDXL VAE で NaN が発生した時の置き換えが高速化されました。liubo0902 氏に感謝します。 PR [#1009](https://github.com/kohya-ss/sd-scripts/pull/1009)
|
||||||
|
- `finetune/make_captions.py` で相対パス指定時のエラーが修正されました。CjangCjengh 氏に感謝します。 PR [#986](https://github.com/kohya-ss/sd-scripts/pull/986)
|
||||||
|
|
||||||
### Dec 3, 2023 / 2023/12/3
|
### Dec 3, 2023 / 2023/12/3
|
||||||
|
|
||||||
- `finetune\tag_images_by_wd14_tagger.py` now supports the separator other than `,` with `--caption_separator` option. Thanks to KohakuBlueleaf! PR [#913](https://github.com/kohya-ss/sd-scripts/pull/913)
|
- `finetune\tag_images_by_wd14_tagger.py` now supports the separator other than `,` with `--caption_separator` option. Thanks to KohakuBlueleaf! PR [#913](https://github.com/kohya-ss/sd-scripts/pull/913)
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import torch
|
|||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from torchvision.transforms.functional import InterpolationMode
|
from torchvision.transforms.functional import InterpolationMode
|
||||||
sys.path.append(os.path.dirname(__file__))
|
sys.path.append(os.path.dirname(__file__))
|
||||||
from blip.blip import blip_decoder
|
from blip.blip import blip_decoder, is_url
|
||||||
import library.train_util as train_util
|
import library.train_util as train_util
|
||||||
|
|
||||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
@@ -76,6 +76,8 @@ def main(args):
|
|||||||
cwd = os.getcwd()
|
cwd = os.getcwd()
|
||||||
print("Current Working Directory is: ", cwd)
|
print("Current Working Directory is: ", cwd)
|
||||||
os.chdir("finetune")
|
os.chdir("finetune")
|
||||||
|
if not is_url(args.caption_weights) and not os.path.isfile(args.caption_weights):
|
||||||
|
args.caption_weights = os.path.join("..", args.caption_weights)
|
||||||
|
|
||||||
print(f"load images from {args.train_data_dir}")
|
print(f"load images from {args.train_data_dir}")
|
||||||
train_data_dir_path = Path(args.train_data_dir)
|
train_data_dir_path = Path(args.train_data_dir)
|
||||||
|
|||||||
@@ -53,6 +53,7 @@ class BaseSubsetParams:
|
|||||||
shuffle_caption: bool = False
|
shuffle_caption: bool = False
|
||||||
caption_separator: str = ',',
|
caption_separator: str = ',',
|
||||||
keep_tokens: int = 0
|
keep_tokens: int = 0
|
||||||
|
keep_tokens_separator: str = None,
|
||||||
color_aug: bool = False
|
color_aug: bool = False
|
||||||
flip_aug: bool = False
|
flip_aug: bool = False
|
||||||
face_crop_aug_range: Optional[Tuple[float, float]] = None
|
face_crop_aug_range: Optional[Tuple[float, float]] = None
|
||||||
@@ -160,6 +161,7 @@ class ConfigSanitizer:
|
|||||||
"random_crop": bool,
|
"random_crop": bool,
|
||||||
"shuffle_caption": bool,
|
"shuffle_caption": bool,
|
||||||
"keep_tokens": int,
|
"keep_tokens": int,
|
||||||
|
"keep_tokens_separator": str,
|
||||||
"token_warmup_min": int,
|
"token_warmup_min": int,
|
||||||
"token_warmup_step": Any(float,int),
|
"token_warmup_step": Any(float,int),
|
||||||
"caption_prefix": str,
|
"caption_prefix": str,
|
||||||
@@ -461,6 +463,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
|||||||
num_repeats: {subset.num_repeats}
|
num_repeats: {subset.num_repeats}
|
||||||
shuffle_caption: {subset.shuffle_caption}
|
shuffle_caption: {subset.shuffle_caption}
|
||||||
keep_tokens: {subset.keep_tokens}
|
keep_tokens: {subset.keep_tokens}
|
||||||
|
keep_tokens_separator: {subset.keep_tokens_separator}
|
||||||
caption_dropout_rate: {subset.caption_dropout_rate}
|
caption_dropout_rate: {subset.caption_dropout_rate}
|
||||||
caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
|
caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
|
||||||
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
|
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
|
||||||
|
|||||||
@@ -140,6 +140,7 @@ def ipex_init(): # pylint: disable=too-many-statements
|
|||||||
|
|
||||||
# C
|
# C
|
||||||
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
|
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
|
||||||
|
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_eu_count
|
||||||
ipex._C._DeviceProperties.major = 2023
|
ipex._C._DeviceProperties.major = 2023
|
||||||
ipex._C._DeviceProperties.minor = 2
|
ipex._C._DeviceProperties.minor = 2
|
||||||
|
|
||||||
@@ -156,20 +157,9 @@ def ipex_init(): # pylint: disable=too-many-statements
|
|||||||
torch.cuda.get_device_properties.minor = 7
|
torch.cuda.get_device_properties.minor = 7
|
||||||
torch.cuda.ipc_collect = lambda *args, **kwargs: None
|
torch.cuda.ipc_collect = lambda *args, **kwargs: None
|
||||||
torch.cuda.utilization = lambda *args, **kwargs: 0
|
torch.cuda.utilization = lambda *args, **kwargs: 0
|
||||||
if hasattr(torch.xpu, 'getDeviceIdListForCard'):
|
|
||||||
torch.cuda.getDeviceIdListForCard = torch.xpu.getDeviceIdListForCard
|
|
||||||
torch.cuda.get_device_id_list_per_card = torch.xpu.getDeviceIdListForCard
|
|
||||||
else:
|
|
||||||
torch.cuda.getDeviceIdListForCard = torch.xpu.get_device_id_list_per_card
|
|
||||||
torch.cuda.get_device_id_list_per_card = torch.xpu.get_device_id_list_per_card
|
|
||||||
|
|
||||||
ipex_hijacks()
|
ipex_hijacks()
|
||||||
if not torch.xpu.has_fp64_dtype():
|
if not torch.xpu.has_fp64_dtype():
|
||||||
try:
|
|
||||||
from .attention import attention_init
|
|
||||||
attention_init()
|
|
||||||
except Exception: # pylint: disable=broad-exception-caught
|
|
||||||
pass
|
|
||||||
try:
|
try:
|
||||||
from .diffusers import ipex_diffusers
|
from .diffusers import ipex_diffusers
|
||||||
ipex_diffusers()
|
ipex_diffusers()
|
||||||
|
|||||||
@@ -1,45 +1,98 @@
|
|||||||
|
import os
|
||||||
import torch
|
import torch
|
||||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||||
|
from functools import cache
|
||||||
|
|
||||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||||
|
|
||||||
original_torch_bmm = torch.bmm
|
# ARC GPUs can't allocate more than 4GB to a single block so we slice the attetion layers
|
||||||
def torch_bmm(input, mat2, *, out=None):
|
|
||||||
if input.dtype != mat2.dtype:
|
|
||||||
mat2 = mat2.to(input.dtype)
|
|
||||||
|
|
||||||
#ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 4))
|
||||||
batch_size_attention, input_tokens, mat2_shape = input.shape[0], input.shape[1], mat2.shape[2]
|
attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))
|
||||||
block_multiply = input.element_size()
|
|
||||||
slice_block_size = input_tokens * mat2_shape / 1024 / 1024 * block_multiply
|
# Find something divisible with the input_tokens
|
||||||
|
@cache
|
||||||
|
def find_slice_size(slice_size, slice_block_size):
|
||||||
|
while (slice_size * slice_block_size) > attention_slice_rate:
|
||||||
|
slice_size = slice_size // 2
|
||||||
|
if slice_size <= 1:
|
||||||
|
slice_size = 1
|
||||||
|
break
|
||||||
|
return slice_size
|
||||||
|
|
||||||
|
# Find slice sizes for SDPA
|
||||||
|
@cache
|
||||||
|
def find_sdpa_slice_sizes(query_shape, query_element_size):
|
||||||
|
if len(query_shape) == 3:
|
||||||
|
batch_size_attention, query_tokens, shape_three = query_shape
|
||||||
|
shape_four = 1
|
||||||
|
else:
|
||||||
|
batch_size_attention, query_tokens, shape_three, shape_four = query_shape
|
||||||
|
|
||||||
|
slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size
|
||||||
block_size = batch_size_attention * slice_block_size
|
block_size = batch_size_attention * slice_block_size
|
||||||
|
|
||||||
split_slice_size = batch_size_attention
|
split_slice_size = batch_size_attention
|
||||||
if block_size > 4:
|
split_2_slice_size = query_tokens
|
||||||
|
split_3_slice_size = shape_three
|
||||||
|
|
||||||
|
do_split = False
|
||||||
|
do_split_2 = False
|
||||||
|
do_split_3 = False
|
||||||
|
|
||||||
|
if block_size > sdpa_slice_trigger_rate:
|
||||||
do_split = True
|
do_split = True
|
||||||
#Find something divisible with the input_tokens
|
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
|
||||||
while (split_slice_size * slice_block_size) > 4:
|
if split_slice_size * slice_block_size > attention_slice_rate:
|
||||||
split_slice_size = split_slice_size // 2
|
slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size
|
||||||
if split_slice_size <= 1:
|
do_split_2 = True
|
||||||
split_slice_size = 1
|
split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
|
||||||
break
|
if split_2_slice_size * slice_2_block_size > attention_slice_rate:
|
||||||
else:
|
slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size
|
||||||
do_split = False
|
do_split_3 = True
|
||||||
|
split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
|
||||||
|
|
||||||
|
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
|
||||||
|
|
||||||
|
# Find slice sizes for BMM
|
||||||
|
@cache
|
||||||
|
def find_bmm_slice_sizes(input_shape, input_element_size, mat2_shape):
|
||||||
|
batch_size_attention, input_tokens, mat2_atten_shape = input_shape[0], input_shape[1], mat2_shape[2]
|
||||||
|
slice_block_size = input_tokens * mat2_atten_shape / 1024 / 1024 * input_element_size
|
||||||
|
block_size = batch_size_attention * slice_block_size
|
||||||
|
|
||||||
|
split_slice_size = batch_size_attention
|
||||||
split_2_slice_size = input_tokens
|
split_2_slice_size = input_tokens
|
||||||
if split_slice_size * slice_block_size > 4:
|
split_3_slice_size = mat2_atten_shape
|
||||||
slice_block_size2 = split_slice_size * mat2_shape / 1024 / 1024 * block_multiply
|
|
||||||
do_split_2 = True
|
|
||||||
#Find something divisible with the input_tokens
|
|
||||||
while (split_2_slice_size * slice_block_size2) > 4:
|
|
||||||
split_2_slice_size = split_2_slice_size // 2
|
|
||||||
if split_2_slice_size <= 1:
|
|
||||||
split_2_slice_size = 1
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
do_split_2 = False
|
|
||||||
|
|
||||||
|
do_split = False
|
||||||
|
do_split_2 = False
|
||||||
|
do_split_3 = False
|
||||||
|
|
||||||
|
if block_size > attention_slice_rate:
|
||||||
|
do_split = True
|
||||||
|
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
|
||||||
|
if split_slice_size * slice_block_size > attention_slice_rate:
|
||||||
|
slice_2_block_size = split_slice_size * mat2_atten_shape / 1024 / 1024 * input_element_size
|
||||||
|
do_split_2 = True
|
||||||
|
split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
|
||||||
|
if split_2_slice_size * slice_2_block_size > attention_slice_rate:
|
||||||
|
slice_3_block_size = split_slice_size * split_2_slice_size / 1024 / 1024 * input_element_size
|
||||||
|
do_split_3 = True
|
||||||
|
split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
|
||||||
|
|
||||||
|
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
|
||||||
|
|
||||||
|
|
||||||
|
original_torch_bmm = torch.bmm
|
||||||
|
def torch_bmm_32_bit(input, mat2, *, out=None):
|
||||||
|
if input.device.type != "xpu":
|
||||||
|
return original_torch_bmm(input, mat2, out=out)
|
||||||
|
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_bmm_slice_sizes(input.shape, input.element_size(), mat2.shape)
|
||||||
|
|
||||||
|
# Slice BMM
|
||||||
if do_split:
|
if do_split:
|
||||||
|
batch_size_attention, input_tokens, mat2_atten_shape = input.shape[0], input.shape[1], mat2.shape[2]
|
||||||
hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype)
|
hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype)
|
||||||
for i in range(batch_size_attention // split_slice_size):
|
for i in range(batch_size_attention // split_slice_size):
|
||||||
start_idx = i * split_slice_size
|
start_idx = i * split_slice_size
|
||||||
@@ -48,11 +101,21 @@ def torch_bmm(input, mat2, *, out=None):
|
|||||||
for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
||||||
start_idx_2 = i2 * split_2_slice_size
|
start_idx_2 = i2 * split_2_slice_size
|
||||||
end_idx_2 = (i2 + 1) * split_2_slice_size
|
end_idx_2 = (i2 + 1) * split_2_slice_size
|
||||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm(
|
if do_split_3:
|
||||||
input[start_idx:end_idx, start_idx_2:end_idx_2],
|
for i3 in range(mat2_atten_shape // split_3_slice_size): # pylint: disable=invalid-name
|
||||||
mat2[start_idx:end_idx, start_idx_2:end_idx_2],
|
start_idx_3 = i3 * split_3_slice_size
|
||||||
out=out
|
end_idx_3 = (i3 + 1) * split_3_slice_size
|
||||||
)
|
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_torch_bmm(
|
||||||
|
input[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
||||||
|
mat2[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
||||||
|
out=out
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm(
|
||||||
|
input[start_idx:end_idx, start_idx_2:end_idx_2],
|
||||||
|
mat2[start_idx:end_idx, start_idx_2:end_idx_2],
|
||||||
|
out=out
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
hidden_states[start_idx:end_idx] = original_torch_bmm(
|
hidden_states[start_idx:end_idx] = original_torch_bmm(
|
||||||
input[start_idx:end_idx],
|
input[start_idx:end_idx],
|
||||||
@@ -64,51 +127,14 @@ def torch_bmm(input, mat2, *, out=None):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
||||||
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
|
def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
|
||||||
#ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
if query.device.type != "xpu":
|
||||||
if len(query.shape) == 3:
|
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
|
||||||
batch_size_attention, query_tokens, shape_four = query.shape
|
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_sdpa_slice_sizes(query.shape, query.element_size())
|
||||||
shape_one = 1
|
|
||||||
no_shape_one = True
|
|
||||||
else:
|
|
||||||
shape_one, batch_size_attention, query_tokens, shape_four = query.shape
|
|
||||||
no_shape_one = False
|
|
||||||
|
|
||||||
if query.dtype != key.dtype:
|
|
||||||
key = key.to(dtype=query.dtype)
|
|
||||||
if query.dtype != value.dtype:
|
|
||||||
value = value.to(dtype=query.dtype)
|
|
||||||
|
|
||||||
block_multiply = query.element_size()
|
|
||||||
slice_block_size = shape_one * query_tokens * shape_four / 1024 / 1024 * block_multiply
|
|
||||||
block_size = batch_size_attention * slice_block_size
|
|
||||||
|
|
||||||
split_slice_size = batch_size_attention
|
|
||||||
if block_size > 4:
|
|
||||||
do_split = True
|
|
||||||
#Find something divisible with the shape_one
|
|
||||||
while (split_slice_size * slice_block_size) > 4:
|
|
||||||
split_slice_size = split_slice_size // 2
|
|
||||||
if split_slice_size <= 1:
|
|
||||||
split_slice_size = 1
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
do_split = False
|
|
||||||
|
|
||||||
split_2_slice_size = query_tokens
|
|
||||||
if split_slice_size * slice_block_size > 4:
|
|
||||||
slice_block_size2 = shape_one * split_slice_size * shape_four / 1024 / 1024 * block_multiply
|
|
||||||
do_split_2 = True
|
|
||||||
#Find something divisible with the batch_size_attention
|
|
||||||
while (split_2_slice_size * slice_block_size2) > 4:
|
|
||||||
split_2_slice_size = split_2_slice_size // 2
|
|
||||||
if split_2_slice_size <= 1:
|
|
||||||
split_2_slice_size = 1
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
do_split_2 = False
|
|
||||||
|
|
||||||
|
# Slice SDPA
|
||||||
if do_split:
|
if do_split:
|
||||||
|
batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2]
|
||||||
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
|
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
|
||||||
for i in range(batch_size_attention // split_slice_size):
|
for i in range(batch_size_attention // split_slice_size):
|
||||||
start_idx = i * split_slice_size
|
start_idx = i * split_slice_size
|
||||||
@@ -117,7 +143,18 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.
|
|||||||
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
||||||
start_idx_2 = i2 * split_2_slice_size
|
start_idx_2 = i2 * split_2_slice_size
|
||||||
end_idx_2 = (i2 + 1) * split_2_slice_size
|
end_idx_2 = (i2 + 1) * split_2_slice_size
|
||||||
if no_shape_one:
|
if do_split_3:
|
||||||
|
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
|
||||||
|
start_idx_3 = i3 * split_3_slice_size
|
||||||
|
end_idx_3 = (i3 + 1) * split_3_slice_size
|
||||||
|
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_scaled_dot_product_attention(
|
||||||
|
query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
||||||
|
key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
||||||
|
value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
||||||
|
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attn_mask is not None else attn_mask,
|
||||||
|
dropout_p=dropout_p, is_causal=is_causal
|
||||||
|
)
|
||||||
|
else:
|
||||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention(
|
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention(
|
||||||
query[start_idx:end_idx, start_idx_2:end_idx_2],
|
query[start_idx:end_idx, start_idx_2:end_idx_2],
|
||||||
key[start_idx:end_idx, start_idx_2:end_idx_2],
|
key[start_idx:end_idx, start_idx_2:end_idx_2],
|
||||||
@@ -125,38 +162,14 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.
|
|||||||
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask,
|
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask,
|
||||||
dropout_p=dropout_p, is_causal=is_causal
|
dropout_p=dropout_p, is_causal=is_causal
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
hidden_states[:, start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention(
|
|
||||||
query[:, start_idx:end_idx, start_idx_2:end_idx_2],
|
|
||||||
key[:, start_idx:end_idx, start_idx_2:end_idx_2],
|
|
||||||
value[:, start_idx:end_idx, start_idx_2:end_idx_2],
|
|
||||||
attn_mask=attn_mask[:, start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask,
|
|
||||||
dropout_p=dropout_p, is_causal=is_causal
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
if no_shape_one:
|
hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention(
|
||||||
hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention(
|
query[start_idx:end_idx],
|
||||||
query[start_idx:end_idx],
|
key[start_idx:end_idx],
|
||||||
key[start_idx:end_idx],
|
value[start_idx:end_idx],
|
||||||
value[start_idx:end_idx],
|
attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask,
|
||||||
attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask,
|
dropout_p=dropout_p, is_causal=is_causal
|
||||||
dropout_p=dropout_p, is_causal=is_causal
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
hidden_states[:, start_idx:end_idx] = original_scaled_dot_product_attention(
|
|
||||||
query[:, start_idx:end_idx],
|
|
||||||
key[:, start_idx:end_idx],
|
|
||||||
value[:, start_idx:end_idx],
|
|
||||||
attn_mask=attn_mask[:, start_idx:end_idx] if attn_mask is not None else attn_mask,
|
|
||||||
dropout_p=dropout_p, is_causal=is_causal
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return original_scaled_dot_product_attention(
|
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
|
||||||
query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal
|
|
||||||
)
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def attention_init():
|
|
||||||
#ARC GPUs can't allocate more than 4GB to a single block:
|
|
||||||
torch.bmm = torch_bmm
|
|
||||||
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention
|
|
||||||
|
|||||||
@@ -1,10 +1,62 @@
|
|||||||
|
import os
|
||||||
import torch
|
import torch
|
||||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||||
import diffusers #0.24.0 # pylint: disable=import-error
|
import diffusers #0.24.0 # pylint: disable=import-error
|
||||||
from diffusers.models.attention_processor import Attention
|
from diffusers.models.attention_processor import Attention
|
||||||
|
from diffusers.utils import USE_PEFT_BACKEND
|
||||||
|
from functools import cache
|
||||||
|
|
||||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||||
|
|
||||||
|
attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def find_slice_size(slice_size, slice_block_size):
|
||||||
|
while (slice_size * slice_block_size) > attention_slice_rate:
|
||||||
|
slice_size = slice_size // 2
|
||||||
|
if slice_size <= 1:
|
||||||
|
slice_size = 1
|
||||||
|
break
|
||||||
|
return slice_size
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def find_attention_slice_sizes(query_shape, query_element_size, query_device_type, slice_size=None):
|
||||||
|
if len(query_shape) == 3:
|
||||||
|
batch_size_attention, query_tokens, shape_three = query_shape
|
||||||
|
shape_four = 1
|
||||||
|
else:
|
||||||
|
batch_size_attention, query_tokens, shape_three, shape_four = query_shape
|
||||||
|
if slice_size is not None:
|
||||||
|
batch_size_attention = slice_size
|
||||||
|
|
||||||
|
slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size
|
||||||
|
block_size = batch_size_attention * slice_block_size
|
||||||
|
|
||||||
|
split_slice_size = batch_size_attention
|
||||||
|
split_2_slice_size = query_tokens
|
||||||
|
split_3_slice_size = shape_three
|
||||||
|
|
||||||
|
do_split = False
|
||||||
|
do_split_2 = False
|
||||||
|
do_split_3 = False
|
||||||
|
|
||||||
|
if query_device_type != "xpu":
|
||||||
|
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
|
||||||
|
|
||||||
|
if block_size > attention_slice_rate:
|
||||||
|
do_split = True
|
||||||
|
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
|
||||||
|
if split_slice_size * slice_block_size > attention_slice_rate:
|
||||||
|
slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size
|
||||||
|
do_split_2 = True
|
||||||
|
split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
|
||||||
|
if split_2_slice_size * slice_2_block_size > attention_slice_rate:
|
||||||
|
slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size
|
||||||
|
do_split_3 = True
|
||||||
|
split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
|
||||||
|
|
||||||
|
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
|
||||||
|
|
||||||
class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
|
class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
|
||||||
r"""
|
r"""
|
||||||
Processor for implementing sliced attention.
|
Processor for implementing sliced attention.
|
||||||
@@ -18,7 +70,9 @@ class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
|
|||||||
def __init__(self, slice_size):
|
def __init__(self, slice_size):
|
||||||
self.slice_size = slice_size
|
self.slice_size = slice_size
|
||||||
|
|
||||||
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): # pylint: disable=too-many-statements, too-many-locals, too-many-branches
|
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor,
|
||||||
|
encoder_hidden_states=None, attention_mask=None) -> torch.FloatTensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches
|
||||||
|
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
input_ndim = hidden_states.ndim
|
input_ndim = hidden_states.ndim
|
||||||
@@ -54,49 +108,61 @@ class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
|
|||||||
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
|
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
#ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
####################################################################
|
||||||
block_multiply = query.element_size()
|
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
||||||
slice_block_size = self.slice_size * shape_three / 1024 / 1024 * block_multiply
|
_, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type, slice_size=self.slice_size)
|
||||||
block_size = query_tokens * slice_block_size
|
|
||||||
split_2_slice_size = query_tokens
|
|
||||||
if block_size > 4:
|
|
||||||
do_split_2 = True
|
|
||||||
#Find something divisible with the query_tokens
|
|
||||||
while (split_2_slice_size * slice_block_size) > 4:
|
|
||||||
split_2_slice_size = split_2_slice_size // 2
|
|
||||||
if split_2_slice_size <= 1:
|
|
||||||
split_2_slice_size = 1
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
do_split_2 = False
|
|
||||||
|
|
||||||
for i in range(batch_size_attention // self.slice_size):
|
|
||||||
start_idx = i * self.slice_size
|
|
||||||
end_idx = (i + 1) * self.slice_size
|
|
||||||
|
|
||||||
|
for i in range(batch_size_attention // split_slice_size):
|
||||||
|
start_idx = i * split_slice_size
|
||||||
|
end_idx = (i + 1) * split_slice_size
|
||||||
if do_split_2:
|
if do_split_2:
|
||||||
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
||||||
start_idx_2 = i2 * split_2_slice_size
|
start_idx_2 = i2 * split_2_slice_size
|
||||||
end_idx_2 = (i2 + 1) * split_2_slice_size
|
end_idx_2 = (i2 + 1) * split_2_slice_size
|
||||||
|
if do_split_3:
|
||||||
|
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
|
||||||
|
start_idx_3 = i3 * split_3_slice_size
|
||||||
|
end_idx_3 = (i3 + 1) * split_3_slice_size
|
||||||
|
|
||||||
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
|
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
|
||||||
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
|
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
|
||||||
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
|
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None
|
||||||
|
|
||||||
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
||||||
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
|
del query_slice
|
||||||
|
del key_slice
|
||||||
|
del attn_mask_slice
|
||||||
|
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3])
|
||||||
|
|
||||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
|
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice
|
||||||
|
del attn_slice
|
||||||
|
else:
|
||||||
|
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
|
||||||
|
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
|
||||||
|
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
|
||||||
|
|
||||||
|
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
||||||
|
del query_slice
|
||||||
|
del key_slice
|
||||||
|
del attn_mask_slice
|
||||||
|
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
|
||||||
|
|
||||||
|
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
|
||||||
|
del attn_slice
|
||||||
else:
|
else:
|
||||||
query_slice = query[start_idx:end_idx]
|
query_slice = query[start_idx:end_idx]
|
||||||
key_slice = key[start_idx:end_idx]
|
key_slice = key[start_idx:end_idx]
|
||||||
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
||||||
|
|
||||||
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
||||||
|
del query_slice
|
||||||
|
del key_slice
|
||||||
|
del attn_mask_slice
|
||||||
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
||||||
|
|
||||||
hidden_states[start_idx:end_idx] = attn_slice
|
hidden_states[start_idx:end_idx] = attn_slice
|
||||||
|
del attn_slice
|
||||||
|
####################################################################
|
||||||
|
|
||||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||||
|
|
||||||
@@ -115,6 +181,130 @@ class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
|
|||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class AttnProcessor:
|
||||||
|
r"""
|
||||||
|
Default processor for performing attention-related computations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor,
|
||||||
|
encoder_hidden_states=None, attention_mask=None,
|
||||||
|
temb=None, scale: float = 1.0) -> torch.Tensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
args = () if USE_PEFT_BACKEND else (scale,)
|
||||||
|
|
||||||
|
if attn.spatial_norm is not None:
|
||||||
|
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||||
|
|
||||||
|
input_ndim = hidden_states.ndim
|
||||||
|
|
||||||
|
if input_ndim == 4:
|
||||||
|
batch_size, channel, height, width = hidden_states.shape
|
||||||
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||||
|
|
||||||
|
batch_size, sequence_length, _ = (
|
||||||
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||||
|
)
|
||||||
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||||
|
|
||||||
|
if attn.group_norm is not None:
|
||||||
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||||
|
|
||||||
|
query = attn.to_q(hidden_states, *args)
|
||||||
|
|
||||||
|
if encoder_hidden_states is None:
|
||||||
|
encoder_hidden_states = hidden_states
|
||||||
|
elif attn.norm_cross:
|
||||||
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||||
|
|
||||||
|
key = attn.to_k(encoder_hidden_states, *args)
|
||||||
|
value = attn.to_v(encoder_hidden_states, *args)
|
||||||
|
|
||||||
|
query = attn.head_to_batch_dim(query)
|
||||||
|
key = attn.head_to_batch_dim(key)
|
||||||
|
value = attn.head_to_batch_dim(value)
|
||||||
|
|
||||||
|
####################################################################
|
||||||
|
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
||||||
|
batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2]
|
||||||
|
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
|
||||||
|
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type)
|
||||||
|
|
||||||
|
if do_split:
|
||||||
|
for i in range(batch_size_attention // split_slice_size):
|
||||||
|
start_idx = i * split_slice_size
|
||||||
|
end_idx = (i + 1) * split_slice_size
|
||||||
|
if do_split_2:
|
||||||
|
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
||||||
|
start_idx_2 = i2 * split_2_slice_size
|
||||||
|
end_idx_2 = (i2 + 1) * split_2_slice_size
|
||||||
|
if do_split_3:
|
||||||
|
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
|
||||||
|
start_idx_3 = i3 * split_3_slice_size
|
||||||
|
end_idx_3 = (i3 + 1) * split_3_slice_size
|
||||||
|
|
||||||
|
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
|
||||||
|
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
|
||||||
|
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None
|
||||||
|
|
||||||
|
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
||||||
|
del query_slice
|
||||||
|
del key_slice
|
||||||
|
del attn_mask_slice
|
||||||
|
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3])
|
||||||
|
|
||||||
|
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice
|
||||||
|
del attn_slice
|
||||||
|
else:
|
||||||
|
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
|
||||||
|
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
|
||||||
|
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
|
||||||
|
|
||||||
|
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
||||||
|
del query_slice
|
||||||
|
del key_slice
|
||||||
|
del attn_mask_slice
|
||||||
|
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
|
||||||
|
|
||||||
|
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
|
||||||
|
del attn_slice
|
||||||
|
else:
|
||||||
|
query_slice = query[start_idx:end_idx]
|
||||||
|
key_slice = key[start_idx:end_idx]
|
||||||
|
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
||||||
|
|
||||||
|
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
||||||
|
del query_slice
|
||||||
|
del key_slice
|
||||||
|
del attn_mask_slice
|
||||||
|
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
||||||
|
|
||||||
|
hidden_states[start_idx:end_idx] = attn_slice
|
||||||
|
del attn_slice
|
||||||
|
else:
|
||||||
|
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
||||||
|
hidden_states = torch.bmm(attention_probs, value)
|
||||||
|
####################################################################
|
||||||
|
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||||
|
|
||||||
|
# linear proj
|
||||||
|
hidden_states = attn.to_out[0](hidden_states, *args)
|
||||||
|
# dropout
|
||||||
|
hidden_states = attn.to_out[1](hidden_states)
|
||||||
|
|
||||||
|
if input_ndim == 4:
|
||||||
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||||
|
|
||||||
|
if attn.residual_connection:
|
||||||
|
hidden_states = hidden_states + residual
|
||||||
|
|
||||||
|
hidden_states = hidden_states / attn.rescale_output_factor
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
def ipex_diffusers():
|
def ipex_diffusers():
|
||||||
#ARC GPUs can't allocate more than 4GB to a single block:
|
#ARC GPUs can't allocate more than 4GB to a single block:
|
||||||
diffusers.models.attention_processor.SlicedAttnProcessor = SlicedAttnProcessor
|
diffusers.models.attention_processor.SlicedAttnProcessor = SlicedAttnProcessor
|
||||||
|
diffusers.models.attention_processor.AttnProcessor = AttnProcessor
|
||||||
|
|||||||
@@ -1,67 +1,9 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
import importlib
|
|
||||||
import torch
|
import torch
|
||||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||||
|
|
||||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return
|
# pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return
|
||||||
|
|
||||||
class CondFunc: # pylint: disable=missing-class-docstring
|
|
||||||
def __new__(cls, orig_func, sub_func, cond_func):
|
|
||||||
self = super(CondFunc, cls).__new__(cls)
|
|
||||||
if isinstance(orig_func, str):
|
|
||||||
func_path = orig_func.split('.')
|
|
||||||
for i in range(len(func_path)-1, -1, -1):
|
|
||||||
try:
|
|
||||||
resolved_obj = importlib.import_module('.'.join(func_path[:i]))
|
|
||||||
break
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
for attr_name in func_path[i:-1]:
|
|
||||||
resolved_obj = getattr(resolved_obj, attr_name)
|
|
||||||
orig_func = getattr(resolved_obj, func_path[-1])
|
|
||||||
setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
|
|
||||||
self.__init__(orig_func, sub_func, cond_func)
|
|
||||||
return lambda *args, **kwargs: self(*args, **kwargs)
|
|
||||||
def __init__(self, orig_func, sub_func, cond_func):
|
|
||||||
self.__orig_func = orig_func
|
|
||||||
self.__sub_func = sub_func
|
|
||||||
self.__cond_func = cond_func
|
|
||||||
def __call__(self, *args, **kwargs):
|
|
||||||
if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs):
|
|
||||||
return self.__sub_func(self.__orig_func, *args, **kwargs)
|
|
||||||
else:
|
|
||||||
return self.__orig_func(*args, **kwargs)
|
|
||||||
|
|
||||||
_utils = torch.utils.data._utils
|
|
||||||
def _shutdown_workers(self):
|
|
||||||
if torch.utils.data._utils is None or torch.utils.data._utils.python_exit_status is True or torch.utils.data._utils.python_exit_status is None:
|
|
||||||
return
|
|
||||||
if hasattr(self, "_shutdown") and not self._shutdown:
|
|
||||||
self._shutdown = True
|
|
||||||
try:
|
|
||||||
if hasattr(self, '_pin_memory_thread'):
|
|
||||||
self._pin_memory_thread_done_event.set()
|
|
||||||
self._worker_result_queue.put((None, None))
|
|
||||||
self._pin_memory_thread.join()
|
|
||||||
self._worker_result_queue.cancel_join_thread()
|
|
||||||
self._worker_result_queue.close()
|
|
||||||
self._workers_done_event.set()
|
|
||||||
for worker_id in range(len(self._workers)):
|
|
||||||
if self._persistent_workers or self._workers_status[worker_id]:
|
|
||||||
self._mark_worker_as_unavailable(worker_id, shutdown=True)
|
|
||||||
for w in self._workers: # pylint: disable=invalid-name
|
|
||||||
w.join(timeout=torch.utils.data._utils.MP_STATUS_CHECK_INTERVAL)
|
|
||||||
for q in self._index_queues: # pylint: disable=invalid-name
|
|
||||||
q.cancel_join_thread()
|
|
||||||
q.close()
|
|
||||||
finally:
|
|
||||||
if self._worker_pids_set:
|
|
||||||
torch.utils.data._utils.signal_handling._remove_worker_pids(id(self))
|
|
||||||
self._worker_pids_set = False
|
|
||||||
for w in self._workers: # pylint: disable=invalid-name
|
|
||||||
if w.is_alive():
|
|
||||||
w.terminate()
|
|
||||||
|
|
||||||
class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods
|
class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods
|
||||||
def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument
|
def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument
|
||||||
if isinstance(device_ids, list) and len(device_ids) > 1:
|
if isinstance(device_ids, list) and len(device_ids) > 1:
|
||||||
@@ -71,17 +13,18 @@ class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstr
|
|||||||
def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
|
def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
|
||||||
return contextlib.nullcontext()
|
return contextlib.nullcontext()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_cuda(self):
|
||||||
|
return self.device.type == 'xpu' or self.device.type == 'cuda'
|
||||||
|
|
||||||
def check_device(device):
|
def check_device(device):
|
||||||
return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int))
|
return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int))
|
||||||
|
|
||||||
def return_xpu(device):
|
def return_xpu(device):
|
||||||
return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu"
|
return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu"
|
||||||
|
|
||||||
def ipex_no_cuda(orig_func, *args, **kwargs):
|
|
||||||
torch.cuda.is_available = lambda: False
|
|
||||||
orig_func(*args, **kwargs)
|
|
||||||
torch.cuda.is_available = torch.xpu.is_available
|
|
||||||
|
|
||||||
|
# Autocast
|
||||||
original_autocast = torch.autocast
|
original_autocast = torch.autocast
|
||||||
def ipex_autocast(*args, **kwargs):
|
def ipex_autocast(*args, **kwargs):
|
||||||
if len(args) > 0 and args[0] == "cuda":
|
if len(args) > 0 and args[0] == "cuda":
|
||||||
@@ -89,15 +32,7 @@ def ipex_autocast(*args, **kwargs):
|
|||||||
else:
|
else:
|
||||||
return original_autocast(*args, **kwargs)
|
return original_autocast(*args, **kwargs)
|
||||||
|
|
||||||
# Embedding BF16
|
# Latent Antialias CPU Offload:
|
||||||
original_torch_cat = torch.cat
|
|
||||||
def torch_cat(tensor, *args, **kwargs):
|
|
||||||
if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype):
|
|
||||||
return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs)
|
|
||||||
else:
|
|
||||||
return original_torch_cat(tensor, *args, **kwargs)
|
|
||||||
|
|
||||||
# Latent antialias:
|
|
||||||
original_interpolate = torch.nn.functional.interpolate
|
original_interpolate = torch.nn.functional.interpolate
|
||||||
def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
|
def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
|
||||||
if antialias or align_corners is not None:
|
if antialias or align_corners is not None:
|
||||||
@@ -109,110 +44,205 @@ def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corn
|
|||||||
return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode,
|
return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode,
|
||||||
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias)
|
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias)
|
||||||
|
|
||||||
original_linalg_solve = torch.linalg.solve
|
# Diffusers Float64 (Alchemist GPUs doesn't support 64 bit):
|
||||||
def linalg_solve(A, B, *args, **kwargs): # pylint: disable=invalid-name
|
original_from_numpy = torch.from_numpy
|
||||||
if A.device != torch.device("cpu") or B.device != torch.device("cpu"):
|
def from_numpy(ndarray):
|
||||||
return_device = A.device
|
if ndarray.dtype == float:
|
||||||
return original_linalg_solve(A.to("cpu"), B.to("cpu"), *args, **kwargs).to(return_device)
|
return original_from_numpy(ndarray.astype('float32'))
|
||||||
else:
|
else:
|
||||||
return original_linalg_solve(A, B, *args, **kwargs)
|
return original_from_numpy(ndarray)
|
||||||
|
|
||||||
@property
|
if torch.xpu.has_fp64_dtype():
|
||||||
def is_cuda(self):
|
original_torch_bmm = torch.bmm
|
||||||
return self.device.type == 'xpu'
|
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
||||||
|
else:
|
||||||
|
# 32 bit attention workarounds for Alchemist:
|
||||||
|
try:
|
||||||
|
from .attention import torch_bmm_32_bit as original_torch_bmm
|
||||||
|
from .attention import scaled_dot_product_attention_32_bit as original_scaled_dot_product_attention
|
||||||
|
except Exception: # pylint: disable=broad-exception-caught
|
||||||
|
original_torch_bmm = torch.bmm
|
||||||
|
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
||||||
|
|
||||||
|
|
||||||
|
# Data Type Errors:
|
||||||
|
def torch_bmm(input, mat2, *, out=None):
|
||||||
|
if input.dtype != mat2.dtype:
|
||||||
|
mat2 = mat2.to(input.dtype)
|
||||||
|
return original_torch_bmm(input, mat2, out=out)
|
||||||
|
|
||||||
|
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
|
||||||
|
if query.dtype != key.dtype:
|
||||||
|
key = key.to(dtype=query.dtype)
|
||||||
|
if query.dtype != value.dtype:
|
||||||
|
value = value.to(dtype=query.dtype)
|
||||||
|
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
|
||||||
|
|
||||||
|
# A1111 FP16
|
||||||
|
original_functional_group_norm = torch.nn.functional.group_norm
|
||||||
|
def functional_group_norm(input, num_groups, weight=None, bias=None, eps=1e-05):
|
||||||
|
if weight is not None and input.dtype != weight.data.dtype:
|
||||||
|
input = input.to(dtype=weight.data.dtype)
|
||||||
|
if bias is not None and weight is not None and bias.data.dtype != weight.data.dtype:
|
||||||
|
bias.data = bias.data.to(dtype=weight.data.dtype)
|
||||||
|
return original_functional_group_norm(input, num_groups, weight=weight, bias=bias, eps=eps)
|
||||||
|
|
||||||
|
# A1111 BF16
|
||||||
|
original_functional_layer_norm = torch.nn.functional.layer_norm
|
||||||
|
def functional_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
|
||||||
|
if weight is not None and input.dtype != weight.data.dtype:
|
||||||
|
input = input.to(dtype=weight.data.dtype)
|
||||||
|
if bias is not None and weight is not None and bias.data.dtype != weight.data.dtype:
|
||||||
|
bias.data = bias.data.to(dtype=weight.data.dtype)
|
||||||
|
return original_functional_layer_norm(input, normalized_shape, weight=weight, bias=bias, eps=eps)
|
||||||
|
|
||||||
|
# Training
|
||||||
|
original_functional_linear = torch.nn.functional.linear
|
||||||
|
def functional_linear(input, weight, bias=None):
|
||||||
|
if input.dtype != weight.data.dtype:
|
||||||
|
input = input.to(dtype=weight.data.dtype)
|
||||||
|
if bias is not None and bias.data.dtype != weight.data.dtype:
|
||||||
|
bias.data = bias.data.to(dtype=weight.data.dtype)
|
||||||
|
return original_functional_linear(input, weight, bias=bias)
|
||||||
|
|
||||||
|
original_functional_conv2d = torch.nn.functional.conv2d
|
||||||
|
def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
||||||
|
if input.dtype != weight.data.dtype:
|
||||||
|
input = input.to(dtype=weight.data.dtype)
|
||||||
|
if bias is not None and bias.data.dtype != weight.data.dtype:
|
||||||
|
bias.data = bias.data.to(dtype=weight.data.dtype)
|
||||||
|
return original_functional_conv2d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
||||||
|
|
||||||
|
# A1111 Embedding BF16
|
||||||
|
original_torch_cat = torch.cat
|
||||||
|
def torch_cat(tensor, *args, **kwargs):
|
||||||
|
if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype):
|
||||||
|
return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs)
|
||||||
|
else:
|
||||||
|
return original_torch_cat(tensor, *args, **kwargs)
|
||||||
|
|
||||||
|
# SwinIR BF16:
|
||||||
|
original_functional_pad = torch.nn.functional.pad
|
||||||
|
def functional_pad(input, pad, mode='constant', value=None):
|
||||||
|
if mode == 'reflect' and input.dtype == torch.bfloat16:
|
||||||
|
return original_functional_pad(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16)
|
||||||
|
else:
|
||||||
|
return original_functional_pad(input, pad, mode=mode, value=value)
|
||||||
|
|
||||||
|
|
||||||
|
original_torch_tensor = torch.tensor
|
||||||
|
def torch_tensor(*args, device=None, **kwargs):
|
||||||
|
if check_device(device):
|
||||||
|
return original_torch_tensor(*args, device=return_xpu(device), **kwargs)
|
||||||
|
else:
|
||||||
|
return original_torch_tensor(*args, device=device, **kwargs)
|
||||||
|
|
||||||
|
original_Tensor_to = torch.Tensor.to
|
||||||
|
def Tensor_to(self, device=None, *args, **kwargs):
|
||||||
|
if check_device(device):
|
||||||
|
return original_Tensor_to(self, return_xpu(device), *args, **kwargs)
|
||||||
|
else:
|
||||||
|
return original_Tensor_to(self, device, *args, **kwargs)
|
||||||
|
|
||||||
|
original_Tensor_cuda = torch.Tensor.cuda
|
||||||
|
def Tensor_cuda(self, device=None, *args, **kwargs):
|
||||||
|
if check_device(device):
|
||||||
|
return original_Tensor_cuda(self, return_xpu(device), *args, **kwargs)
|
||||||
|
else:
|
||||||
|
return original_Tensor_cuda(self, device, *args, **kwargs)
|
||||||
|
|
||||||
|
original_UntypedStorage_init = torch.UntypedStorage.__init__
|
||||||
|
def UntypedStorage_init(*args, device=None, **kwargs):
|
||||||
|
if check_device(device):
|
||||||
|
return original_UntypedStorage_init(*args, device=return_xpu(device), **kwargs)
|
||||||
|
else:
|
||||||
|
return original_UntypedStorage_init(*args, device=device, **kwargs)
|
||||||
|
|
||||||
|
original_UntypedStorage_cuda = torch.UntypedStorage.cuda
|
||||||
|
def UntypedStorage_cuda(self, device=None, *args, **kwargs):
|
||||||
|
if check_device(device):
|
||||||
|
return original_UntypedStorage_cuda(self, return_xpu(device), *args, **kwargs)
|
||||||
|
else:
|
||||||
|
return original_UntypedStorage_cuda(self, device, *args, **kwargs)
|
||||||
|
|
||||||
|
original_torch_empty = torch.empty
|
||||||
|
def torch_empty(*args, device=None, **kwargs):
|
||||||
|
if check_device(device):
|
||||||
|
return original_torch_empty(*args, device=return_xpu(device), **kwargs)
|
||||||
|
else:
|
||||||
|
return original_torch_empty(*args, device=device, **kwargs)
|
||||||
|
|
||||||
|
original_torch_randn = torch.randn
|
||||||
|
def torch_randn(*args, device=None, **kwargs):
|
||||||
|
if check_device(device):
|
||||||
|
return original_torch_randn(*args, device=return_xpu(device), **kwargs)
|
||||||
|
else:
|
||||||
|
return original_torch_randn(*args, device=device, **kwargs)
|
||||||
|
|
||||||
|
original_torch_ones = torch.ones
|
||||||
|
def torch_ones(*args, device=None, **kwargs):
|
||||||
|
if check_device(device):
|
||||||
|
return original_torch_ones(*args, device=return_xpu(device), **kwargs)
|
||||||
|
else:
|
||||||
|
return original_torch_ones(*args, device=device, **kwargs)
|
||||||
|
|
||||||
|
original_torch_zeros = torch.zeros
|
||||||
|
def torch_zeros(*args, device=None, **kwargs):
|
||||||
|
if check_device(device):
|
||||||
|
return original_torch_zeros(*args, device=return_xpu(device), **kwargs)
|
||||||
|
else:
|
||||||
|
return original_torch_zeros(*args, device=device, **kwargs)
|
||||||
|
|
||||||
|
original_torch_linspace = torch.linspace
|
||||||
|
def torch_linspace(*args, device=None, **kwargs):
|
||||||
|
if check_device(device):
|
||||||
|
return original_torch_linspace(*args, device=return_xpu(device), **kwargs)
|
||||||
|
else:
|
||||||
|
return original_torch_linspace(*args, device=device, **kwargs)
|
||||||
|
|
||||||
|
original_torch_Generator = torch.Generator
|
||||||
|
def torch_Generator(device=None):
|
||||||
|
if check_device(device):
|
||||||
|
return original_torch_Generator(return_xpu(device))
|
||||||
|
else:
|
||||||
|
return original_torch_Generator(device)
|
||||||
|
|
||||||
|
original_torch_load = torch.load
|
||||||
|
def torch_load(f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs):
|
||||||
|
if check_device(map_location):
|
||||||
|
return original_torch_load(f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs)
|
||||||
|
else:
|
||||||
|
return original_torch_load(f, map_location=map_location, pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs)
|
||||||
|
|
||||||
|
# Hijack Functions:
|
||||||
def ipex_hijacks():
|
def ipex_hijacks():
|
||||||
CondFunc('torch.tensor',
|
torch.tensor = torch_tensor
|
||||||
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
|
torch.Tensor.to = Tensor_to
|
||||||
lambda orig_func, *args, device=None, **kwargs: check_device(device))
|
torch.Tensor.cuda = Tensor_cuda
|
||||||
CondFunc('torch.Tensor.to',
|
torch.UntypedStorage.__init__ = UntypedStorage_init
|
||||||
lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs),
|
torch.UntypedStorage.cuda = UntypedStorage_cuda
|
||||||
lambda orig_func, self, device=None, *args, **kwargs: check_device(device))
|
torch.empty = torch_empty
|
||||||
CondFunc('torch.Tensor.cuda',
|
torch.randn = torch_randn
|
||||||
lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs),
|
torch.ones = torch_ones
|
||||||
lambda orig_func, self, device=None, *args, **kwargs: check_device(device))
|
torch.zeros = torch_zeros
|
||||||
CondFunc('torch.UntypedStorage.__init__',
|
torch.linspace = torch_linspace
|
||||||
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
|
torch.Generator = torch_Generator
|
||||||
lambda orig_func, *args, device=None, **kwargs: check_device(device))
|
torch.load = torch_load
|
||||||
CondFunc('torch.UntypedStorage.cuda',
|
|
||||||
lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs),
|
|
||||||
lambda orig_func, self, device=None, *args, **kwargs: check_device(device))
|
|
||||||
CondFunc('torch.empty',
|
|
||||||
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
|
|
||||||
lambda orig_func, *args, device=None, **kwargs: check_device(device))
|
|
||||||
CondFunc('torch.randn',
|
|
||||||
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
|
|
||||||
lambda orig_func, *args, device=None, **kwargs: check_device(device))
|
|
||||||
CondFunc('torch.ones',
|
|
||||||
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
|
|
||||||
lambda orig_func, *args, device=None, **kwargs: check_device(device))
|
|
||||||
CondFunc('torch.zeros',
|
|
||||||
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
|
|
||||||
lambda orig_func, *args, device=None, **kwargs: check_device(device))
|
|
||||||
CondFunc('torch.linspace',
|
|
||||||
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
|
|
||||||
lambda orig_func, *args, device=None, **kwargs: check_device(device))
|
|
||||||
CondFunc('torch.load',
|
|
||||||
lambda orig_func, f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs:
|
|
||||||
orig_func(orig_func, f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs),
|
|
||||||
lambda orig_func, f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs: check_device(map_location))
|
|
||||||
|
|
||||||
CondFunc('torch.Generator',
|
|
||||||
lambda orig_func, device=None: torch.xpu.Generator(return_xpu(device)),
|
|
||||||
lambda orig_func, device=None: device is not None and device != torch.device("cpu") and device != "cpu")
|
|
||||||
|
|
||||||
# TiledVAE and ControlNet:
|
|
||||||
CondFunc('torch.batch_norm',
|
|
||||||
lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input,
|
|
||||||
weight if weight is not None else torch.ones(input.size()[1], device=input.device),
|
|
||||||
bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs),
|
|
||||||
lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu"))
|
|
||||||
CondFunc('torch.instance_norm',
|
|
||||||
lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input,
|
|
||||||
weight if weight is not None else torch.ones(input.size()[1], device=input.device),
|
|
||||||
bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs),
|
|
||||||
lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu"))
|
|
||||||
|
|
||||||
# Functions with dtype errors:
|
|
||||||
CondFunc('torch.nn.modules.GroupNorm.forward',
|
|
||||||
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
|
|
||||||
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
|
|
||||||
# Training:
|
|
||||||
CondFunc('torch.nn.modules.linear.Linear.forward',
|
|
||||||
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
|
|
||||||
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
|
|
||||||
CondFunc('torch.nn.modules.conv.Conv2d.forward',
|
|
||||||
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
|
|
||||||
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
|
|
||||||
# BF16:
|
|
||||||
CondFunc('torch.nn.functional.layer_norm',
|
|
||||||
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
|
|
||||||
orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs),
|
|
||||||
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
|
|
||||||
weight is not None and input.dtype != weight.data.dtype)
|
|
||||||
# SwinIR BF16:
|
|
||||||
CondFunc('torch.nn.functional.pad',
|
|
||||||
lambda orig_func, input, pad, mode='constant', value=None: orig_func(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16),
|
|
||||||
lambda orig_func, input, pad, mode='constant', value=None: mode == 'reflect' and input.dtype == torch.bfloat16)
|
|
||||||
|
|
||||||
# Diffusers Float64 (Alchemist GPUs doesn't support 64 bit):
|
|
||||||
if not torch.xpu.has_fp64_dtype():
|
|
||||||
CondFunc('torch.from_numpy',
|
|
||||||
lambda orig_func, ndarray: orig_func(ndarray.astype('float32')),
|
|
||||||
lambda orig_func, ndarray: ndarray.dtype == float)
|
|
||||||
|
|
||||||
# Broken functions when torch.cuda.is_available is True:
|
|
||||||
# Pin Memory:
|
|
||||||
CondFunc('torch.utils.data.dataloader._BaseDataLoaderIter.__init__',
|
|
||||||
lambda orig_func, *args, **kwargs: ipex_no_cuda(orig_func, *args, **kwargs),
|
|
||||||
lambda orig_func, *args, **kwargs: True)
|
|
||||||
|
|
||||||
# Functions that make compile mad with CondFunc:
|
|
||||||
torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers = _shutdown_workers
|
|
||||||
torch.nn.DataParallel = DummyDataParallel
|
|
||||||
torch.autocast = ipex_autocast
|
|
||||||
torch.cat = torch_cat
|
|
||||||
torch.linalg.solve = linalg_solve
|
|
||||||
torch.UntypedStorage.is_cuda = is_cuda
|
|
||||||
torch.nn.functional.interpolate = interpolate
|
|
||||||
torch.backends.cuda.sdp_kernel = return_null_context
|
torch.backends.cuda.sdp_kernel = return_null_context
|
||||||
|
torch.nn.DataParallel = DummyDataParallel
|
||||||
|
torch.UntypedStorage.is_cuda = is_cuda
|
||||||
|
torch.autocast = ipex_autocast
|
||||||
|
|
||||||
|
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention
|
||||||
|
torch.nn.functional.group_norm = functional_group_norm
|
||||||
|
torch.nn.functional.layer_norm = functional_layer_norm
|
||||||
|
torch.nn.functional.linear = functional_linear
|
||||||
|
torch.nn.functional.conv2d = functional_conv2d
|
||||||
|
torch.nn.functional.interpolate = interpolate
|
||||||
|
torch.nn.functional.pad = functional_pad
|
||||||
|
|
||||||
|
torch.bmm = torch_bmm
|
||||||
|
torch.cat = torch_cat
|
||||||
|
if not torch.xpu.has_fp64_dtype():
|
||||||
|
torch.from_numpy = from_numpy
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import numpy as np
|
|||||||
import PIL.Image
|
import PIL.Image
|
||||||
import torch
|
import torch
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||||
|
|
||||||
import diffusers
|
import diffusers
|
||||||
from diffusers import SchedulerMixin, StableDiffusionPipeline
|
from diffusers import SchedulerMixin, StableDiffusionPipeline
|
||||||
@@ -520,6 +520,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|||||||
safety_checker: StableDiffusionSafetyChecker,
|
safety_checker: StableDiffusionSafetyChecker,
|
||||||
feature_extractor: CLIPFeatureExtractor,
|
feature_extractor: CLIPFeatureExtractor,
|
||||||
requires_safety_checker: bool = True,
|
requires_safety_checker: bool = True,
|
||||||
|
image_encoder: CLIPVisionModelWithProjection = None,
|
||||||
clip_skip: int = 1,
|
clip_skip: int = 1,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@@ -531,32 +532,11 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|||||||
safety_checker=safety_checker,
|
safety_checker=safety_checker,
|
||||||
feature_extractor=feature_extractor,
|
feature_extractor=feature_extractor,
|
||||||
requires_safety_checker=requires_safety_checker,
|
requires_safety_checker=requires_safety_checker,
|
||||||
|
image_encoder=image_encoder,
|
||||||
)
|
)
|
||||||
self.clip_skip = clip_skip
|
self.custom_clip_skip = clip_skip
|
||||||
self.__init__additional__()
|
self.__init__additional__()
|
||||||
|
|
||||||
# else:
|
|
||||||
# def __init__(
|
|
||||||
# self,
|
|
||||||
# vae: AutoencoderKL,
|
|
||||||
# text_encoder: CLIPTextModel,
|
|
||||||
# tokenizer: CLIPTokenizer,
|
|
||||||
# unet: UNet2DConditionModel,
|
|
||||||
# scheduler: SchedulerMixin,
|
|
||||||
# safety_checker: StableDiffusionSafetyChecker,
|
|
||||||
# feature_extractor: CLIPFeatureExtractor,
|
|
||||||
# ):
|
|
||||||
# super().__init__(
|
|
||||||
# vae=vae,
|
|
||||||
# text_encoder=text_encoder,
|
|
||||||
# tokenizer=tokenizer,
|
|
||||||
# unet=unet,
|
|
||||||
# scheduler=scheduler,
|
|
||||||
# safety_checker=safety_checker,
|
|
||||||
# feature_extractor=feature_extractor,
|
|
||||||
# )
|
|
||||||
# self.__init__additional__()
|
|
||||||
|
|
||||||
def __init__additional__(self):
|
def __init__additional__(self):
|
||||||
if not hasattr(self, "vae_scale_factor"):
|
if not hasattr(self, "vae_scale_factor"):
|
||||||
setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1))
|
setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1))
|
||||||
@@ -624,7 +604,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
|
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
|
||||||
max_embeddings_multiples=max_embeddings_multiples,
|
max_embeddings_multiples=max_embeddings_multiples,
|
||||||
clip_skip=self.clip_skip,
|
clip_skip=self.custom_clip_skip,
|
||||||
)
|
)
|
||||||
bs_embed, seq_len, _ = text_embeddings.shape
|
bs_embed, seq_len, _ = text_embeddings.shape
|
||||||
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
|
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
|
||||||
|
|||||||
@@ -4,10 +4,13 @@
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
|
|
||||||
if torch.xpu.is_available():
|
if torch.xpu.is_available():
|
||||||
from library.ipex import ipex_init
|
from library.ipex import ipex_init
|
||||||
|
|
||||||
ipex_init()
|
ipex_init()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
@@ -571,9 +574,9 @@ def convert_ldm_clip_checkpoint_v1(checkpoint):
|
|||||||
if key.startswith("cond_stage_model.transformer"):
|
if key.startswith("cond_stage_model.transformer"):
|
||||||
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
|
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
|
||||||
|
|
||||||
# support checkpoint without position_ids (invalid checkpoint)
|
# remove position_ids for newer transformer, which causes error :(
|
||||||
if "text_model.embeddings.position_ids" not in text_model_dict:
|
if "text_model.embeddings.position_ids" in text_model_dict:
|
||||||
text_model_dict["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0) # 77 is the max length of the text
|
text_model_dict.pop("text_model.embeddings.position_ids")
|
||||||
|
|
||||||
return text_model_dict
|
return text_model_dict
|
||||||
|
|
||||||
@@ -1307,19 +1310,19 @@ def load_vae(vae_id, dtype):
|
|||||||
|
|
||||||
def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
|
def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
|
||||||
max_width, max_height = max_reso
|
max_width, max_height = max_reso
|
||||||
max_area = (max_width // divisible) * (max_height // divisible)
|
max_area = max_width * max_height
|
||||||
|
|
||||||
resos = set()
|
resos = set()
|
||||||
|
|
||||||
size = int(math.sqrt(max_area)) * divisible
|
width = int(math.sqrt(max_area) // divisible) * divisible
|
||||||
resos.add((size, size))
|
resos.add((width, width))
|
||||||
|
|
||||||
size = min_size
|
width = min_size
|
||||||
while size <= max_size:
|
while width <= max_size:
|
||||||
width = size
|
height = min(max_size, int((max_area // width) // divisible) * divisible)
|
||||||
height = min(max_size, (max_area // (width // divisible)) * divisible)
|
if height >= min_size:
|
||||||
resos.add((width, height))
|
resos.add((width, height))
|
||||||
resos.add((height, width))
|
resos.add((height, width))
|
||||||
|
|
||||||
# # make additional resos
|
# # make additional resos
|
||||||
# if width >= height and width - divisible >= min_size:
|
# if width >= height and width - divisible >= min_size:
|
||||||
@@ -1329,7 +1332,7 @@ def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64)
|
|||||||
# resos.add((width, height - divisible))
|
# resos.add((width, height - divisible))
|
||||||
# resos.add((height - divisible, width))
|
# resos.add((height - divisible, width))
|
||||||
|
|
||||||
size += divisible
|
width += divisible
|
||||||
|
|
||||||
resos = list(resos)
|
resos = list(resos)
|
||||||
resos.sort()
|
resos.sort()
|
||||||
|
|||||||
@@ -100,7 +100,7 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
|
|||||||
key = key.replace(".ln_final", ".final_layer_norm")
|
key = key.replace(".ln_final", ".final_layer_norm")
|
||||||
# ckpt from comfy has this key: text_model.encoder.text_model.embeddings.position_ids
|
# ckpt from comfy has this key: text_model.encoder.text_model.embeddings.position_ids
|
||||||
elif ".embeddings.position_ids" in key:
|
elif ".embeddings.position_ids" in key:
|
||||||
key = None # remove this key: make position_ids by ourselves
|
key = None # remove this key: position_ids is not used in newer transformers
|
||||||
return key
|
return key
|
||||||
|
|
||||||
keys = list(checkpoint.keys())
|
keys = list(checkpoint.keys())
|
||||||
@@ -126,10 +126,6 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
|
|||||||
new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
|
new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
|
||||||
new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
|
new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
|
||||||
|
|
||||||
# original SD にはないので、position_idsを追加
|
|
||||||
position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
|
|
||||||
new_sd["text_model.embeddings.position_ids"] = position_ids
|
|
||||||
|
|
||||||
# logit_scale はDiffusersには含まれないが、保存時に戻したいので別途返す
|
# logit_scale はDiffusersには含まれないが、保存時に戻したいので別途返す
|
||||||
logit_scale = checkpoint.get(SDXL_KEY_PREFIX + "logit_scale", None)
|
logit_scale = checkpoint.get(SDXL_KEY_PREFIX + "logit_scale", None)
|
||||||
|
|
||||||
@@ -265,9 +261,9 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty
|
|||||||
elif k.startswith("conditioner.embedders.1.model."):
|
elif k.startswith("conditioner.embedders.1.model."):
|
||||||
te2_sd[k] = state_dict.pop(k)
|
te2_sd[k] = state_dict.pop(k)
|
||||||
|
|
||||||
# 一部のposition_idsがないモデルへの対応 / add position_ids for some models
|
# 最新の transformers では position_ids を含むとエラーになるので削除 / remove position_ids for latest transformers
|
||||||
if "text_model.embeddings.position_ids" not in te1_sd:
|
if "text_model.embeddings.position_ids" in te1_sd:
|
||||||
te1_sd["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0)
|
te1_sd.pop("text_model.embeddings.position_ids")
|
||||||
|
|
||||||
info1 = _load_state_dict_on_device(text_model1, te1_sd, device=map_location) # remain fp32
|
info1 = _load_state_dict_on_device(text_model1, te1_sd, device=map_location) # remain fp32
|
||||||
print("text encoder 1:", info1)
|
print("text encoder 1:", info1)
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from typing import (
|
|||||||
Tuple,
|
Tuple,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
from accelerate import Accelerator, InitProcessGroupKwargs
|
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs
|
||||||
import gc
|
import gc
|
||||||
import glob
|
import glob
|
||||||
import math
|
import math
|
||||||
@@ -351,6 +351,7 @@ class BaseSubset:
|
|||||||
shuffle_caption: bool,
|
shuffle_caption: bool,
|
||||||
caption_separator: str,
|
caption_separator: str,
|
||||||
keep_tokens: int,
|
keep_tokens: int,
|
||||||
|
keep_tokens_separator: str,
|
||||||
color_aug: bool,
|
color_aug: bool,
|
||||||
flip_aug: bool,
|
flip_aug: bool,
|
||||||
face_crop_aug_range: Optional[Tuple[float, float]],
|
face_crop_aug_range: Optional[Tuple[float, float]],
|
||||||
@@ -368,6 +369,7 @@ class BaseSubset:
|
|||||||
self.shuffle_caption = shuffle_caption
|
self.shuffle_caption = shuffle_caption
|
||||||
self.caption_separator = caption_separator
|
self.caption_separator = caption_separator
|
||||||
self.keep_tokens = keep_tokens
|
self.keep_tokens = keep_tokens
|
||||||
|
self.keep_tokens_separator = keep_tokens_separator
|
||||||
self.color_aug = color_aug
|
self.color_aug = color_aug
|
||||||
self.flip_aug = flip_aug
|
self.flip_aug = flip_aug
|
||||||
self.face_crop_aug_range = face_crop_aug_range
|
self.face_crop_aug_range = face_crop_aug_range
|
||||||
@@ -395,6 +397,7 @@ class DreamBoothSubset(BaseSubset):
|
|||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
caption_separator: str,
|
caption_separator: str,
|
||||||
keep_tokens,
|
keep_tokens,
|
||||||
|
keep_tokens_separator,
|
||||||
color_aug,
|
color_aug,
|
||||||
flip_aug,
|
flip_aug,
|
||||||
face_crop_aug_range,
|
face_crop_aug_range,
|
||||||
@@ -415,6 +418,7 @@ class DreamBoothSubset(BaseSubset):
|
|||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
caption_separator,
|
caption_separator,
|
||||||
keep_tokens,
|
keep_tokens,
|
||||||
|
keep_tokens_separator,
|
||||||
color_aug,
|
color_aug,
|
||||||
flip_aug,
|
flip_aug,
|
||||||
face_crop_aug_range,
|
face_crop_aug_range,
|
||||||
@@ -449,6 +453,7 @@ class FineTuningSubset(BaseSubset):
|
|||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
caption_separator,
|
caption_separator,
|
||||||
keep_tokens,
|
keep_tokens,
|
||||||
|
keep_tokens_separator,
|
||||||
color_aug,
|
color_aug,
|
||||||
flip_aug,
|
flip_aug,
|
||||||
face_crop_aug_range,
|
face_crop_aug_range,
|
||||||
@@ -469,6 +474,7 @@ class FineTuningSubset(BaseSubset):
|
|||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
caption_separator,
|
caption_separator,
|
||||||
keep_tokens,
|
keep_tokens,
|
||||||
|
keep_tokens_separator,
|
||||||
color_aug,
|
color_aug,
|
||||||
flip_aug,
|
flip_aug,
|
||||||
face_crop_aug_range,
|
face_crop_aug_range,
|
||||||
@@ -500,6 +506,7 @@ class ControlNetSubset(BaseSubset):
|
|||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
caption_separator,
|
caption_separator,
|
||||||
keep_tokens,
|
keep_tokens,
|
||||||
|
keep_tokens_separator,
|
||||||
color_aug,
|
color_aug,
|
||||||
flip_aug,
|
flip_aug,
|
||||||
face_crop_aug_range,
|
face_crop_aug_range,
|
||||||
@@ -520,6 +527,7 @@ class ControlNetSubset(BaseSubset):
|
|||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
caption_separator,
|
caption_separator,
|
||||||
keep_tokens,
|
keep_tokens,
|
||||||
|
keep_tokens_separator,
|
||||||
color_aug,
|
color_aug,
|
||||||
flip_aug,
|
flip_aug,
|
||||||
face_crop_aug_range,
|
face_crop_aug_range,
|
||||||
@@ -654,15 +662,33 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
caption = ""
|
caption = ""
|
||||||
else:
|
else:
|
||||||
if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
|
if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
|
||||||
tokens = [t.strip() for t in caption.strip().split(subset.caption_separator)]
|
fixed_tokens = []
|
||||||
|
flex_tokens = []
|
||||||
|
if (
|
||||||
|
hasattr(subset, "keep_tokens_separator")
|
||||||
|
and subset.keep_tokens_separator
|
||||||
|
and subset.keep_tokens_separator in caption
|
||||||
|
):
|
||||||
|
fixed_part, flex_part = caption.split(subset.keep_tokens_separator, 1)
|
||||||
|
fixed_tokens = [t.strip() for t in fixed_part.split(subset.caption_separator) if t.strip()]
|
||||||
|
flex_tokens = [t.strip() for t in flex_part.split(subset.caption_separator) if t.strip()]
|
||||||
|
else:
|
||||||
|
tokens = [t.strip() for t in caption.strip().split(subset.caption_separator)]
|
||||||
|
flex_tokens = tokens[:]
|
||||||
|
if subset.keep_tokens > 0:
|
||||||
|
fixed_tokens = flex_tokens[: subset.keep_tokens]
|
||||||
|
flex_tokens = tokens[subset.keep_tokens :]
|
||||||
|
|
||||||
if subset.token_warmup_step < 1: # 初回に上書きする
|
if subset.token_warmup_step < 1: # 初回に上書きする
|
||||||
subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps)
|
subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps)
|
||||||
if subset.token_warmup_step and self.current_step < subset.token_warmup_step:
|
if subset.token_warmup_step and self.current_step < subset.token_warmup_step:
|
||||||
tokens_len = (
|
tokens_len = (
|
||||||
math.floor((self.current_step) * ((len(tokens) - subset.token_warmup_min) / (subset.token_warmup_step)))
|
math.floor(
|
||||||
|
(self.current_step) * ((len(flex_tokens) - subset.token_warmup_min) / (subset.token_warmup_step))
|
||||||
|
)
|
||||||
+ subset.token_warmup_min
|
+ subset.token_warmup_min
|
||||||
)
|
)
|
||||||
tokens = tokens[:tokens_len]
|
flex_tokens = flex_tokens[:tokens_len]
|
||||||
|
|
||||||
def dropout_tags(tokens):
|
def dropout_tags(tokens):
|
||||||
if subset.caption_tag_dropout_rate <= 0:
|
if subset.caption_tag_dropout_rate <= 0:
|
||||||
@@ -673,12 +699,6 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
l.append(token)
|
l.append(token)
|
||||||
return l
|
return l
|
||||||
|
|
||||||
fixed_tokens = []
|
|
||||||
flex_tokens = tokens[:]
|
|
||||||
if subset.keep_tokens > 0:
|
|
||||||
fixed_tokens = flex_tokens[: subset.keep_tokens]
|
|
||||||
flex_tokens = tokens[subset.keep_tokens :]
|
|
||||||
|
|
||||||
if subset.shuffle_caption:
|
if subset.shuffle_caption:
|
||||||
random.shuffle(flex_tokens)
|
random.shuffle(flex_tokens)
|
||||||
|
|
||||||
@@ -1724,6 +1744,7 @@ class ControlNetDataset(BaseDataset):
|
|||||||
subset.shuffle_caption,
|
subset.shuffle_caption,
|
||||||
subset.caption_separator,
|
subset.caption_separator,
|
||||||
subset.keep_tokens,
|
subset.keep_tokens,
|
||||||
|
subset.keep_tokens_separator,
|
||||||
subset.color_aug,
|
subset.color_aug,
|
||||||
subset.flip_aug,
|
subset.flip_aug,
|
||||||
subset.face_crop_aug_range,
|
subset.face_crop_aug_range,
|
||||||
@@ -2827,6 +2848,17 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う",
|
help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う",
|
||||||
)
|
)
|
||||||
|
parser.add_argument("--torch_compile", action="store_true", help="use torch.compile (requires PyTorch 2.0) / torch.compile を使う")
|
||||||
|
parser.add_argument(
|
||||||
|
"--dynamo_backend",
|
||||||
|
type=str,
|
||||||
|
default="inductor",
|
||||||
|
# available backends:
|
||||||
|
# https://github.com/huggingface/accelerate/blob/d1abd59114ada8ba673e1214218cb2878c13b82d/src/accelerate/utils/dataclasses.py#L376-L388C5
|
||||||
|
# https://pytorch.org/docs/stable/torch.compiler.html
|
||||||
|
choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt"],
|
||||||
|
help="dynamo backend type (default is inductor) / dynamoのbackendの種類(デフォルトは inductor)"
|
||||||
|
)
|
||||||
parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
|
parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--sdpa",
|
"--sdpa",
|
||||||
@@ -2878,6 +2910,16 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|||||||
default=None,
|
default=None,
|
||||||
help="DDP timeout (min, None for default of accelerate) / DDPのタイムアウト(分、Noneでaccelerateのデフォルト)",
|
help="DDP timeout (min, None for default of accelerate) / DDPのタイムアウト(分、Noneでaccelerateのデフォルト)",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ddp_gradient_as_bucket_view",
|
||||||
|
action="store_true",
|
||||||
|
help="enable gradient_as_bucket_view for DDP / DDPでgradient_as_bucket_viewを有効にする",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ddp_static_graph",
|
||||||
|
action="store_true",
|
||||||
|
help="enable static_graph for DDP / DDPでstatic_graphを有効にする",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--clip_skip",
|
"--clip_skip",
|
||||||
type=int,
|
type=int,
|
||||||
@@ -3131,6 +3173,13 @@ def add_dataset_arguments(
|
|||||||
default=0,
|
default=0,
|
||||||
help="keep heading N tokens when shuffling caption tokens (token means comma separated strings) / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す(トークンはカンマ区切りの各部分を意味する)",
|
help="keep heading N tokens when shuffling caption tokens (token means comma separated strings) / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す(トークンはカンマ区切りの各部分を意味する)",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--keep_tokens_separator",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="A custom separator to divide the caption into fixed and flexible parts. Tokens before this separator will not be shuffled. If not specified, '--keep_tokens' will be used to determine the fixed number of tokens."
|
||||||
|
+ " / captionを固定部分と可変部分に分けるためのカスタム区切り文字。この区切り文字より前のトークンはシャッフルされない。指定しない場合、'--keep_tokens'が固定部分のトークン数として使用される。",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--caption_prefix",
|
"--caption_prefix",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -3832,15 +3881,25 @@ def prepare_accelerator(args: argparse.Namespace):
|
|||||||
if args.wandb_api_key is not None:
|
if args.wandb_api_key is not None:
|
||||||
wandb.login(key=args.wandb_api_key)
|
wandb.login(key=args.wandb_api_key)
|
||||||
|
|
||||||
|
# torch.compile のオプション。 NO の場合は torch.compile は使わない
|
||||||
|
dynamo_backend = "NO"
|
||||||
|
if args.torch_compile:
|
||||||
|
dynamo_backend = args.dynamo_backend
|
||||||
|
|
||||||
kwargs_handlers = (
|
kwargs_handlers = (
|
||||||
None if args.ddp_timeout is None else [InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout))]
|
InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout)) if args.ddp_timeout else None,
|
||||||
|
DistributedDataParallelKwargs(gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph)
|
||||||
|
if args.ddp_gradient_as_bucket_view or args.ddp_static_graph
|
||||||
|
else None,
|
||||||
)
|
)
|
||||||
|
kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers))
|
||||||
accelerator = Accelerator(
|
accelerator = Accelerator(
|
||||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||||
mixed_precision=args.mixed_precision,
|
mixed_precision=args.mixed_precision,
|
||||||
log_with=log_with,
|
log_with=log_with,
|
||||||
project_dir=logging_dir,
|
project_dir=logging_dir,
|
||||||
kwargs_handlers=kwargs_handlers,
|
kwargs_handlers=kwargs_handlers,
|
||||||
|
dynamo_backend=dynamo_backend,
|
||||||
)
|
)
|
||||||
return accelerator
|
return accelerator
|
||||||
|
|
||||||
@@ -4561,7 +4620,7 @@ def line_to_prompt_dict(line: str) -> dict:
|
|||||||
|
|
||||||
def sample_images_common(
|
def sample_images_common(
|
||||||
pipe_class,
|
pipe_class,
|
||||||
accelerator,
|
accelerator: Accelerator,
|
||||||
args: argparse.Namespace,
|
args: argparse.Namespace,
|
||||||
epoch,
|
epoch,
|
||||||
steps,
|
steps,
|
||||||
@@ -4598,6 +4657,13 @@ def sample_images_common(
|
|||||||
org_vae_device = vae.device # CPUにいるはず
|
org_vae_device = vae.device # CPUにいるはず
|
||||||
vae.to(device)
|
vae.to(device)
|
||||||
|
|
||||||
|
# unwrap unet and text_encoder(s)
|
||||||
|
unet = accelerator.unwrap_model(unet)
|
||||||
|
if isinstance(text_encoder, (list, tuple)):
|
||||||
|
text_encoder = [accelerator.unwrap_model(te) for te in text_encoder]
|
||||||
|
else:
|
||||||
|
text_encoder = accelerator.unwrap_model(text_encoder)
|
||||||
|
|
||||||
# read prompts
|
# read prompts
|
||||||
|
|
||||||
# with open(args.sample_prompts, "rt", encoding="utf-8") as f:
|
# with open(args.sample_prompts, "rt", encoding="utf-8") as f:
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
accelerate==0.23.0
|
accelerate==0.25.0
|
||||||
transformers==4.30.2
|
transformers==4.36.2
|
||||||
diffusers[torch]==0.21.2
|
diffusers[torch]==0.25.0
|
||||||
ftfy==6.1.1
|
ftfy==6.1.1
|
||||||
# albumentations==1.3.0
|
# albumentations==1.3.0
|
||||||
opencv-python==4.7.0.68
|
opencv-python==4.7.0.68
|
||||||
einops==0.6.0
|
einops==0.6.1
|
||||||
pytorch-lightning==1.9.0
|
pytorch-lightning==1.9.0
|
||||||
# bitsandbytes==0.39.1
|
# bitsandbytes==0.39.1
|
||||||
tensorboard==2.10.1
|
tensorboard==2.10.1
|
||||||
@@ -14,7 +14,7 @@ altair==4.2.2
|
|||||||
easygui==0.98.3
|
easygui==0.98.3
|
||||||
toml==0.10.2
|
toml==0.10.2
|
||||||
voluptuous==0.13.1
|
voluptuous==0.13.1
|
||||||
huggingface-hub==0.15.1
|
huggingface-hub==0.20.1
|
||||||
# for BLIP captioning
|
# for BLIP captioning
|
||||||
# requests==2.28.2
|
# requests==2.28.2
|
||||||
# timm==0.6.12
|
# timm==0.6.12
|
||||||
|
|||||||
@@ -398,6 +398,9 @@ def train(args):
|
|||||||
if train_unet:
|
if train_unet:
|
||||||
unet = accelerator.prepare(unet)
|
unet = accelerator.prepare(unet)
|
||||||
if train_text_encoder1:
|
if train_text_encoder1:
|
||||||
|
# freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer
|
||||||
|
text_encoder1.text_model.encoder.layers[-1].requires_grad_(False)
|
||||||
|
text_encoder1.text_model.final_layer_norm.requires_grad_(False)
|
||||||
text_encoder1 = accelerator.prepare(text_encoder1)
|
text_encoder1 = accelerator.prepare(text_encoder1)
|
||||||
if train_text_encoder2:
|
if train_text_encoder2:
|
||||||
text_encoder2 = accelerator.prepare(text_encoder2)
|
text_encoder2 = accelerator.prepare(text_encoder2)
|
||||||
@@ -484,7 +487,7 @@ def train(args):
|
|||||||
# NaNが含まれていれば警告を表示し0に置き換える
|
# NaNが含まれていれば警告を表示し0に置き換える
|
||||||
if torch.any(torch.isnan(latents)):
|
if torch.any(torch.isnan(latents)):
|
||||||
accelerator.print("NaN found in latents, replacing with zeros")
|
accelerator.print("NaN found in latents, replacing with zeros")
|
||||||
latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
|
latents = torch.nan_to_num(latents, 0, out=latents)
|
||||||
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
|
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
|
||||||
|
|
||||||
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
|
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
|
||||||
|
|||||||
@@ -394,7 +394,7 @@ def train(args):
|
|||||||
# NaNが含まれていれば警告を表示し0に置き換える
|
# NaNが含まれていれば警告を表示し0に置き換える
|
||||||
if torch.any(torch.isnan(latents)):
|
if torch.any(torch.isnan(latents)):
|
||||||
accelerator.print("NaN found in latents, replacing with zeros")
|
accelerator.print("NaN found in latents, replacing with zeros")
|
||||||
latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
|
latents = torch.nan_to_num(latents, 0, out=latents)
|
||||||
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
|
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
|
||||||
|
|
||||||
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
|
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
|
||||||
|
|||||||
@@ -363,7 +363,7 @@ def train(args):
|
|||||||
# NaNが含まれていれば警告を表示し0に置き換える
|
# NaNが含まれていれば警告を表示し0に置き換える
|
||||||
if torch.any(torch.isnan(latents)):
|
if torch.any(torch.isnan(latents)):
|
||||||
accelerator.print("NaN found in latents, replacing with zeros")
|
accelerator.print("NaN found in latents, replacing with zeros")
|
||||||
latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
|
latents = torch.nan_to_num(latents, 0, out=latents)
|
||||||
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
|
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
|
||||||
|
|
||||||
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
|
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ def convert(args):
|
|||||||
is_load_ckpt = os.path.isfile(args.model_to_load)
|
is_load_ckpt = os.path.isfile(args.model_to_load)
|
||||||
is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 0
|
is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 0
|
||||||
|
|
||||||
assert not is_load_ckpt or args.v1 != args.v2, f"v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です"
|
assert not is_load_ckpt or args.v1 != args.v2, "v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です"
|
||||||
# assert (
|
# assert (
|
||||||
# is_save_ckpt or args.reference_model is not None
|
# is_save_ckpt or args.reference_model is not None
|
||||||
# ), f"reference model is required to save as Diffusers / Diffusers形式での保存には参照モデルが必要です"
|
# ), f"reference model is required to save as Diffusers / Diffusers形式での保存には参照モデルが必要です"
|
||||||
@@ -34,10 +34,12 @@ def convert(args):
|
|||||||
|
|
||||||
if is_load_ckpt:
|
if is_load_ckpt:
|
||||||
v2_model = args.v2
|
v2_model = args.v2
|
||||||
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(v2_model, args.model_to_load, unet_use_linear_projection_in_v2=args.unet_use_linear_projection)
|
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(
|
||||||
|
v2_model, args.model_to_load, unet_use_linear_projection_in_v2=args.unet_use_linear_projection
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
pipe = StableDiffusionPipeline.from_pretrained(
|
pipe = StableDiffusionPipeline.from_pretrained(
|
||||||
args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None
|
args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None, variant=args.variant
|
||||||
)
|
)
|
||||||
text_encoder = pipe.text_encoder
|
text_encoder = pipe.text_encoder
|
||||||
vae = pipe.vae
|
vae = pipe.vae
|
||||||
@@ -57,15 +59,26 @@ def convert(args):
|
|||||||
if is_save_ckpt:
|
if is_save_ckpt:
|
||||||
original_model = args.model_to_load if is_load_ckpt else None
|
original_model = args.model_to_load if is_load_ckpt else None
|
||||||
key_count = model_util.save_stable_diffusion_checkpoint(
|
key_count = model_util.save_stable_diffusion_checkpoint(
|
||||||
v2_model, args.model_to_save, text_encoder, unet, original_model, args.epoch, args.global_step, save_dtype, vae
|
v2_model,
|
||||||
|
args.model_to_save,
|
||||||
|
text_encoder,
|
||||||
|
unet,
|
||||||
|
original_model,
|
||||||
|
args.epoch,
|
||||||
|
args.global_step,
|
||||||
|
None if args.metadata is None else eval(args.metadata),
|
||||||
|
save_dtype=save_dtype,
|
||||||
|
vae=vae,
|
||||||
)
|
)
|
||||||
print(f"model saved. total converted state_dict keys: {key_count}")
|
print(f"model saved. total converted state_dict keys: {key_count}")
|
||||||
else:
|
else:
|
||||||
print(f"copy scheduler/tokenizer config from: {args.reference_model if args.reference_model is not None else 'default model'}")
|
print(
|
||||||
|
f"copy scheduler/tokenizer config from: {args.reference_model if args.reference_model is not None else 'default model'}"
|
||||||
|
)
|
||||||
model_util.save_diffusers_checkpoint(
|
model_util.save_diffusers_checkpoint(
|
||||||
v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors
|
v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors
|
||||||
)
|
)
|
||||||
print(f"model saved.")
|
print("model saved.")
|
||||||
|
|
||||||
|
|
||||||
def setup_parser() -> argparse.ArgumentParser:
|
def setup_parser() -> argparse.ArgumentParser:
|
||||||
@@ -77,7 +90,9 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
"--v2", action="store_true", help="load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む"
|
"--v2", action="store_true", help="load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--unet_use_linear_projection", action="store_true", help="When saving v2 model as Diffusers, set U-Net config to `use_linear_projection=true` (to match stabilityai's model) / Diffusers形式でv2モデルを保存するときにU-Netの設定を`use_linear_projection=true`にする(stabilityaiのモデルと合わせる)"
|
"--unet_use_linear_projection",
|
||||||
|
action="store_true",
|
||||||
|
help="When saving v2 model as Diffusers, set U-Net config to `use_linear_projection=true` (to match stabilityai's model) / Diffusers形式でv2モデルを保存するときにU-Netの設定を`use_linear_projection=true`にする(stabilityaiのモデルと合わせる)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--fp16",
|
"--fp16",
|
||||||
@@ -99,6 +114,18 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--global_step", type=int, default=0, help="global_step to write to checkpoint / checkpointに記録するglobal_stepの値"
|
"--global_step", type=int, default=0, help="global_step to write to checkpoint / checkpointに記録するglobal_stepの値"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--metadata",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help='モデルに保存されるメタデータ、Pythonの辞書形式で指定 / metadata: metadata written in to the model in Python Dictionary. Example metadata: \'{"name": "model_name", "resolution": "512x512"}\'',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--variant",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="読む込むDiffusersのvariantを指定する、例: fp16 / variant: Diffusers variant to load. Example: fp16",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--reference_model",
|
"--reference_model",
|
||||||
type=str,
|
type=str,
|
||||||
|
|||||||
@@ -750,7 +750,7 @@ class NetworkTrainer:
|
|||||||
# NaNが含まれていれば警告を表示し0に置き換える
|
# NaNが含まれていれば警告を表示し0に置き換える
|
||||||
if torch.any(torch.isnan(latents)):
|
if torch.any(torch.isnan(latents)):
|
||||||
accelerator.print("NaN found in latents, replacing with zeros")
|
accelerator.print("NaN found in latents, replacing with zeros")
|
||||||
latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
|
latents = torch.nan_to_num(latents, 0, out=latents)
|
||||||
latents = latents * self.vae_scale_factor
|
latents = latents * self.vae_scale_factor
|
||||||
b_size = latents.shape[0]
|
b_size = latents.shape[0]
|
||||||
|
|
||||||
|
|||||||
@@ -441,9 +441,10 @@ class TextualInversionTrainer:
|
|||||||
|
|
||||||
# Freeze all parameters except for the token embeddings in text encoder
|
# Freeze all parameters except for the token embeddings in text encoder
|
||||||
text_encoder.requires_grad_(True)
|
text_encoder.requires_grad_(True)
|
||||||
text_encoder.text_model.encoder.requires_grad_(False)
|
unwrapped_text_encoder = accelerator.unwrap_model(text_encoder)
|
||||||
text_encoder.text_model.final_layer_norm.requires_grad_(False)
|
unwrapped_text_encoder.text_model.encoder.requires_grad_(False)
|
||||||
text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
|
unwrapped_text_encoder.text_model.final_layer_norm.requires_grad_(False)
|
||||||
|
unwrapped_text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
|
||||||
# text_encoder.text_model.embeddings.token_embedding.requires_grad_(True)
|
# text_encoder.text_model.embeddings.token_embedding.requires_grad_(True)
|
||||||
|
|
||||||
unet.requires_grad_(False)
|
unet.requires_grad_(False)
|
||||||
@@ -603,7 +604,7 @@ class TextualInversionTrainer:
|
|||||||
|
|
||||||
accelerator.backward(loss)
|
accelerator.backward(loss)
|
||||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||||
params_to_clip = text_encoder.get_input_embeddings().parameters()
|
params_to_clip = accelerator.unwrap_model(text_encoder).get_input_embeddings().parameters()
|
||||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
@@ -615,9 +616,11 @@ class TextualInversionTrainer:
|
|||||||
for text_encoder, orig_embeds_params, index_no_updates in zip(
|
for text_encoder, orig_embeds_params, index_no_updates in zip(
|
||||||
text_encoders, orig_embeds_params_list, index_no_updates_list
|
text_encoders, orig_embeds_params_list, index_no_updates_list
|
||||||
):
|
):
|
||||||
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
|
# if full_fp16/bf16, input_embeddings_weight is fp16/bf16, orig_embeds_params is fp32
|
||||||
|
input_embeddings_weight = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight
|
||||||
|
input_embeddings_weight[index_no_updates] = orig_embeds_params.to(input_embeddings_weight.dtype)[
|
||||||
index_no_updates
|
index_no_updates
|
||||||
] = orig_embeds_params[index_no_updates]
|
]
|
||||||
|
|
||||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||||
if accelerator.sync_gradients:
|
if accelerator.sync_gradients:
|
||||||
|
|||||||
Reference in New Issue
Block a user