mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
Compare commits
6 Commits
dev
...
e721d21b17
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e721d21b17 | ||
|
|
3bbfa9b258 | ||
|
|
3420a6f7d1 | ||
|
|
45cab086cc | ||
|
|
f25cb8abd1 | ||
|
|
c15e6b4f3b |
@@ -50,9 +50,6 @@ Stable Diffusion等の画像生成モデルの学習、モデルによる画像
|
||||
|
||||
### 更新履歴
|
||||
|
||||
- 次のリリースに含まれる予定の主な変更点は以下の通りです。リリース前の変更点は予告なく変更される可能性があります。
|
||||
- Intel GPUの互換性を向上しました。[PR #2307](https://github.com/kohya-ss/sd-scripts/pull/2307) WhitePr氏に感謝します。
|
||||
|
||||
- **Version 0.10.3 (2026-04-02):**
|
||||
- Animaでfp16で学習する際の安定性をさらに改善しました。[PR #2302](https://github.com/kohya-ss/sd-scripts/pull/2302) 問題をご報告いただいた方々に深く感謝します。
|
||||
|
||||
|
||||
@@ -47,9 +47,6 @@ If you find this project helpful, please consider supporting its development via
|
||||
|
||||
### Change History
|
||||
|
||||
- The following are the main changes planned for the next release. Please note that these changes may be subject to change without notice before the release.
|
||||
- Improved compatibility with Intel GPUs. Thanks to WhitePr for [PR #2307](https://github.com/kohya-ss/sd-scripts/pull/2307).
|
||||
|
||||
- **Version 0.10.3 (2026-04-02):**
|
||||
- Stability when training with fp16 on Anima has been further improved. See [PR #2302](https://github.com/kohya-ss/sd-scripts/pull/2302) for details. We deeply appreciate those who reported the issue.
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
from packaging import version
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||
has_ipex = True
|
||||
@@ -9,7 +8,7 @@ except Exception:
|
||||
has_ipex = False
|
||||
from .hijacks import ipex_hijacks
|
||||
|
||||
torch_version = version.parse(torch.__version__)
|
||||
torch_version = float(torch.__version__[:3])
|
||||
|
||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||
|
||||
@@ -57,6 +56,7 @@ def ipex_init(): # pylint: disable=too-many-statements
|
||||
torch.cuda.__path__ = torch.xpu.__path__
|
||||
torch.cuda.set_stream = torch.xpu.set_stream
|
||||
torch.cuda.torch = torch.xpu.torch
|
||||
torch.cuda.Union = torch.xpu.Union
|
||||
torch.cuda.__annotations__ = torch.xpu.__annotations__
|
||||
torch.cuda.__package__ = torch.xpu.__package__
|
||||
torch.cuda.__builtins__ = torch.xpu.__builtins__
|
||||
@@ -64,12 +64,14 @@ def ipex_init(): # pylint: disable=too-many-statements
|
||||
torch.cuda.StreamContext = torch.xpu.StreamContext
|
||||
torch.cuda._lazy_call = torch.xpu._lazy_call
|
||||
torch.cuda.random = torch.xpu.random
|
||||
torch.cuda._device = torch.xpu._device
|
||||
torch.cuda.__name__ = torch.xpu.__name__
|
||||
torch.cuda._device_t = torch.xpu._device_t
|
||||
torch.cuda.__spec__ = torch.xpu.__spec__
|
||||
torch.cuda.__file__ = torch.xpu.__file__
|
||||
# torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
|
||||
|
||||
if torch_version < version.parse("2.3"):
|
||||
if torch_version < 2.3:
|
||||
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
|
||||
torch.cuda._initialized = torch.xpu.lazy_init._initialized
|
||||
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
|
||||
@@ -112,22 +114,17 @@ def ipex_init(): # pylint: disable=too-many-statements
|
||||
torch.cuda.threading = torch.xpu.threading
|
||||
torch.cuda.traceback = torch.xpu.traceback
|
||||
|
||||
if torch_version < version.parse("2.5"):
|
||||
if torch_version < 2.5:
|
||||
torch.cuda.os = torch.xpu.os
|
||||
torch.cuda.Device = torch.xpu.Device
|
||||
torch.cuda.warnings = torch.xpu.warnings
|
||||
torch.cuda.classproperty = torch.xpu.classproperty
|
||||
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
|
||||
|
||||
if torch_version < version.parse("2.7"):
|
||||
if torch_version < 2.7:
|
||||
torch.cuda.Tuple = torch.xpu.Tuple
|
||||
torch.cuda.List = torch.xpu.List
|
||||
|
||||
if torch_version < version.parse("2.11"):
|
||||
torch.cuda._device_t = torch.xpu._device_t
|
||||
torch.cuda._device = torch.xpu._device
|
||||
torch.cuda.Union = torch.xpu.Union
|
||||
|
||||
|
||||
# Memory:
|
||||
if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
|
||||
@@ -163,7 +160,7 @@ def ipex_init(): # pylint: disable=too-many-statements
|
||||
torch.cuda.initial_seed = torch.xpu.initial_seed
|
||||
|
||||
# C
|
||||
if torch_version < version.parse("2.3"):
|
||||
if torch_version < 2.3:
|
||||
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentRawStream
|
||||
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count
|
||||
ipex._C._DeviceProperties.major = 12
|
||||
|
||||
@@ -20,6 +20,7 @@
|
||||
# --------------------------------------------------------
|
||||
|
||||
import math
|
||||
import os
|
||||
from typing import List, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
@@ -31,6 +32,10 @@ import torch.nn.functional as F
|
||||
|
||||
from library import custom_offloading_utils
|
||||
|
||||
disable_selective_torch_compile = (
|
||||
os.getenv("SDSCRIPTS_SELECTIVE_TORCH_COMPILE", "0") == "0"
|
||||
)
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||
@@ -549,7 +554,7 @@ class JointAttention(nn.Module):
|
||||
f"Could not load flash attention. Please install flash_attn. / フラッシュアテンションを読み込めませんでした。flash_attn をインストールしてください。 / {e}"
|
||||
)
|
||||
|
||||
|
||||
@torch.compiler.disable(reason="complex ops inside")
|
||||
def apply_rope(
|
||||
x_in: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
@@ -625,13 +630,10 @@ class FeedForward(nn.Module):
|
||||
bias=False,
|
||||
)
|
||||
nn.init.xavier_uniform_(self.w3.weight)
|
||||
|
||||
# @torch.compile
|
||||
def _forward_silu_gating(self, x1, x3):
|
||||
return F.silu(x1) * x3
|
||||
|
||||
|
||||
@torch.compile(disable=disable_selective_torch_compile)
|
||||
def forward(self, x):
|
||||
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
|
||||
return self.w2(F.silu(self.w1(x))*self.w3(x))
|
||||
|
||||
|
||||
class JointTransformerBlock(GradientCheckpointMixin):
|
||||
@@ -697,6 +699,7 @@ class JointTransformerBlock(GradientCheckpointMixin):
|
||||
nn.init.zeros_(self.adaLN_modulation[1].weight)
|
||||
nn.init.zeros_(self.adaLN_modulation[1].bias)
|
||||
|
||||
@torch.compile(disable=disable_selective_torch_compile)
|
||||
def _forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
@@ -788,6 +791,7 @@ class FinalLayer(GradientCheckpointMixin):
|
||||
nn.init.zeros_(self.adaLN_modulation[1].weight)
|
||||
nn.init.zeros_(self.adaLN_modulation[1].bias)
|
||||
|
||||
@torch.compile(disable=disable_selective_torch_compile)
|
||||
def forward(self, x, c):
|
||||
scale = self.adaLN_modulation(c)
|
||||
x = modulate(self.norm_final(x), scale)
|
||||
@@ -808,6 +812,7 @@ class RopeEmbedder:
|
||||
self.axes_lens = axes_lens
|
||||
self.freqs_cis = NextDiT.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
|
||||
|
||||
@torch.compiler.disable(reason="complex ops inside")
|
||||
def __call__(self, ids: torch.Tensor):
|
||||
device = ids.device
|
||||
self.freqs_cis = [freqs_cis.to(ids.device) for freqs_cis in self.freqs_cis]
|
||||
@@ -1219,6 +1224,7 @@ class NextDiT(nn.Module):
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@torch.compiler.disable(reason="complex ops inside")
|
||||
def precompute_freqs_cis(
|
||||
dim: List[int],
|
||||
end: List[int],
|
||||
|
||||
@@ -3992,6 +3992,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
],
|
||||
help="dynamo backend type (default is inductor) / dynamoのbackendの種類(デフォルトは inductor)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--activation_memory_budget",
|
||||
type=float,
|
||||
default=None,
|
||||
help="activation memory budget setting for torch.compile (range: 0~1). Smaller value saves more memory at cost of speed. If set, use --torch_compile without --gradient_checkpointing is recommended. Requires PyTorch 2.4. / torch.compileのactivation memory budget設定(0~1の値)。この値を小さくするとメモリ使用量を節約できますが、処理速度は低下します。この設定を行う場合は、--gradient_checkpointing オプションを指定せずに --torch_compile を使用することをお勧めします。PyTorch 2.4以降が必要です。"
|
||||
)
|
||||
parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
|
||||
parser.add_argument(
|
||||
"--sdpa",
|
||||
@@ -5539,6 +5545,19 @@ def prepare_accelerator(args: argparse.Namespace):
|
||||
if args.torch_compile:
|
||||
dynamo_backend = args.dynamo_backend
|
||||
|
||||
if args.activation_memory_budget is not None: # Note: 0 is a valid value.
|
||||
if 0 <= args.activation_memory_budget <= 1:
|
||||
logger.info(
|
||||
f"set torch compile activation memory budget to {args.activation_memory_budget}"
|
||||
)
|
||||
torch._functorch.config.activation_memory_budget = ( # type: ignore
|
||||
args.activation_memory_budget
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"activation_memory_budget must be between 0 and 1 (inclusive)"
|
||||
)
|
||||
|
||||
kwargs_handlers = [
|
||||
(
|
||||
InitProcessGroupKwargs(
|
||||
|
||||
Reference in New Issue
Block a user