diff --git a/.github/workflows/typos.yml b/.github/workflows/typos.yml index b6865dbf..bd4ef334 100644 --- a/.github/workflows/typos.yml +++ b/.github/workflows/typos.yml @@ -18,4 +18,4 @@ jobs: - uses: actions/checkout@v4 - name: typos-action - uses: crate-ci/typos@v1.16.15 + uses: crate-ci/typos@v1.16.26 diff --git a/README.md b/README.md index 51183a9a..60e665a7 100644 --- a/README.md +++ b/README.md @@ -281,6 +281,41 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum ## Change History +### Dec 24, 2023 / 2023/12/24 + +- Fixed to work `tools/convert_diffusers20_original_sd.py`. Thanks to Disty0! PR [#1016](https://github.com/kohya-ss/sd-scripts/pull/1016) + +- `tools/convert_diffusers20_original_sd.py` が動かなくなっていたのが修正されました。Disty0 氏に感謝します。 PR [#1016](https://github.com/kohya-ss/sd-scripts/pull/1016) + + +### Dec 21, 2023 / 2023/12/21 + +- The issues in multi-GPU training are fixed. Thanks to Isotr0py! PR [#989](https://github.com/kohya-ss/sd-scripts/pull/989) and [#1000](https://github.com/kohya-ss/sd-scripts/pull/1000) + - `--ddp_gradient_as_bucket_view` and `--ddp_bucket_view`options are added to `sdxl_train.py`. Please specify these options for multi-GPU training. +- IPEX support is updated. Thanks to Disty0! +- Fixed the bug that the size of the bucket becomes less than `min_bucket_reso`. Thanks to Cauldrath! PR [#1008](https://github.com/kohya-ss/sd-scripts/pull/1008) +- `--sample_at_first` option is added to each training script. This option is useful to generate images at the first step, before training. Thanks to shirayu! PR [#907](https://github.com/kohya-ss/sd-scripts/pull/907) +- `--ss` option is added to the sampling prompt in training. You can specify the scheduler for the sampling like `--ss euler_a`. Thanks to shirayu! PR [#906](https://github.com/kohya-ss/sd-scripts/pull/906) +- `keep_tokens_separator` is added to the dataset config. This option is useful to keep (prevent from shuffling) the tokens in the captions. See [#975](https://github.com/kohya-ss/sd-scripts/pull/975) for details. Thanks to Linaqruf! + - You can specify the separator with an option like `--keep_tokens_separator "|||"` or with `keep_tokens_separator: "|||"` in `.toml`. The tokens before `|||` are not shuffled. +- Attention processor hook is added. See [#961](https://github.com/kohya-ss/sd-scripts/pull/961) for details. Thanks to rockerBOO! +- The optimizer `PagedAdamW` is added. Thanks to xzuyn! PR [#955](https://github.com/kohya-ss/sd-scripts/pull/955) +- NaN replacement in SDXL VAE is sped up. Thanks to liubo0902! PR [#1009](https://github.com/kohya-ss/sd-scripts/pull/1009) +- Fixed the path error in `finetune/make_captions.py`. Thanks to CjangCjengh! PR [#986](https://github.com/kohya-ss/sd-scripts/pull/986) + +- マルチGPUでの学習の不具合を修正しました。Isotr0py 氏に感謝します。 PR [#989](https://github.com/kohya-ss/sd-scripts/pull/989) および [#1000](https://github.com/kohya-ss/sd-scripts/pull/1000) + - `sdxl_train.py` に `--ddp_gradient_as_bucket_view` と `--ddp_bucket_view` オプションが追加されました。マルチGPUでの学習時にはこれらのオプションを指定してください。 +- IPEX サポートが更新されました。Disty0 氏に感謝します。 +- Aspect Ratio Bucketing で bucket のサイズが `min_bucket_reso` 未満になる不具合を修正しました。Cauldrath 氏に感謝します。 PR [#1008](https://github.com/kohya-ss/sd-scripts/pull/1008) +- 各学習スクリプトに `--sample_at_first` オプションが追加されました。学習前に画像を生成することで、学習結果が比較しやすくなります。shirayu 氏に感謝します。 PR [#907](https://github.com/kohya-ss/sd-scripts/pull/907) +- 学習時のプロンプトに `--ss` オプションが追加されました。`--ss euler_a` のようにスケジューラを指定できます。shirayu 氏に感謝します。 PR [#906](https://github.com/kohya-ss/sd-scripts/pull/906) +- データセット設定に `keep_tokens_separator` が追加されました。キャプション内のトークンをどの位置までシャッフルしないかを指定できます。詳細は [#975](https://github.com/kohya-ss/sd-scripts/pull/975) を参照してください。Linaqruf 氏に感謝します。 + - オプションで `--keep_tokens_separator "|||"` のように指定するか、`.toml` で `keep_tokens_separator: "|||"` のように指定します。`|||` の前のトークンはシャッフルされません。 +- Attention processor hook が追加されました。詳細は [#961](https://github.com/kohya-ss/sd-scripts/pull/961) を参照してください。rockerBOO 氏に感謝します。 +- オプティマイザ `PagedAdamW` が追加されました。xzuyn 氏に感謝します。 PR [#955](https://github.com/kohya-ss/sd-scripts/pull/955) +- 学習時、SDXL VAE で NaN が発生した時の置き換えが高速化されました。liubo0902 氏に感謝します。 PR [#1009](https://github.com/kohya-ss/sd-scripts/pull/1009) +- `finetune/make_captions.py` で相対パス指定時のエラーが修正されました。CjangCjengh 氏に感謝します。 PR [#986](https://github.com/kohya-ss/sd-scripts/pull/986) + ### Dec 3, 2023 / 2023/12/3 - `finetune\tag_images_by_wd14_tagger.py` now supports the separator other than `,` with `--caption_separator` option. Thanks to KohakuBlueleaf! PR [#913](https://github.com/kohya-ss/sd-scripts/pull/913) diff --git a/finetune/make_captions.py b/finetune/make_captions.py index b20c4106..074576bc 100644 --- a/finetune/make_captions.py +++ b/finetune/make_captions.py @@ -13,7 +13,7 @@ import torch from torchvision import transforms from torchvision.transforms.functional import InterpolationMode sys.path.append(os.path.dirname(__file__)) -from blip.blip import blip_decoder +from blip.blip import blip_decoder, is_url import library.train_util as train_util DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -76,6 +76,8 @@ def main(args): cwd = os.getcwd() print("Current Working Directory is: ", cwd) os.chdir("finetune") + if not is_url(args.caption_weights) and not os.path.isfile(args.caption_weights): + args.caption_weights = os.path.join("..", args.caption_weights) print(f"load images from {args.train_data_dir}") train_data_dir_path = Path(args.train_data_dir) diff --git a/library/config_util.py b/library/config_util.py index ab90fb63..47868f3b 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -53,6 +53,7 @@ class BaseSubsetParams: shuffle_caption: bool = False caption_separator: str = ',', keep_tokens: int = 0 + keep_tokens_separator: str = None, color_aug: bool = False flip_aug: bool = False face_crop_aug_range: Optional[Tuple[float, float]] = None @@ -160,6 +161,7 @@ class ConfigSanitizer: "random_crop": bool, "shuffle_caption": bool, "keep_tokens": int, + "keep_tokens_separator": str, "token_warmup_min": int, "token_warmup_step": Any(float,int), "caption_prefix": str, @@ -461,6 +463,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu num_repeats: {subset.num_repeats} shuffle_caption: {subset.shuffle_caption} keep_tokens: {subset.keep_tokens} + keep_tokens_separator: {subset.keep_tokens_separator} caption_dropout_rate: {subset.caption_dropout_rate} caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs} caption_tag_dropout_rate: {subset.caption_tag_dropout_rate} diff --git a/library/ipex/__init__.py b/library/ipex/__init__.py index 662572c8..33350493 100644 --- a/library/ipex/__init__.py +++ b/library/ipex/__init__.py @@ -140,6 +140,7 @@ def ipex_init(): # pylint: disable=too-many-statements # C torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream + ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_eu_count ipex._C._DeviceProperties.major = 2023 ipex._C._DeviceProperties.minor = 2 @@ -156,20 +157,9 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.get_device_properties.minor = 7 torch.cuda.ipc_collect = lambda *args, **kwargs: None torch.cuda.utilization = lambda *args, **kwargs: 0 - if hasattr(torch.xpu, 'getDeviceIdListForCard'): - torch.cuda.getDeviceIdListForCard = torch.xpu.getDeviceIdListForCard - torch.cuda.get_device_id_list_per_card = torch.xpu.getDeviceIdListForCard - else: - torch.cuda.getDeviceIdListForCard = torch.xpu.get_device_id_list_per_card - torch.cuda.get_device_id_list_per_card = torch.xpu.get_device_id_list_per_card ipex_hijacks() if not torch.xpu.has_fp64_dtype(): - try: - from .attention import attention_init - attention_init() - except Exception: # pylint: disable=broad-exception-caught - pass try: from .diffusers import ipex_diffusers ipex_diffusers() diff --git a/library/ipex/attention.py b/library/ipex/attention.py index 52016466..e98807a8 100644 --- a/library/ipex/attention.py +++ b/library/ipex/attention.py @@ -1,45 +1,98 @@ +import os import torch import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import +from functools import cache # pylint: disable=protected-access, missing-function-docstring, line-too-long -original_torch_bmm = torch.bmm -def torch_bmm(input, mat2, *, out=None): - if input.dtype != mat2.dtype: - mat2 = mat2.to(input.dtype) +# ARC GPUs can't allocate more than 4GB to a single block so we slice the attetion layers - #ARC GPUs can't allocate more than 4GB to a single block, Slice it: - batch_size_attention, input_tokens, mat2_shape = input.shape[0], input.shape[1], mat2.shape[2] - block_multiply = input.element_size() - slice_block_size = input_tokens * mat2_shape / 1024 / 1024 * block_multiply +sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 4)) +attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4)) + +# Find something divisible with the input_tokens +@cache +def find_slice_size(slice_size, slice_block_size): + while (slice_size * slice_block_size) > attention_slice_rate: + slice_size = slice_size // 2 + if slice_size <= 1: + slice_size = 1 + break + return slice_size + +# Find slice sizes for SDPA +@cache +def find_sdpa_slice_sizes(query_shape, query_element_size): + if len(query_shape) == 3: + batch_size_attention, query_tokens, shape_three = query_shape + shape_four = 1 + else: + batch_size_attention, query_tokens, shape_three, shape_four = query_shape + + slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size block_size = batch_size_attention * slice_block_size split_slice_size = batch_size_attention - if block_size > 4: + split_2_slice_size = query_tokens + split_3_slice_size = shape_three + + do_split = False + do_split_2 = False + do_split_3 = False + + if block_size > sdpa_slice_trigger_rate: do_split = True - #Find something divisible with the input_tokens - while (split_slice_size * slice_block_size) > 4: - split_slice_size = split_slice_size // 2 - if split_slice_size <= 1: - split_slice_size = 1 - break - else: - do_split = False + split_slice_size = find_slice_size(split_slice_size, slice_block_size) + if split_slice_size * slice_block_size > attention_slice_rate: + slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size + do_split_2 = True + split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size) + if split_2_slice_size * slice_2_block_size > attention_slice_rate: + slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size + do_split_3 = True + split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size) + return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size + +# Find slice sizes for BMM +@cache +def find_bmm_slice_sizes(input_shape, input_element_size, mat2_shape): + batch_size_attention, input_tokens, mat2_atten_shape = input_shape[0], input_shape[1], mat2_shape[2] + slice_block_size = input_tokens * mat2_atten_shape / 1024 / 1024 * input_element_size + block_size = batch_size_attention * slice_block_size + + split_slice_size = batch_size_attention split_2_slice_size = input_tokens - if split_slice_size * slice_block_size > 4: - slice_block_size2 = split_slice_size * mat2_shape / 1024 / 1024 * block_multiply - do_split_2 = True - #Find something divisible with the input_tokens - while (split_2_slice_size * slice_block_size2) > 4: - split_2_slice_size = split_2_slice_size // 2 - if split_2_slice_size <= 1: - split_2_slice_size = 1 - break - else: - do_split_2 = False + split_3_slice_size = mat2_atten_shape + do_split = False + do_split_2 = False + do_split_3 = False + + if block_size > attention_slice_rate: + do_split = True + split_slice_size = find_slice_size(split_slice_size, slice_block_size) + if split_slice_size * slice_block_size > attention_slice_rate: + slice_2_block_size = split_slice_size * mat2_atten_shape / 1024 / 1024 * input_element_size + do_split_2 = True + split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size) + if split_2_slice_size * slice_2_block_size > attention_slice_rate: + slice_3_block_size = split_slice_size * split_2_slice_size / 1024 / 1024 * input_element_size + do_split_3 = True + split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size) + + return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size + + +original_torch_bmm = torch.bmm +def torch_bmm_32_bit(input, mat2, *, out=None): + if input.device.type != "xpu": + return original_torch_bmm(input, mat2, out=out) + do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_bmm_slice_sizes(input.shape, input.element_size(), mat2.shape) + + # Slice BMM if do_split: + batch_size_attention, input_tokens, mat2_atten_shape = input.shape[0], input.shape[1], mat2.shape[2] hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype) for i in range(batch_size_attention // split_slice_size): start_idx = i * split_slice_size @@ -48,11 +101,21 @@ def torch_bmm(input, mat2, *, out=None): for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name start_idx_2 = i2 * split_2_slice_size end_idx_2 = (i2 + 1) * split_2_slice_size - hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm( - input[start_idx:end_idx, start_idx_2:end_idx_2], - mat2[start_idx:end_idx, start_idx_2:end_idx_2], - out=out - ) + if do_split_3: + for i3 in range(mat2_atten_shape // split_3_slice_size): # pylint: disable=invalid-name + start_idx_3 = i3 * split_3_slice_size + end_idx_3 = (i3 + 1) * split_3_slice_size + hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_torch_bmm( + input[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], + mat2[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], + out=out + ) + else: + hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm( + input[start_idx:end_idx, start_idx_2:end_idx_2], + mat2[start_idx:end_idx, start_idx_2:end_idx_2], + out=out + ) else: hidden_states[start_idx:end_idx] = original_torch_bmm( input[start_idx:end_idx], @@ -64,51 +127,14 @@ def torch_bmm(input, mat2, *, out=None): return hidden_states original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention -def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False): - #ARC GPUs can't allocate more than 4GB to a single block, Slice it: - if len(query.shape) == 3: - batch_size_attention, query_tokens, shape_four = query.shape - shape_one = 1 - no_shape_one = True - else: - shape_one, batch_size_attention, query_tokens, shape_four = query.shape - no_shape_one = False - - if query.dtype != key.dtype: - key = key.to(dtype=query.dtype) - if query.dtype != value.dtype: - value = value.to(dtype=query.dtype) - - block_multiply = query.element_size() - slice_block_size = shape_one * query_tokens * shape_four / 1024 / 1024 * block_multiply - block_size = batch_size_attention * slice_block_size - - split_slice_size = batch_size_attention - if block_size > 4: - do_split = True - #Find something divisible with the shape_one - while (split_slice_size * slice_block_size) > 4: - split_slice_size = split_slice_size // 2 - if split_slice_size <= 1: - split_slice_size = 1 - break - else: - do_split = False - - split_2_slice_size = query_tokens - if split_slice_size * slice_block_size > 4: - slice_block_size2 = shape_one * split_slice_size * shape_four / 1024 / 1024 * block_multiply - do_split_2 = True - #Find something divisible with the batch_size_attention - while (split_2_slice_size * slice_block_size2) > 4: - split_2_slice_size = split_2_slice_size // 2 - if split_2_slice_size <= 1: - split_2_slice_size = 1 - break - else: - do_split_2 = False +def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False): + if query.device.type != "xpu": + return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal) + do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_sdpa_slice_sizes(query.shape, query.element_size()) + # Slice SDPA if do_split: + batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2] hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype) for i in range(batch_size_attention // split_slice_size): start_idx = i * split_slice_size @@ -117,7 +143,18 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0. for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name start_idx_2 = i2 * split_2_slice_size end_idx_2 = (i2 + 1) * split_2_slice_size - if no_shape_one: + if do_split_3: + for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name + start_idx_3 = i3 * split_3_slice_size + end_idx_3 = (i3 + 1) * split_3_slice_size + hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_scaled_dot_product_attention( + query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], + key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], + value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], + attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attn_mask is not None else attn_mask, + dropout_p=dropout_p, is_causal=is_causal + ) + else: hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention( query[start_idx:end_idx, start_idx_2:end_idx_2], key[start_idx:end_idx, start_idx_2:end_idx_2], @@ -125,38 +162,14 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0. attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask, dropout_p=dropout_p, is_causal=is_causal ) - else: - hidden_states[:, start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention( - query[:, start_idx:end_idx, start_idx_2:end_idx_2], - key[:, start_idx:end_idx, start_idx_2:end_idx_2], - value[:, start_idx:end_idx, start_idx_2:end_idx_2], - attn_mask=attn_mask[:, start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask, - dropout_p=dropout_p, is_causal=is_causal - ) else: - if no_shape_one: - hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention( - query[start_idx:end_idx], - key[start_idx:end_idx], - value[start_idx:end_idx], - attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask, - dropout_p=dropout_p, is_causal=is_causal - ) - else: - hidden_states[:, start_idx:end_idx] = original_scaled_dot_product_attention( - query[:, start_idx:end_idx], - key[:, start_idx:end_idx], - value[:, start_idx:end_idx], - attn_mask=attn_mask[:, start_idx:end_idx] if attn_mask is not None else attn_mask, - dropout_p=dropout_p, is_causal=is_causal - ) + hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention( + query[start_idx:end_idx], + key[start_idx:end_idx], + value[start_idx:end_idx], + attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask, + dropout_p=dropout_p, is_causal=is_causal + ) else: - return original_scaled_dot_product_attention( - query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal - ) + return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal) return hidden_states - -def attention_init(): - #ARC GPUs can't allocate more than 4GB to a single block: - torch.bmm = torch_bmm - torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention diff --git a/library/ipex/diffusers.py b/library/ipex/diffusers.py index c32af507..47b0375a 100644 --- a/library/ipex/diffusers.py +++ b/library/ipex/diffusers.py @@ -1,10 +1,62 @@ +import os import torch import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import import diffusers #0.24.0 # pylint: disable=import-error from diffusers.models.attention_processor import Attention +from diffusers.utils import USE_PEFT_BACKEND +from functools import cache # pylint: disable=protected-access, missing-function-docstring, line-too-long +attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4)) + +@cache +def find_slice_size(slice_size, slice_block_size): + while (slice_size * slice_block_size) > attention_slice_rate: + slice_size = slice_size // 2 + if slice_size <= 1: + slice_size = 1 + break + return slice_size + +@cache +def find_attention_slice_sizes(query_shape, query_element_size, query_device_type, slice_size=None): + if len(query_shape) == 3: + batch_size_attention, query_tokens, shape_three = query_shape + shape_four = 1 + else: + batch_size_attention, query_tokens, shape_three, shape_four = query_shape + if slice_size is not None: + batch_size_attention = slice_size + + slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size + block_size = batch_size_attention * slice_block_size + + split_slice_size = batch_size_attention + split_2_slice_size = query_tokens + split_3_slice_size = shape_three + + do_split = False + do_split_2 = False + do_split_3 = False + + if query_device_type != "xpu": + return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size + + if block_size > attention_slice_rate: + do_split = True + split_slice_size = find_slice_size(split_slice_size, slice_block_size) + if split_slice_size * slice_block_size > attention_slice_rate: + slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size + do_split_2 = True + split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size) + if split_2_slice_size * slice_2_block_size > attention_slice_rate: + slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size + do_split_3 = True + split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size) + + return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size + class SlicedAttnProcessor: # pylint: disable=too-few-public-methods r""" Processor for implementing sliced attention. @@ -18,7 +70,9 @@ class SlicedAttnProcessor: # pylint: disable=too-few-public-methods def __init__(self, slice_size): self.slice_size = slice_size - def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): # pylint: disable=too-many-statements, too-many-locals, too-many-branches + def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, + encoder_hidden_states=None, attention_mask=None) -> torch.FloatTensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches + residual = hidden_states input_ndim = hidden_states.ndim @@ -54,49 +108,61 @@ class SlicedAttnProcessor: # pylint: disable=too-few-public-methods (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype ) - #ARC GPUs can't allocate more than 4GB to a single block, Slice it: - block_multiply = query.element_size() - slice_block_size = self.slice_size * shape_three / 1024 / 1024 * block_multiply - block_size = query_tokens * slice_block_size - split_2_slice_size = query_tokens - if block_size > 4: - do_split_2 = True - #Find something divisible with the query_tokens - while (split_2_slice_size * slice_block_size) > 4: - split_2_slice_size = split_2_slice_size // 2 - if split_2_slice_size <= 1: - split_2_slice_size = 1 - break - else: - do_split_2 = False - - for i in range(batch_size_attention // self.slice_size): - start_idx = i * self.slice_size - end_idx = (i + 1) * self.slice_size + #################################################################### + # ARC GPUs can't allocate more than 4GB to a single block, Slice it: + _, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type, slice_size=self.slice_size) + for i in range(batch_size_attention // split_slice_size): + start_idx = i * split_slice_size + end_idx = (i + 1) * split_slice_size if do_split_2: for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name start_idx_2 = i2 * split_2_slice_size end_idx_2 = (i2 + 1) * split_2_slice_size + if do_split_3: + for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name + start_idx_3 = i3 * split_3_slice_size + end_idx_3 = (i3 + 1) * split_3_slice_size - query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2] - key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2] - attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None + query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] + key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] + attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None - attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) - attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2]) + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) + del query_slice + del key_slice + del attn_mask_slice + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]) - hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice + hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice + del attn_slice + else: + query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2] + key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2] + attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None + + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) + del query_slice + del key_slice + del attn_mask_slice + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2]) + + hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice + del attn_slice else: query_slice = query[start_idx:end_idx] key_slice = key[start_idx:end_idx] attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) - + del query_slice + del key_slice + del attn_mask_slice attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) hidden_states[start_idx:end_idx] = attn_slice + del attn_slice + #################################################################### hidden_states = attn.batch_to_head_dim(hidden_states) @@ -115,6 +181,130 @@ class SlicedAttnProcessor: # pylint: disable=too-few-public-methods return hidden_states + +class AttnProcessor: + r""" + Default processor for performing attention-related computations. + """ + + def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, + encoder_hidden_states=None, attention_mask=None, + temb=None, scale: float = 1.0) -> torch.Tensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches + + residual = hidden_states + + args = () if USE_PEFT_BACKEND else (scale,) + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states, *args) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states, *args) + value = attn.to_v(encoder_hidden_states, *args) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + #################################################################### + # ARC GPUs can't allocate more than 4GB to a single block, Slice it: + batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2] + hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype) + do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type) + + if do_split: + for i in range(batch_size_attention // split_slice_size): + start_idx = i * split_slice_size + end_idx = (i + 1) * split_slice_size + if do_split_2: + for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name + start_idx_2 = i2 * split_2_slice_size + end_idx_2 = (i2 + 1) * split_2_slice_size + if do_split_3: + for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name + start_idx_3 = i3 * split_3_slice_size + end_idx_3 = (i3 + 1) * split_3_slice_size + + query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] + key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] + attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None + + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) + del query_slice + del key_slice + del attn_mask_slice + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]) + + hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice + del attn_slice + else: + query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2] + key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2] + attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None + + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) + del query_slice + del key_slice + del attn_mask_slice + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2]) + + hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice + del attn_slice + else: + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None + + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) + del query_slice + del key_slice + del attn_mask_slice + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + del attn_slice + else: + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + #################################################################### + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states, *args) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + def ipex_diffusers(): #ARC GPUs can't allocate more than 4GB to a single block: diffusers.models.attention_processor.SlicedAttnProcessor = SlicedAttnProcessor + diffusers.models.attention_processor.AttnProcessor = AttnProcessor diff --git a/library/ipex/hijacks.py b/library/ipex/hijacks.py index 4a9a3569..b6d246dd 100644 --- a/library/ipex/hijacks.py +++ b/library/ipex/hijacks.py @@ -1,67 +1,9 @@ import contextlib -import importlib import torch import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import # pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return -class CondFunc: # pylint: disable=missing-class-docstring - def __new__(cls, orig_func, sub_func, cond_func): - self = super(CondFunc, cls).__new__(cls) - if isinstance(orig_func, str): - func_path = orig_func.split('.') - for i in range(len(func_path)-1, -1, -1): - try: - resolved_obj = importlib.import_module('.'.join(func_path[:i])) - break - except ImportError: - pass - for attr_name in func_path[i:-1]: - resolved_obj = getattr(resolved_obj, attr_name) - orig_func = getattr(resolved_obj, func_path[-1]) - setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs)) - self.__init__(orig_func, sub_func, cond_func) - return lambda *args, **kwargs: self(*args, **kwargs) - def __init__(self, orig_func, sub_func, cond_func): - self.__orig_func = orig_func - self.__sub_func = sub_func - self.__cond_func = cond_func - def __call__(self, *args, **kwargs): - if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs): - return self.__sub_func(self.__orig_func, *args, **kwargs) - else: - return self.__orig_func(*args, **kwargs) - -_utils = torch.utils.data._utils -def _shutdown_workers(self): - if torch.utils.data._utils is None or torch.utils.data._utils.python_exit_status is True or torch.utils.data._utils.python_exit_status is None: - return - if hasattr(self, "_shutdown") and not self._shutdown: - self._shutdown = True - try: - if hasattr(self, '_pin_memory_thread'): - self._pin_memory_thread_done_event.set() - self._worker_result_queue.put((None, None)) - self._pin_memory_thread.join() - self._worker_result_queue.cancel_join_thread() - self._worker_result_queue.close() - self._workers_done_event.set() - for worker_id in range(len(self._workers)): - if self._persistent_workers or self._workers_status[worker_id]: - self._mark_worker_as_unavailable(worker_id, shutdown=True) - for w in self._workers: # pylint: disable=invalid-name - w.join(timeout=torch.utils.data._utils.MP_STATUS_CHECK_INTERVAL) - for q in self._index_queues: # pylint: disable=invalid-name - q.cancel_join_thread() - q.close() - finally: - if self._worker_pids_set: - torch.utils.data._utils.signal_handling._remove_worker_pids(id(self)) - self._worker_pids_set = False - for w in self._workers: # pylint: disable=invalid-name - if w.is_alive(): - w.terminate() - class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument if isinstance(device_ids, list) and len(device_ids) > 1: @@ -71,17 +13,18 @@ class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstr def return_null_context(*args, **kwargs): # pylint: disable=unused-argument return contextlib.nullcontext() +@property +def is_cuda(self): + return self.device.type == 'xpu' or self.device.type == 'cuda' + def check_device(device): return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int)) def return_xpu(device): return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu" -def ipex_no_cuda(orig_func, *args, **kwargs): - torch.cuda.is_available = lambda: False - orig_func(*args, **kwargs) - torch.cuda.is_available = torch.xpu.is_available +# Autocast original_autocast = torch.autocast def ipex_autocast(*args, **kwargs): if len(args) > 0 and args[0] == "cuda": @@ -89,15 +32,7 @@ def ipex_autocast(*args, **kwargs): else: return original_autocast(*args, **kwargs) -# Embedding BF16 -original_torch_cat = torch.cat -def torch_cat(tensor, *args, **kwargs): - if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype): - return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs) - else: - return original_torch_cat(tensor, *args, **kwargs) - -# Latent antialias: +# Latent Antialias CPU Offload: original_interpolate = torch.nn.functional.interpolate def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments if antialias or align_corners is not None: @@ -109,110 +44,205 @@ def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corn return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode, align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias) -original_linalg_solve = torch.linalg.solve -def linalg_solve(A, B, *args, **kwargs): # pylint: disable=invalid-name - if A.device != torch.device("cpu") or B.device != torch.device("cpu"): - return_device = A.device - return original_linalg_solve(A.to("cpu"), B.to("cpu"), *args, **kwargs).to(return_device) +# Diffusers Float64 (Alchemist GPUs doesn't support 64 bit): +original_from_numpy = torch.from_numpy +def from_numpy(ndarray): + if ndarray.dtype == float: + return original_from_numpy(ndarray.astype('float32')) else: - return original_linalg_solve(A, B, *args, **kwargs) + return original_from_numpy(ndarray) -@property -def is_cuda(self): - return self.device.type == 'xpu' +if torch.xpu.has_fp64_dtype(): + original_torch_bmm = torch.bmm + original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention +else: + # 32 bit attention workarounds for Alchemist: + try: + from .attention import torch_bmm_32_bit as original_torch_bmm + from .attention import scaled_dot_product_attention_32_bit as original_scaled_dot_product_attention + except Exception: # pylint: disable=broad-exception-caught + original_torch_bmm = torch.bmm + original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention + +# Data Type Errors: +def torch_bmm(input, mat2, *, out=None): + if input.dtype != mat2.dtype: + mat2 = mat2.to(input.dtype) + return original_torch_bmm(input, mat2, out=out) + +def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False): + if query.dtype != key.dtype: + key = key.to(dtype=query.dtype) + if query.dtype != value.dtype: + value = value.to(dtype=query.dtype) + return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal) + +# A1111 FP16 +original_functional_group_norm = torch.nn.functional.group_norm +def functional_group_norm(input, num_groups, weight=None, bias=None, eps=1e-05): + if weight is not None and input.dtype != weight.data.dtype: + input = input.to(dtype=weight.data.dtype) + if bias is not None and weight is not None and bias.data.dtype != weight.data.dtype: + bias.data = bias.data.to(dtype=weight.data.dtype) + return original_functional_group_norm(input, num_groups, weight=weight, bias=bias, eps=eps) + +# A1111 BF16 +original_functional_layer_norm = torch.nn.functional.layer_norm +def functional_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05): + if weight is not None and input.dtype != weight.data.dtype: + input = input.to(dtype=weight.data.dtype) + if bias is not None and weight is not None and bias.data.dtype != weight.data.dtype: + bias.data = bias.data.to(dtype=weight.data.dtype) + return original_functional_layer_norm(input, normalized_shape, weight=weight, bias=bias, eps=eps) + +# Training +original_functional_linear = torch.nn.functional.linear +def functional_linear(input, weight, bias=None): + if input.dtype != weight.data.dtype: + input = input.to(dtype=weight.data.dtype) + if bias is not None and bias.data.dtype != weight.data.dtype: + bias.data = bias.data.to(dtype=weight.data.dtype) + return original_functional_linear(input, weight, bias=bias) + +original_functional_conv2d = torch.nn.functional.conv2d +def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): + if input.dtype != weight.data.dtype: + input = input.to(dtype=weight.data.dtype) + if bias is not None and bias.data.dtype != weight.data.dtype: + bias.data = bias.data.to(dtype=weight.data.dtype) + return original_functional_conv2d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) + +# A1111 Embedding BF16 +original_torch_cat = torch.cat +def torch_cat(tensor, *args, **kwargs): + if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype): + return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs) + else: + return original_torch_cat(tensor, *args, **kwargs) + +# SwinIR BF16: +original_functional_pad = torch.nn.functional.pad +def functional_pad(input, pad, mode='constant', value=None): + if mode == 'reflect' and input.dtype == torch.bfloat16: + return original_functional_pad(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16) + else: + return original_functional_pad(input, pad, mode=mode, value=value) + + +original_torch_tensor = torch.tensor +def torch_tensor(*args, device=None, **kwargs): + if check_device(device): + return original_torch_tensor(*args, device=return_xpu(device), **kwargs) + else: + return original_torch_tensor(*args, device=device, **kwargs) + +original_Tensor_to = torch.Tensor.to +def Tensor_to(self, device=None, *args, **kwargs): + if check_device(device): + return original_Tensor_to(self, return_xpu(device), *args, **kwargs) + else: + return original_Tensor_to(self, device, *args, **kwargs) + +original_Tensor_cuda = torch.Tensor.cuda +def Tensor_cuda(self, device=None, *args, **kwargs): + if check_device(device): + return original_Tensor_cuda(self, return_xpu(device), *args, **kwargs) + else: + return original_Tensor_cuda(self, device, *args, **kwargs) + +original_UntypedStorage_init = torch.UntypedStorage.__init__ +def UntypedStorage_init(*args, device=None, **kwargs): + if check_device(device): + return original_UntypedStorage_init(*args, device=return_xpu(device), **kwargs) + else: + return original_UntypedStorage_init(*args, device=device, **kwargs) + +original_UntypedStorage_cuda = torch.UntypedStorage.cuda +def UntypedStorage_cuda(self, device=None, *args, **kwargs): + if check_device(device): + return original_UntypedStorage_cuda(self, return_xpu(device), *args, **kwargs) + else: + return original_UntypedStorage_cuda(self, device, *args, **kwargs) + +original_torch_empty = torch.empty +def torch_empty(*args, device=None, **kwargs): + if check_device(device): + return original_torch_empty(*args, device=return_xpu(device), **kwargs) + else: + return original_torch_empty(*args, device=device, **kwargs) + +original_torch_randn = torch.randn +def torch_randn(*args, device=None, **kwargs): + if check_device(device): + return original_torch_randn(*args, device=return_xpu(device), **kwargs) + else: + return original_torch_randn(*args, device=device, **kwargs) + +original_torch_ones = torch.ones +def torch_ones(*args, device=None, **kwargs): + if check_device(device): + return original_torch_ones(*args, device=return_xpu(device), **kwargs) + else: + return original_torch_ones(*args, device=device, **kwargs) + +original_torch_zeros = torch.zeros +def torch_zeros(*args, device=None, **kwargs): + if check_device(device): + return original_torch_zeros(*args, device=return_xpu(device), **kwargs) + else: + return original_torch_zeros(*args, device=device, **kwargs) + +original_torch_linspace = torch.linspace +def torch_linspace(*args, device=None, **kwargs): + if check_device(device): + return original_torch_linspace(*args, device=return_xpu(device), **kwargs) + else: + return original_torch_linspace(*args, device=device, **kwargs) + +original_torch_Generator = torch.Generator +def torch_Generator(device=None): + if check_device(device): + return original_torch_Generator(return_xpu(device)) + else: + return original_torch_Generator(device) + +original_torch_load = torch.load +def torch_load(f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs): + if check_device(map_location): + return original_torch_load(f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs) + else: + return original_torch_load(f, map_location=map_location, pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs) + +# Hijack Functions: def ipex_hijacks(): - CondFunc('torch.tensor', - lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), - lambda orig_func, *args, device=None, **kwargs: check_device(device)) - CondFunc('torch.Tensor.to', - lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs), - lambda orig_func, self, device=None, *args, **kwargs: check_device(device)) - CondFunc('torch.Tensor.cuda', - lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs), - lambda orig_func, self, device=None, *args, **kwargs: check_device(device)) - CondFunc('torch.UntypedStorage.__init__', - lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), - lambda orig_func, *args, device=None, **kwargs: check_device(device)) - CondFunc('torch.UntypedStorage.cuda', - lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs), - lambda orig_func, self, device=None, *args, **kwargs: check_device(device)) - CondFunc('torch.empty', - lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), - lambda orig_func, *args, device=None, **kwargs: check_device(device)) - CondFunc('torch.randn', - lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), - lambda orig_func, *args, device=None, **kwargs: check_device(device)) - CondFunc('torch.ones', - lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), - lambda orig_func, *args, device=None, **kwargs: check_device(device)) - CondFunc('torch.zeros', - lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), - lambda orig_func, *args, device=None, **kwargs: check_device(device)) - CondFunc('torch.linspace', - lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), - lambda orig_func, *args, device=None, **kwargs: check_device(device)) - CondFunc('torch.load', - lambda orig_func, f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs: - orig_func(orig_func, f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs), - lambda orig_func, f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs: check_device(map_location)) + torch.tensor = torch_tensor + torch.Tensor.to = Tensor_to + torch.Tensor.cuda = Tensor_cuda + torch.UntypedStorage.__init__ = UntypedStorage_init + torch.UntypedStorage.cuda = UntypedStorage_cuda + torch.empty = torch_empty + torch.randn = torch_randn + torch.ones = torch_ones + torch.zeros = torch_zeros + torch.linspace = torch_linspace + torch.Generator = torch_Generator + torch.load = torch_load - CondFunc('torch.Generator', - lambda orig_func, device=None: torch.xpu.Generator(return_xpu(device)), - lambda orig_func, device=None: device is not None and device != torch.device("cpu") and device != "cpu") - - # TiledVAE and ControlNet: - CondFunc('torch.batch_norm', - lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input, - weight if weight is not None else torch.ones(input.size()[1], device=input.device), - bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs), - lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu")) - CondFunc('torch.instance_norm', - lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input, - weight if weight is not None else torch.ones(input.size()[1], device=input.device), - bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs), - lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu")) - - # Functions with dtype errors: - CondFunc('torch.nn.modules.GroupNorm.forward', - lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), - lambda orig_func, self, input: input.dtype != self.weight.data.dtype) - # Training: - CondFunc('torch.nn.modules.linear.Linear.forward', - lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), - lambda orig_func, self, input: input.dtype != self.weight.data.dtype) - CondFunc('torch.nn.modules.conv.Conv2d.forward', - lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), - lambda orig_func, self, input: input.dtype != self.weight.data.dtype) - # BF16: - CondFunc('torch.nn.functional.layer_norm', - lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: - orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs), - lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: - weight is not None and input.dtype != weight.data.dtype) - # SwinIR BF16: - CondFunc('torch.nn.functional.pad', - lambda orig_func, input, pad, mode='constant', value=None: orig_func(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16), - lambda orig_func, input, pad, mode='constant', value=None: mode == 'reflect' and input.dtype == torch.bfloat16) - - # Diffusers Float64 (Alchemist GPUs doesn't support 64 bit): - if not torch.xpu.has_fp64_dtype(): - CondFunc('torch.from_numpy', - lambda orig_func, ndarray: orig_func(ndarray.astype('float32')), - lambda orig_func, ndarray: ndarray.dtype == float) - - # Broken functions when torch.cuda.is_available is True: - # Pin Memory: - CondFunc('torch.utils.data.dataloader._BaseDataLoaderIter.__init__', - lambda orig_func, *args, **kwargs: ipex_no_cuda(orig_func, *args, **kwargs), - lambda orig_func, *args, **kwargs: True) - - # Functions that make compile mad with CondFunc: - torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers = _shutdown_workers - torch.nn.DataParallel = DummyDataParallel - torch.autocast = ipex_autocast - torch.cat = torch_cat - torch.linalg.solve = linalg_solve - torch.UntypedStorage.is_cuda = is_cuda - torch.nn.functional.interpolate = interpolate torch.backends.cuda.sdp_kernel = return_null_context + torch.nn.DataParallel = DummyDataParallel + torch.UntypedStorage.is_cuda = is_cuda + torch.autocast = ipex_autocast + + torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention + torch.nn.functional.group_norm = functional_group_norm + torch.nn.functional.layer_norm = functional_layer_norm + torch.nn.functional.linear = functional_linear + torch.nn.functional.conv2d = functional_conv2d + torch.nn.functional.interpolate = interpolate + torch.nn.functional.pad = functional_pad + + torch.bmm = torch_bmm + torch.cat = torch_cat + if not torch.xpu.has_fp64_dtype(): + torch.from_numpy = from_numpy diff --git a/library/lpw_stable_diffusion.py b/library/lpw_stable_diffusion.py index 9dce91a7..3963e9b1 100644 --- a/library/lpw_stable_diffusion.py +++ b/library/lpw_stable_diffusion.py @@ -9,7 +9,7 @@ import numpy as np import PIL.Image import torch from packaging import version -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection import diffusers from diffusers import SchedulerMixin, StableDiffusionPipeline @@ -520,6 +520,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline): safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, requires_safety_checker: bool = True, + image_encoder: CLIPVisionModelWithProjection = None, clip_skip: int = 1, ): super().__init__( @@ -531,32 +532,11 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline): safety_checker=safety_checker, feature_extractor=feature_extractor, requires_safety_checker=requires_safety_checker, + image_encoder=image_encoder, ) - self.clip_skip = clip_skip + self.custom_clip_skip = clip_skip self.__init__additional__() - # else: - # def __init__( - # self, - # vae: AutoencoderKL, - # text_encoder: CLIPTextModel, - # tokenizer: CLIPTokenizer, - # unet: UNet2DConditionModel, - # scheduler: SchedulerMixin, - # safety_checker: StableDiffusionSafetyChecker, - # feature_extractor: CLIPFeatureExtractor, - # ): - # super().__init__( - # vae=vae, - # text_encoder=text_encoder, - # tokenizer=tokenizer, - # unet=unet, - # scheduler=scheduler, - # safety_checker=safety_checker, - # feature_extractor=feature_extractor, - # ) - # self.__init__additional__() - def __init__additional__(self): if not hasattr(self, "vae_scale_factor"): setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1)) @@ -624,7 +604,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline): prompt=prompt, uncond_prompt=negative_prompt if do_classifier_free_guidance else None, max_embeddings_multiples=max_embeddings_multiples, - clip_skip=self.clip_skip, + clip_skip=self.custom_clip_skip, ) bs_embed, seq_len, _ = text_embeddings.shape text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) diff --git a/library/model_util.py b/library/model_util.py index 00a3c049..6102d0a1 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -4,10 +4,13 @@ import math import os import torch + try: import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): from library.ipex import ipex_init + ipex_init() except Exception: pass @@ -571,9 +574,9 @@ def convert_ldm_clip_checkpoint_v1(checkpoint): if key.startswith("cond_stage_model.transformer"): text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] - # support checkpoint without position_ids (invalid checkpoint) - if "text_model.embeddings.position_ids" not in text_model_dict: - text_model_dict["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0) # 77 is the max length of the text + # remove position_ids for newer transformer, which causes error :( + if "text_model.embeddings.position_ids" in text_model_dict: + text_model_dict.pop("text_model.embeddings.position_ids") return text_model_dict @@ -1307,19 +1310,19 @@ def load_vae(vae_id, dtype): def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64): max_width, max_height = max_reso - max_area = (max_width // divisible) * (max_height // divisible) + max_area = max_width * max_height resos = set() - size = int(math.sqrt(max_area)) * divisible - resos.add((size, size)) + width = int(math.sqrt(max_area) // divisible) * divisible + resos.add((width, width)) - size = min_size - while size <= max_size: - width = size - height = min(max_size, (max_area // (width // divisible)) * divisible) - resos.add((width, height)) - resos.add((height, width)) + width = min_size + while width <= max_size: + height = min(max_size, int((max_area // width) // divisible) * divisible) + if height >= min_size: + resos.add((width, height)) + resos.add((height, width)) # # make additional resos # if width >= height and width - divisible >= min_size: @@ -1329,7 +1332,7 @@ def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64) # resos.add((width, height - divisible)) # resos.add((height - divisible, width)) - size += divisible + width += divisible resos = list(resos) resos.sort() diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py index a844927c..08b90c39 100644 --- a/library/sdxl_model_util.py +++ b/library/sdxl_model_util.py @@ -100,7 +100,7 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length): key = key.replace(".ln_final", ".final_layer_norm") # ckpt from comfy has this key: text_model.encoder.text_model.embeddings.position_ids elif ".embeddings.position_ids" in key: - key = None # remove this key: make position_ids by ourselves + key = None # remove this key: position_ids is not used in newer transformers return key keys = list(checkpoint.keys()) @@ -126,10 +126,6 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length): new_sd[key_pfx + "k_proj" + key_suffix] = values[1] new_sd[key_pfx + "v_proj" + key_suffix] = values[2] - # original SD にはないので、position_idsを追加 - position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64) - new_sd["text_model.embeddings.position_ids"] = position_ids - # logit_scale はDiffusersには含まれないが、保存時に戻したいので別途返す logit_scale = checkpoint.get(SDXL_KEY_PREFIX + "logit_scale", None) @@ -265,9 +261,9 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty elif k.startswith("conditioner.embedders.1.model."): te2_sd[k] = state_dict.pop(k) - # 一部のposition_idsがないモデルへの対応 / add position_ids for some models - if "text_model.embeddings.position_ids" not in te1_sd: - te1_sd["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0) + # 最新の transformers では position_ids を含むとエラーになるので削除 / remove position_ids for latest transformers + if "text_model.embeddings.position_ids" in te1_sd: + te1_sd.pop("text_model.embeddings.position_ids") info1 = _load_state_dict_on_device(text_model1, te1_sd, device=map_location) # remain fp32 print("text encoder 1:", info1) diff --git a/library/train_util.py b/library/train_util.py index d2eb7cb2..e22afe1c 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -19,7 +19,7 @@ from typing import ( Tuple, Union, ) -from accelerate import Accelerator, InitProcessGroupKwargs +from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs import gc import glob import math @@ -351,6 +351,7 @@ class BaseSubset: shuffle_caption: bool, caption_separator: str, keep_tokens: int, + keep_tokens_separator: str, color_aug: bool, flip_aug: bool, face_crop_aug_range: Optional[Tuple[float, float]], @@ -368,6 +369,7 @@ class BaseSubset: self.shuffle_caption = shuffle_caption self.caption_separator = caption_separator self.keep_tokens = keep_tokens + self.keep_tokens_separator = keep_tokens_separator self.color_aug = color_aug self.flip_aug = flip_aug self.face_crop_aug_range = face_crop_aug_range @@ -395,6 +397,7 @@ class DreamBoothSubset(BaseSubset): shuffle_caption, caption_separator: str, keep_tokens, + keep_tokens_separator, color_aug, flip_aug, face_crop_aug_range, @@ -415,6 +418,7 @@ class DreamBoothSubset(BaseSubset): shuffle_caption, caption_separator, keep_tokens, + keep_tokens_separator, color_aug, flip_aug, face_crop_aug_range, @@ -449,6 +453,7 @@ class FineTuningSubset(BaseSubset): shuffle_caption, caption_separator, keep_tokens, + keep_tokens_separator, color_aug, flip_aug, face_crop_aug_range, @@ -469,6 +474,7 @@ class FineTuningSubset(BaseSubset): shuffle_caption, caption_separator, keep_tokens, + keep_tokens_separator, color_aug, flip_aug, face_crop_aug_range, @@ -500,6 +506,7 @@ class ControlNetSubset(BaseSubset): shuffle_caption, caption_separator, keep_tokens, + keep_tokens_separator, color_aug, flip_aug, face_crop_aug_range, @@ -520,6 +527,7 @@ class ControlNetSubset(BaseSubset): shuffle_caption, caption_separator, keep_tokens, + keep_tokens_separator, color_aug, flip_aug, face_crop_aug_range, @@ -654,15 +662,33 @@ class BaseDataset(torch.utils.data.Dataset): caption = "" else: if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0: - tokens = [t.strip() for t in caption.strip().split(subset.caption_separator)] + fixed_tokens = [] + flex_tokens = [] + if ( + hasattr(subset, "keep_tokens_separator") + and subset.keep_tokens_separator + and subset.keep_tokens_separator in caption + ): + fixed_part, flex_part = caption.split(subset.keep_tokens_separator, 1) + fixed_tokens = [t.strip() for t in fixed_part.split(subset.caption_separator) if t.strip()] + flex_tokens = [t.strip() for t in flex_part.split(subset.caption_separator) if t.strip()] + else: + tokens = [t.strip() for t in caption.strip().split(subset.caption_separator)] + flex_tokens = tokens[:] + if subset.keep_tokens > 0: + fixed_tokens = flex_tokens[: subset.keep_tokens] + flex_tokens = tokens[subset.keep_tokens :] + if subset.token_warmup_step < 1: # 初回に上書きする subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps) if subset.token_warmup_step and self.current_step < subset.token_warmup_step: tokens_len = ( - math.floor((self.current_step) * ((len(tokens) - subset.token_warmup_min) / (subset.token_warmup_step))) + math.floor( + (self.current_step) * ((len(flex_tokens) - subset.token_warmup_min) / (subset.token_warmup_step)) + ) + subset.token_warmup_min ) - tokens = tokens[:tokens_len] + flex_tokens = flex_tokens[:tokens_len] def dropout_tags(tokens): if subset.caption_tag_dropout_rate <= 0: @@ -673,12 +699,6 @@ class BaseDataset(torch.utils.data.Dataset): l.append(token) return l - fixed_tokens = [] - flex_tokens = tokens[:] - if subset.keep_tokens > 0: - fixed_tokens = flex_tokens[: subset.keep_tokens] - flex_tokens = tokens[subset.keep_tokens :] - if subset.shuffle_caption: random.shuffle(flex_tokens) @@ -1724,6 +1744,7 @@ class ControlNetDataset(BaseDataset): subset.shuffle_caption, subset.caption_separator, subset.keep_tokens, + subset.keep_tokens_separator, subset.color_aug, subset.flip_aug, subset.face_crop_aug_range, @@ -2827,6 +2848,17 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: action="store_true", help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う", ) + parser.add_argument("--torch_compile", action="store_true", help="use torch.compile (requires PyTorch 2.0) / torch.compile を使う") + parser.add_argument( + "--dynamo_backend", + type=str, + default="inductor", + # available backends: + # https://github.com/huggingface/accelerate/blob/d1abd59114ada8ba673e1214218cb2878c13b82d/src/accelerate/utils/dataclasses.py#L376-L388C5 + # https://pytorch.org/docs/stable/torch.compiler.html + choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt"], + help="dynamo backend type (default is inductor) / dynamoのbackendの種類(デフォルトは inductor)" + ) parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う") parser.add_argument( "--sdpa", @@ -2878,6 +2910,16 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: default=None, help="DDP timeout (min, None for default of accelerate) / DDPのタイムアウト(分、Noneでaccelerateのデフォルト)", ) + parser.add_argument( + "--ddp_gradient_as_bucket_view", + action="store_true", + help="enable gradient_as_bucket_view for DDP / DDPでgradient_as_bucket_viewを有効にする", + ) + parser.add_argument( + "--ddp_static_graph", + action="store_true", + help="enable static_graph for DDP / DDPでstatic_graphを有効にする", + ) parser.add_argument( "--clip_skip", type=int, @@ -3131,6 +3173,13 @@ def add_dataset_arguments( default=0, help="keep heading N tokens when shuffling caption tokens (token means comma separated strings) / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す(トークンはカンマ区切りの各部分を意味する)", ) + parser.add_argument( + "--keep_tokens_separator", + type=str, + default="", + help="A custom separator to divide the caption into fixed and flexible parts. Tokens before this separator will not be shuffled. If not specified, '--keep_tokens' will be used to determine the fixed number of tokens." + + " / captionを固定部分と可変部分に分けるためのカスタム区切り文字。この区切り文字より前のトークンはシャッフルされない。指定しない場合、'--keep_tokens'が固定部分のトークン数として使用される。", + ) parser.add_argument( "--caption_prefix", type=str, @@ -3831,16 +3880,26 @@ def prepare_accelerator(args: argparse.Namespace): os.environ["WANDB_DIR"] = logging_dir if args.wandb_api_key is not None: wandb.login(key=args.wandb_api_key) + + # torch.compile のオプション。 NO の場合は torch.compile は使わない + dynamo_backend = "NO" + if args.torch_compile: + dynamo_backend = args.dynamo_backend kwargs_handlers = ( - None if args.ddp_timeout is None else [InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout))] + InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout)) if args.ddp_timeout else None, + DistributedDataParallelKwargs(gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph) + if args.ddp_gradient_as_bucket_view or args.ddp_static_graph + else None, ) + kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers)) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, log_with=log_with, project_dir=logging_dir, kwargs_handlers=kwargs_handlers, + dynamo_backend=dynamo_backend, ) return accelerator @@ -4561,7 +4620,7 @@ def line_to_prompt_dict(line: str) -> dict: def sample_images_common( pipe_class, - accelerator, + accelerator: Accelerator, args: argparse.Namespace, epoch, steps, @@ -4598,6 +4657,13 @@ def sample_images_common( org_vae_device = vae.device # CPUにいるはず vae.to(device) + # unwrap unet and text_encoder(s) + unet = accelerator.unwrap_model(unet) + if isinstance(text_encoder, (list, tuple)): + text_encoder = [accelerator.unwrap_model(te) for te in text_encoder] + else: + text_encoder = accelerator.unwrap_model(text_encoder) + # read prompts # with open(args.sample_prompts, "rt", encoding="utf-8") as f: diff --git a/requirements.txt b/requirements.txt index c27131cd..8517d95a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,10 @@ -accelerate==0.23.0 -transformers==4.30.2 -diffusers[torch]==0.21.2 +accelerate==0.25.0 +transformers==4.36.2 +diffusers[torch]==0.25.0 ftfy==6.1.1 # albumentations==1.3.0 opencv-python==4.7.0.68 -einops==0.6.0 +einops==0.6.1 pytorch-lightning==1.9.0 # bitsandbytes==0.39.1 tensorboard==2.10.1 @@ -14,7 +14,7 @@ altair==4.2.2 easygui==0.98.3 toml==0.10.2 voluptuous==0.13.1 -huggingface-hub==0.15.1 +huggingface-hub==0.20.1 # for BLIP captioning # requests==2.28.2 # timm==0.6.12 diff --git a/sdxl_train.py b/sdxl_train.py index 501eef65..8983673d 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -398,6 +398,9 @@ def train(args): if train_unet: unet = accelerator.prepare(unet) if train_text_encoder1: + # freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer + text_encoder1.text_model.encoder.layers[-1].requires_grad_(False) + text_encoder1.text_model.final_layer_norm.requires_grad_(False) text_encoder1 = accelerator.prepare(text_encoder1) if train_text_encoder2: text_encoder2 = accelerator.prepare(text_encoder2) @@ -484,7 +487,7 @@ def train(args): # NaNが含まれていれば警告を表示し0に置き換える if torch.any(torch.isnan(latents)): accelerator.print("NaN found in latents, replacing with zeros") - latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents) + latents = torch.nan_to_num(latents, 0, out=latents) latents = latents * sdxl_model_util.VAE_SCALE_FACTOR if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index cb97859f..18c6bd05 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -394,7 +394,7 @@ def train(args): # NaNが含まれていれば警告を表示し0に置き換える if torch.any(torch.isnan(latents)): accelerator.print("NaN found in latents, replacing with zeros") - latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents) + latents = torch.nan_to_num(latents, 0, out=latents) latents = latents * sdxl_model_util.VAE_SCALE_FACTOR if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 87f30301..6ae5377b 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -363,7 +363,7 @@ def train(args): # NaNが含まれていれば警告を表示し0に置き換える if torch.any(torch.isnan(latents)): accelerator.print("NaN found in latents, replacing with zeros") - latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents) + latents = torch.nan_to_num(latents, 0, out=latents) latents = latents * sdxl_model_util.VAE_SCALE_FACTOR if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: diff --git a/tools/convert_diffusers20_original_sd.py b/tools/convert_diffusers20_original_sd.py index b9365b51..fe30996a 100644 --- a/tools/convert_diffusers20_original_sd.py +++ b/tools/convert_diffusers20_original_sd.py @@ -23,7 +23,7 @@ def convert(args): is_load_ckpt = os.path.isfile(args.model_to_load) is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 0 - assert not is_load_ckpt or args.v1 != args.v2, f"v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です" + assert not is_load_ckpt or args.v1 != args.v2, "v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です" # assert ( # is_save_ckpt or args.reference_model is not None # ), f"reference model is required to save as Diffusers / Diffusers形式での保存には参照モデルが必要です" @@ -34,10 +34,12 @@ def convert(args): if is_load_ckpt: v2_model = args.v2 - text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(v2_model, args.model_to_load, unet_use_linear_projection_in_v2=args.unet_use_linear_projection) + text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint( + v2_model, args.model_to_load, unet_use_linear_projection_in_v2=args.unet_use_linear_projection + ) else: pipe = StableDiffusionPipeline.from_pretrained( - args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None + args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None, variant=args.variant ) text_encoder = pipe.text_encoder vae = pipe.vae @@ -57,15 +59,26 @@ def convert(args): if is_save_ckpt: original_model = args.model_to_load if is_load_ckpt else None key_count = model_util.save_stable_diffusion_checkpoint( - v2_model, args.model_to_save, text_encoder, unet, original_model, args.epoch, args.global_step, save_dtype, vae + v2_model, + args.model_to_save, + text_encoder, + unet, + original_model, + args.epoch, + args.global_step, + None if args.metadata is None else eval(args.metadata), + save_dtype=save_dtype, + vae=vae, ) print(f"model saved. total converted state_dict keys: {key_count}") else: - print(f"copy scheduler/tokenizer config from: {args.reference_model if args.reference_model is not None else 'default model'}") + print( + f"copy scheduler/tokenizer config from: {args.reference_model if args.reference_model is not None else 'default model'}" + ) model_util.save_diffusers_checkpoint( v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors ) - print(f"model saved.") + print("model saved.") def setup_parser() -> argparse.ArgumentParser: @@ -77,7 +90,9 @@ def setup_parser() -> argparse.ArgumentParser: "--v2", action="store_true", help="load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む" ) parser.add_argument( - "--unet_use_linear_projection", action="store_true", help="When saving v2 model as Diffusers, set U-Net config to `use_linear_projection=true` (to match stabilityai's model) / Diffusers形式でv2モデルを保存するときにU-Netの設定を`use_linear_projection=true`にする(stabilityaiのモデルと合わせる)" + "--unet_use_linear_projection", + action="store_true", + help="When saving v2 model as Diffusers, set U-Net config to `use_linear_projection=true` (to match stabilityai's model) / Diffusers形式でv2モデルを保存するときにU-Netの設定を`use_linear_projection=true`にする(stabilityaiのモデルと合わせる)", ) parser.add_argument( "--fp16", @@ -99,6 +114,18 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--global_step", type=int, default=0, help="global_step to write to checkpoint / checkpointに記録するglobal_stepの値" ) + parser.add_argument( + "--metadata", + type=str, + default=None, + help='モデルに保存されるメタデータ、Pythonの辞書形式で指定 / metadata: metadata written in to the model in Python Dictionary. Example metadata: \'{"name": "model_name", "resolution": "512x512"}\'', + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="読む込むDiffusersのvariantを指定する、例: fp16 / variant: Diffusers variant to load. Example: fp16", + ) parser.add_argument( "--reference_model", type=str, diff --git a/train_network.py b/train_network.py index 378a3390..9cba78da 100644 --- a/train_network.py +++ b/train_network.py @@ -750,7 +750,7 @@ class NetworkTrainer: # NaNが含まれていれば警告を表示し0に置き換える if torch.any(torch.isnan(latents)): accelerator.print("NaN found in latents, replacing with zeros") - latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents) + latents = torch.nan_to_num(latents, 0, out=latents) latents = latents * self.vae_scale_factor b_size = latents.shape[0] diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 877ac838..545b6ba8 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -441,9 +441,10 @@ class TextualInversionTrainer: # Freeze all parameters except for the token embeddings in text encoder text_encoder.requires_grad_(True) - text_encoder.text_model.encoder.requires_grad_(False) - text_encoder.text_model.final_layer_norm.requires_grad_(False) - text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) + unwrapped_text_encoder = accelerator.unwrap_model(text_encoder) + unwrapped_text_encoder.text_model.encoder.requires_grad_(False) + unwrapped_text_encoder.text_model.final_layer_norm.requires_grad_(False) + unwrapped_text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True) unet.requires_grad_(False) @@ -603,7 +604,7 @@ class TextualInversionTrainer: accelerator.backward(loss) if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = text_encoder.get_input_embeddings().parameters() + params_to_clip = accelerator.unwrap_model(text_encoder).get_input_embeddings().parameters() accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() @@ -615,9 +616,11 @@ class TextualInversionTrainer: for text_encoder, orig_embeds_params, index_no_updates in zip( text_encoders, orig_embeds_params_list, index_no_updates_list ): - accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[ + # if full_fp16/bf16, input_embeddings_weight is fp16/bf16, orig_embeds_params is fp32 + input_embeddings_weight = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight + input_embeddings_weight[index_no_updates] = orig_embeds_params.to(input_embeddings_weight.dtype)[ index_no_updates - ] = orig_embeds_params[index_no_updates] + ] # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: