Compare commits

..

7 Commits

Author SHA1 Message Date
Kohya S.
308a0cc9fc Merge pull request #2312 from kohya-ss/dev
Merge dev to main
2026-04-07 08:53:13 +09:00
Kohya S
7e60e163c1 Merge branch 'main' into dev 2026-04-07 08:49:58 +09:00
Kohya S.
a8f5c222e0 Merge pull request #2311 from kohya-ss/doc-update-readme-for-next-release
README: Add planned changes for next release (intel GPU compatibility)
2026-04-07 08:47:37 +09:00
Kohya S
1d588d6cb6 README: Add planned changes for next release and improve Intel GPU compatibility 2026-04-07 08:44:31 +09:00
Kohya S.
a7d35701a0 Merge pull request #2307 from WhitePr/dev
update ipex
2026-04-07 08:41:41 +09:00
WhitePr
8da05a10dc Update IPEX libs 2026-04-04 05:37:18 +09:00
WhitePr
197b129284 Modifying the method for get the Torch version 2026-04-04 04:44:53 +09:00
4 changed files with 83 additions and 44 deletions

View File

@@ -50,6 +50,9 @@ Stable Diffusion等の画像生成モデルの学習、モデルによる画像
### 更新履歴 ### 更新履歴
- 次のリリースに含まれる予定の主な変更点は以下の通りです。リリース前の変更点は予告なく変更される可能性があります。
- Intel GPUの互換性を向上しました。[PR #2307](https://github.com/kohya-ss/sd-scripts/pull/2307) WhitePr氏に感謝します。
- **Version 0.10.3 (2026-04-02):** - **Version 0.10.3 (2026-04-02):**
- Animaでfp16で学習する際の安定性をさらに改善しました。[PR #2302](https://github.com/kohya-ss/sd-scripts/pull/2302) 問題をご報告いただいた方々に深く感謝します。 - Animaでfp16で学習する際の安定性をさらに改善しました。[PR #2302](https://github.com/kohya-ss/sd-scripts/pull/2302) 問題をご報告いただいた方々に深く感謝します。

View File

@@ -47,6 +47,9 @@ If you find this project helpful, please consider supporting its development via
### Change History ### Change History
- The following are the main changes planned for the next release. Please note that these changes may be subject to change without notice before the release.
- Improved compatibility with Intel GPUs. Thanks to WhitePr for [PR #2307](https://github.com/kohya-ss/sd-scripts/pull/2307).
- **Version 0.10.3 (2026-04-02):** - **Version 0.10.3 (2026-04-02):**
- Stability when training with fp16 on Anima has been further improved. See [PR #2302](https://github.com/kohya-ss/sd-scripts/pull/2302) for details. We deeply appreciate those who reported the issue. - Stability when training with fp16 on Anima has been further improved. See [PR #2302](https://github.com/kohya-ss/sd-scripts/pull/2302) for details. We deeply appreciate those who reported the issue.

View File

@@ -542,20 +542,10 @@ class PipelineLike:
uncond_embeddings = torch.cat([uncond_embeddings, tes_uncond_embs[i]], dim=2) # n,77,2048 uncond_embeddings = torch.cat([uncond_embeddings, tes_uncond_embs[i]], dim=2) # n,77,2048
if do_classifier_free_guidance: if do_classifier_free_guidance:
lcm = uncond_embeddings.shape[1] * text_embeddings.shape[1] // math.gcd(uncond_embeddings.shape[1], text_embeddings.shape[1])
if negative_scale is None: if negative_scale is None:
text_embeddings = torch.cat([ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
uncond_embeddings.repeat(1, lcm // uncond_embeddings.shape[1], 1),
text_embeddings.repeat(1, lcm // text_embeddings.shape[1], 1),
])
else: else:
lcm = real_uncond_embeddings.shape[1] * text_embeddings.shape[1] // math.gcd(real_uncond_embeddings.shape[1], text_embeddings.shape[1]) text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings])
text_embeddings = torch.cat([
uncond_embeddings.repeat(1, lcm // uncond_embeddings.shape[1], 1),
text_embeddings.repeat(1, lcm // text_embeddings.shape[1], 1),
real_uncond_embeddings.repeat(1, lcm // real_uncond_embeddings.shape[1], 1)
])
if self.control_net_lllites or (self.control_nets and self.is_sdxl): if self.control_net_lllites or (self.control_nets and self.is_sdxl):
# ControlNetのhintにguide imageを流用する。ControlNetの場合はControlNet側で行う # ControlNetのhintにguide imageを流用する。ControlNetの場合はControlNet側で行う
@@ -1115,17 +1105,22 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos
""" """
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
lcm = chunk_length
for i in range(len(tokens)): for i in range(len(tokens)):
target_length = ((len(tokens[i]) + 2) // chunk_length + 1) * chunk_length tokens[i] = [bos] + tokens[i] + [eos] + [pad] * (max_length - 2 - len(tokens[i]))
lcm = target_length * lcm // math.gcd(target_length, lcm) if no_boseos_middle:
tokens[i] = [bos] + tokens[i] + [eos] + [pad] * (target_length - 2 - len(tokens[i])) weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
weights[i] = [1.0] + weights[i] + [1.0] * (target_length - 1 - len(weights[i])) else:
w = []
for i in range(len(tokens)): if len(weights[i]) == 0:
tokens[i] = tokens[i] * (lcm // len(tokens[i])) w = [1.0] * weights_length
weights[i] = weights[i] * (lcm // len(weights[i])) else:
for j in range(max_embeddings_multiples):
w.append(1.0) # weight for starting token in this chunk
w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
w.append(1.0) # weight for ending token in this chunk
w += [1.0] * (weights_length - len(w))
weights[i] = w[:]
return tokens, weights return tokens, weights
@@ -1143,21 +1138,56 @@ def get_unweighted_text_embeddings(
When the length of tokens is a multiple of the capacity of the text encoder, When the length of tokens is a multiple of the capacity of the text encoder,
it should be split into chunks and sent to the text encoder individually. it should be split into chunks and sent to the text encoder individually.
""" """
pool = None max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
text_embeddings = [] if max_embeddings_multiples > 1:
text_embeddings = []
for chunk in text_input.chunk(text_input.shape[1] // chunk_length, dim=1): pool = None
enc_out = text_encoder(chunk, output_hidden_states=True, return_dict=True) for i in range(max_embeddings_multiples):
text_embedding = enc_out["hidden_states"][-clip_skip] # extract the i-th chunk
text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
# cover the head and the tail by the starting and the ending tokens
text_input_chunk[:, 0] = text_input[0, 0]
if pad == eos: # v1
text_input_chunk[:, -1] = text_input[0, -1]
else: # v2
for j in range(len(text_input_chunk)):
if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
text_input_chunk[j, -1] = eos
if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
text_input_chunk[j, 1] = eos
# in sdxl, value of clip_skip is same for Text Encoder 1 and 2
enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
text_embedding = enc_out["hidden_states"][-clip_skip]
if not is_sdxl: # SD 1.5 requires final_layer_norm
text_embedding = text_encoder.text_model.final_layer_norm(text_embedding)
if pool is None:
pool = enc_out.get("text_embeds", None) # use 1st chunk, if provided
if pool is not None:
pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], text_input_chunk, eos)
if no_boseos_middle:
if i == 0:
# discard the ending token
text_embedding = text_embedding[:, :-1]
elif i == max_embeddings_multiples - 1:
# discard the starting token
text_embedding = text_embedding[:, 1:]
else:
# discard both starting and ending tokens
text_embedding = text_embedding[:, 1:-1]
text_embeddings.append(text_embedding)
text_embeddings = torch.concat(text_embeddings, axis=1)
else:
enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True)
text_embeddings = enc_out["hidden_states"][-clip_skip]
if not is_sdxl: # SD 1.5 requires final_layer_norm if not is_sdxl: # SD 1.5 requires final_layer_norm
text_embedding = text_encoder.text_model.final_layer_norm(text_embedding) text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings)
text_embeddings.append(text_embedding) pool = enc_out.get("text_embeds", None) # text encoder 1 doesn't return this
if pool is None: if pool is not None:
pool = enc_out.get("text_embeds", None) # text encoder 1 doesn't return this pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], text_input, eos)
if pool is not None:
pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], text_input, eos)
text_embeddings = torch.cat(text_embeddings, dim=1)
return text_embeddings, pool return text_embeddings, pool

View File

@@ -1,6 +1,7 @@
import os import os
import sys import sys
import torch import torch
from packaging import version
try: try:
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
has_ipex = True has_ipex = True
@@ -8,7 +9,7 @@ except Exception:
has_ipex = False has_ipex = False
from .hijacks import ipex_hijacks from .hijacks import ipex_hijacks
torch_version = float(torch.__version__[:3]) torch_version = version.parse(torch.__version__)
# pylint: disable=protected-access, missing-function-docstring, line-too-long # pylint: disable=protected-access, missing-function-docstring, line-too-long
@@ -56,7 +57,6 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.__path__ = torch.xpu.__path__ torch.cuda.__path__ = torch.xpu.__path__
torch.cuda.set_stream = torch.xpu.set_stream torch.cuda.set_stream = torch.xpu.set_stream
torch.cuda.torch = torch.xpu.torch torch.cuda.torch = torch.xpu.torch
torch.cuda.Union = torch.xpu.Union
torch.cuda.__annotations__ = torch.xpu.__annotations__ torch.cuda.__annotations__ = torch.xpu.__annotations__
torch.cuda.__package__ = torch.xpu.__package__ torch.cuda.__package__ = torch.xpu.__package__
torch.cuda.__builtins__ = torch.xpu.__builtins__ torch.cuda.__builtins__ = torch.xpu.__builtins__
@@ -64,14 +64,12 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.StreamContext = torch.xpu.StreamContext torch.cuda.StreamContext = torch.xpu.StreamContext
torch.cuda._lazy_call = torch.xpu._lazy_call torch.cuda._lazy_call = torch.xpu._lazy_call
torch.cuda.random = torch.xpu.random torch.cuda.random = torch.xpu.random
torch.cuda._device = torch.xpu._device
torch.cuda.__name__ = torch.xpu.__name__ torch.cuda.__name__ = torch.xpu.__name__
torch.cuda._device_t = torch.xpu._device_t
torch.cuda.__spec__ = torch.xpu.__spec__ torch.cuda.__spec__ = torch.xpu.__spec__
torch.cuda.__file__ = torch.xpu.__file__ torch.cuda.__file__ = torch.xpu.__file__
# torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing # torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
if torch_version < 2.3: if torch_version < version.parse("2.3"):
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
torch.cuda._initialized = torch.xpu.lazy_init._initialized torch.cuda._initialized = torch.xpu.lazy_init._initialized
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
@@ -114,17 +112,22 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.threading = torch.xpu.threading torch.cuda.threading = torch.xpu.threading
torch.cuda.traceback = torch.xpu.traceback torch.cuda.traceback = torch.xpu.traceback
if torch_version < 2.5: if torch_version < version.parse("2.5"):
torch.cuda.os = torch.xpu.os torch.cuda.os = torch.xpu.os
torch.cuda.Device = torch.xpu.Device torch.cuda.Device = torch.xpu.Device
torch.cuda.warnings = torch.xpu.warnings torch.cuda.warnings = torch.xpu.warnings
torch.cuda.classproperty = torch.xpu.classproperty torch.cuda.classproperty = torch.xpu.classproperty
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
if torch_version < 2.7: if torch_version < version.parse("2.7"):
torch.cuda.Tuple = torch.xpu.Tuple torch.cuda.Tuple = torch.xpu.Tuple
torch.cuda.List = torch.xpu.List torch.cuda.List = torch.xpu.List
if torch_version < version.parse("2.11"):
torch.cuda._device_t = torch.xpu._device_t
torch.cuda._device = torch.xpu._device
torch.cuda.Union = torch.xpu.Union
# Memory: # Memory:
if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read(): if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
@@ -160,7 +163,7 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.initial_seed = torch.xpu.initial_seed torch.cuda.initial_seed = torch.xpu.initial_seed
# C # C
if torch_version < 2.3: if torch_version < version.parse("2.3"):
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentRawStream torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentRawStream
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count
ipex._C._DeviceProperties.major = 12 ipex._C._DeviceProperties.major = 12