Compare commits

..

9 Commits

Author SHA1 Message Date
Kohya S.
308a0cc9fc Merge pull request #2312 from kohya-ss/dev
Merge dev to main
2026-04-07 08:53:13 +09:00
Kohya S
7e60e163c1 Merge branch 'main' into dev 2026-04-07 08:49:58 +09:00
Kohya S.
a8f5c222e0 Merge pull request #2311 from kohya-ss/doc-update-readme-for-next-release
README: Add planned changes for next release (intel GPU compatibility)
2026-04-07 08:47:37 +09:00
Kohya S
1d588d6cb6 README: Add planned changes for next release and improve Intel GPU compatibility 2026-04-07 08:44:31 +09:00
Kohya S.
a7d35701a0 Merge pull request #2307 from WhitePr/dev
update ipex
2026-04-07 08:41:41 +09:00
WhitePr
8da05a10dc Update IPEX libs 2026-04-04 05:37:18 +09:00
WhitePr
197b129284 Modifying the method for get the Torch version 2026-04-04 04:44:53 +09:00
Kohya S.
51435f1718 Merge pull request #2303 from kohya-ss/sd3
fix: improve numerical stability by conditionally using float32 in Anima with fp16 training
2026-04-02 12:40:48 +09:00
Kohya S.
fa53f71ec0 fix: improve numerical stability by conditionally using float32 in Anima (#2302)
* fix: improve numerical stability by conditionally using float32 in block computations

* doc: update README for improvement stability for fp16 training on Anima in version 0.10.3
2026-04-02 12:36:29 +09:00
6 changed files with 41 additions and 165 deletions

View File

@@ -50,6 +50,12 @@ Stable Diffusion等の画像生成モデルの学習、モデルによる画像
### 更新履歴 ### 更新履歴
- 次のリリースに含まれる予定の主な変更点は以下の通りです。リリース前の変更点は予告なく変更される可能性があります。
- Intel GPUの互換性を向上しました。[PR #2307](https://github.com/kohya-ss/sd-scripts/pull/2307) WhitePr氏に感謝します。
- **Version 0.10.3 (2026-04-02):**
- Animaでfp16で学習する際の安定性をさらに改善しました。[PR #2302](https://github.com/kohya-ss/sd-scripts/pull/2302) 問題をご報告いただいた方々に深く感謝します。
- **Version 0.10.2 (2026-03-30):** - **Version 0.10.2 (2026-03-30):**
- SD/SDXLのLECO学習に対応しました。[PR #2285](https://github.com/kohya-ss/sd-scripts/pull/2285) および [PR #2294](https://github.com/kohya-ss/sd-scripts/pull/2294) umisetokikaze氏に深く感謝します。 - SD/SDXLのLECO学習に対応しました。[PR #2285](https://github.com/kohya-ss/sd-scripts/pull/2285) および [PR #2294](https://github.com/kohya-ss/sd-scripts/pull/2294) umisetokikaze氏に深く感謝します。
- 詳細は[ドキュメント](./docs/train_leco.md)をご覧ください。 - 詳細は[ドキュメント](./docs/train_leco.md)をご覧ください。

View File

@@ -47,6 +47,12 @@ If you find this project helpful, please consider supporting its development via
### Change History ### Change History
- The following are the main changes planned for the next release. Please note that these changes may be subject to change without notice before the release.
- Improved compatibility with Intel GPUs. Thanks to WhitePr for [PR #2307](https://github.com/kohya-ss/sd-scripts/pull/2307).
- **Version 0.10.3 (2026-04-02):**
- Stability when training with fp16 on Anima has been further improved. See [PR #2302](https://github.com/kohya-ss/sd-scripts/pull/2302) for details. We deeply appreciate those who reported the issue.
- **Version 0.10.2 (2026-03-30):** - **Version 0.10.2 (2026-03-30):**
- LECO training for SD/SDXL is now supported. Many thanks to umisetokikaze for [PR #2285](https://github.com/kohya-ss/sd-scripts/pull/2285) and [PR #2294](https://github.com/kohya-ss/sd-scripts/pull/2294). - LECO training for SD/SDXL is now supported. Many thanks to umisetokikaze for [PR #2285](https://github.com/kohya-ss/sd-scripts/pull/2285) and [PR #2294](https://github.com/kohya-ss/sd-scripts/pull/2294).
- Please refer to the [documentation](./docs/train_leco.md) for details. - Please refer to the [documentation](./docs/train_leco.md) for details.

View File

@@ -738,9 +738,9 @@ class FinalLayer(nn.Module):
x_B_T_H_W_D: torch.Tensor, x_B_T_H_W_D: torch.Tensor,
emb_B_T_D: torch.Tensor, emb_B_T_D: torch.Tensor,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None, adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
use_fp32: bool = False,
): ):
# Compute AdaLN modulation parameters (in float32 when fp16 to avoid overflow in Linear layers) # Compute AdaLN modulation parameters (in float32 when fp16 to avoid overflow in Linear layers)
use_fp32 = x_B_T_H_W_D.dtype == torch.float16
with torch.autocast(device_type=x_B_T_H_W_D.device.type, dtype=torch.float32, enabled=use_fp32): with torch.autocast(device_type=x_B_T_H_W_D.device.type, dtype=torch.float32, enabled=use_fp32):
if self.use_adaln_lora: if self.use_adaln_lora:
assert adaln_lora_B_T_3D is not None assert adaln_lora_B_T_3D is not None
@@ -863,11 +863,11 @@ class Block(nn.Module):
emb_B_T_D: torch.Tensor, emb_B_T_D: torch.Tensor,
crossattn_emb: torch.Tensor, crossattn_emb: torch.Tensor,
attn_params: attention.AttentionParams, attn_params: attention.AttentionParams,
use_fp32: bool = False,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None, rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None, adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
extra_per_block_pos_emb: Optional[torch.Tensor] = None, extra_per_block_pos_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
use_fp32 = x_B_T_H_W_D.dtype == torch.float16
if use_fp32: if use_fp32:
# Cast to float32 for better numerical stability in residual connections. Each module will cast back to float16 by enclosing autocast context. # Cast to float32 for better numerical stability in residual connections. Each module will cast back to float16 by enclosing autocast context.
x_B_T_H_W_D = x_B_T_H_W_D.float() x_B_T_H_W_D = x_B_T_H_W_D.float()
@@ -959,6 +959,7 @@ class Block(nn.Module):
emb_B_T_D: torch.Tensor, emb_B_T_D: torch.Tensor,
crossattn_emb: torch.Tensor, crossattn_emb: torch.Tensor,
attn_params: attention.AttentionParams, attn_params: attention.AttentionParams,
use_fp32: bool = False,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None, rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None, adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
extra_per_block_pos_emb: Optional[torch.Tensor] = None, extra_per_block_pos_emb: Optional[torch.Tensor] = None,
@@ -972,6 +973,7 @@ class Block(nn.Module):
emb_B_T_D, emb_B_T_D,
crossattn_emb, crossattn_emb,
attn_params, attn_params,
use_fp32,
rope_emb_L_1_1_D, rope_emb_L_1_1_D,
adaln_lora_B_T_3D, adaln_lora_B_T_3D,
extra_per_block_pos_emb, extra_per_block_pos_emb,
@@ -994,6 +996,7 @@ class Block(nn.Module):
emb_B_T_D, emb_B_T_D,
crossattn_emb, crossattn_emb,
attn_params, attn_params,
use_fp32,
rope_emb_L_1_1_D, rope_emb_L_1_1_D,
adaln_lora_B_T_3D, adaln_lora_B_T_3D,
extra_per_block_pos_emb, extra_per_block_pos_emb,
@@ -1007,6 +1010,7 @@ class Block(nn.Module):
emb_B_T_D, emb_B_T_D,
crossattn_emb, crossattn_emb,
attn_params, attn_params,
use_fp32,
rope_emb_L_1_1_D, rope_emb_L_1_1_D,
adaln_lora_B_T_3D, adaln_lora_B_T_3D,
extra_per_block_pos_emb, extra_per_block_pos_emb,
@@ -1018,6 +1022,7 @@ class Block(nn.Module):
emb_B_T_D, emb_B_T_D,
crossattn_emb, crossattn_emb,
attn_params, attn_params,
use_fp32,
rope_emb_L_1_1_D, rope_emb_L_1_1_D,
adaln_lora_B_T_3D, adaln_lora_B_T_3D,
extra_per_block_pos_emb, extra_per_block_pos_emb,
@@ -1338,16 +1343,19 @@ class Anima(nn.Module):
attn_params = attention.AttentionParams.create_attention_params(self.attn_mode, self.split_attn) attn_params = attention.AttentionParams.create_attention_params(self.attn_mode, self.split_attn)
# Determine whether to use float32 for block computations based on input dtype (use float32 for better stability when input is float16)
use_fp32 = x_B_T_H_W_D.dtype == torch.float16
for block_idx, block in enumerate(self.blocks): for block_idx, block in enumerate(self.blocks):
if self.blocks_to_swap: if self.blocks_to_swap:
self.offloader.wait_for_block(block_idx) self.offloader.wait_for_block(block_idx)
x_B_T_H_W_D = block(x_B_T_H_W_D, t_embedding_B_T_D, crossattn_emb, attn_params, **block_kwargs) x_B_T_H_W_D = block(x_B_T_H_W_D, t_embedding_B_T_D, crossattn_emb, attn_params, use_fp32, **block_kwargs)
if self.blocks_to_swap: if self.blocks_to_swap:
self.offloader.submit_move_blocks(self.blocks, block_idx) self.offloader.submit_move_blocks(self.blocks, block_idx)
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D) x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D, use_fp32=use_fp32)
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O) x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)
return x_B_C_Tt_Hp_Wp return x_B_C_Tt_Hp_Wp

View File

@@ -1,7 +1,5 @@
import json import json
import os import os
from pathlib import Path
import re
from dataclasses import replace from dataclasses import replace
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
@@ -28,63 +26,6 @@ MODEL_NAME_SCHNELL = "schnell"
MODEL_VERSION_CHROMA = "chroma" MODEL_VERSION_CHROMA = "chroma"
def get_checkpoint_paths(ckpt_path: str | Path):
"""
Get checkpoint paths for flux models
- huggingface directory structure
- huggingface sharded safetensors files
- in transformer directory
- plain directory
- single safetensor files
"""
if not isinstance(ckpt_path, Path):
# Convert to Path object
ckpt_path = Path(ckpt_path)
# If ckpt_path is a directory
if ckpt_path.is_dir():
# List to store potential checkpoint paths
potential_paths = []
# Check for files directly in the directory
potential_paths.extend(ckpt_path.glob('*.safetensors'))
# Check for files in the transformer subdirectory
transformer_path = ckpt_path / 'transformer'
if transformer_path.is_dir():
potential_paths.extend(transformer_path.glob('*.safetensors'))
# Filter and expand multi-part checkpoint paths
checkpoint_paths = []
for path in potential_paths:
# If it's a multi-part checkpoint
if '-of-' in path.name:
# Use regex to extract parts
match = re.search(r'(.+?)-(\d+)-of-(\d+)', path.name)
if match:
base_name, current_part, total_parts = match.groups()
# Generate all part paths
part_paths = [
path.with_name(f'{base_name}-{i:05d}-of-{int(total_parts):05d}.safetensors')
for i in range(1, int(total_parts) + 1)
]
checkpoint_paths.extend(part_paths)
else:
# Single file checkpoint
checkpoint_paths.append(path)
# Remove duplicates while preserving order
checkpoint_paths = list(dict.fromkeys(checkpoint_paths))
else:
# If ckpt_path is a single file
checkpoint_paths = [ckpt_path]
return checkpoint_paths
def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]: def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]:
""" """
チェックポイントの状態を分析し、DiffusersかBFLか、devかschnellか、ブロック数を計算して返す。 チェックポイントの状態を分析し、DiffusersかBFLか、devかschnellか、ブロック数を計算して返す。
@@ -102,7 +43,12 @@ def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int
# check the state dict: Diffusers or BFL, dev or schnell, number of blocks # check the state dict: Diffusers or BFL, dev or schnell, number of blocks
logger.info(f"Checking the state dict: Diffusers or BFL, dev or schnell") logger.info(f"Checking the state dict: Diffusers or BFL, dev or schnell")
ckpt_paths = get_checkpoint_paths(ckpt_path) if os.path.isdir(ckpt_path): # if ckpt_path is a directory, it is Diffusers
ckpt_path = os.path.join(ckpt_path, "transformer", "diffusion_pytorch_model-00001-of-00003.safetensors")
if "00001-of-00003" in ckpt_path:
ckpt_paths = [ckpt_path.replace("00001-of-00003", f"0000{i}-of-00003") for i in range(1, 4)]
else:
ckpt_paths = [ckpt_path]
keys = [] keys = []
for ckpt_path in ckpt_paths: for ckpt_path in ckpt_paths:

View File

@@ -1,6 +1,7 @@
import os import os
import sys import sys
import torch import torch
from packaging import version
try: try:
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
has_ipex = True has_ipex = True
@@ -8,7 +9,7 @@ except Exception:
has_ipex = False has_ipex = False
from .hijacks import ipex_hijacks from .hijacks import ipex_hijacks
torch_version = float(torch.__version__[:3]) torch_version = version.parse(torch.__version__)
# pylint: disable=protected-access, missing-function-docstring, line-too-long # pylint: disable=protected-access, missing-function-docstring, line-too-long
@@ -56,7 +57,6 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.__path__ = torch.xpu.__path__ torch.cuda.__path__ = torch.xpu.__path__
torch.cuda.set_stream = torch.xpu.set_stream torch.cuda.set_stream = torch.xpu.set_stream
torch.cuda.torch = torch.xpu.torch torch.cuda.torch = torch.xpu.torch
torch.cuda.Union = torch.xpu.Union
torch.cuda.__annotations__ = torch.xpu.__annotations__ torch.cuda.__annotations__ = torch.xpu.__annotations__
torch.cuda.__package__ = torch.xpu.__package__ torch.cuda.__package__ = torch.xpu.__package__
torch.cuda.__builtins__ = torch.xpu.__builtins__ torch.cuda.__builtins__ = torch.xpu.__builtins__
@@ -64,14 +64,12 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.StreamContext = torch.xpu.StreamContext torch.cuda.StreamContext = torch.xpu.StreamContext
torch.cuda._lazy_call = torch.xpu._lazy_call torch.cuda._lazy_call = torch.xpu._lazy_call
torch.cuda.random = torch.xpu.random torch.cuda.random = torch.xpu.random
torch.cuda._device = torch.xpu._device
torch.cuda.__name__ = torch.xpu.__name__ torch.cuda.__name__ = torch.xpu.__name__
torch.cuda._device_t = torch.xpu._device_t
torch.cuda.__spec__ = torch.xpu.__spec__ torch.cuda.__spec__ = torch.xpu.__spec__
torch.cuda.__file__ = torch.xpu.__file__ torch.cuda.__file__ = torch.xpu.__file__
# torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing # torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
if torch_version < 2.3: if torch_version < version.parse("2.3"):
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
torch.cuda._initialized = torch.xpu.lazy_init._initialized torch.cuda._initialized = torch.xpu.lazy_init._initialized
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
@@ -114,17 +112,22 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.threading = torch.xpu.threading torch.cuda.threading = torch.xpu.threading
torch.cuda.traceback = torch.xpu.traceback torch.cuda.traceback = torch.xpu.traceback
if torch_version < 2.5: if torch_version < version.parse("2.5"):
torch.cuda.os = torch.xpu.os torch.cuda.os = torch.xpu.os
torch.cuda.Device = torch.xpu.Device torch.cuda.Device = torch.xpu.Device
torch.cuda.warnings = torch.xpu.warnings torch.cuda.warnings = torch.xpu.warnings
torch.cuda.classproperty = torch.xpu.classproperty torch.cuda.classproperty = torch.xpu.classproperty
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
if torch_version < 2.7: if torch_version < version.parse("2.7"):
torch.cuda.Tuple = torch.xpu.Tuple torch.cuda.Tuple = torch.xpu.Tuple
torch.cuda.List = torch.xpu.List torch.cuda.List = torch.xpu.List
if torch_version < version.parse("2.11"):
torch.cuda._device_t = torch.xpu._device_t
torch.cuda._device = torch.xpu._device
torch.cuda.Union = torch.xpu.Union
# Memory: # Memory:
if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read(): if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
@@ -160,7 +163,7 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.initial_seed = torch.xpu.initial_seed torch.cuda.initial_seed = torch.xpu.initial_seed
# C # C
if torch_version < 2.3: if torch_version < version.parse("2.3"):
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentRawStream torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentRawStream
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count
ipex._C._DeviceProperties.major = 12 ipex._C._DeviceProperties.major = 12

View File

@@ -1,93 +0,0 @@
import pytest
from pathlib import Path
import tempfile
from library.flux_utils import get_checkpoint_paths
def test_get_checkpoint_paths():
# Create a temporary directory for testing
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
# Scenario 1: Single safetensors file in root directory
single_file = temp_path / "model.safetensors"
single_file.touch()
paths = get_checkpoint_paths(str(single_file))
assert len(paths) == 1
assert paths[0] == single_file
def test_multiple_root_checkpoint_paths():
"""
Multiple single safetensors files in root directory
"""
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
# Scenario 2:
file1 = temp_path / "model1.safetensors"
file2 = temp_path / "model2.safetensors"
file1.touch()
file2.touch()
paths = get_checkpoint_paths(temp_path)
assert len(paths) == 2
assert set(paths) == {file1, file2}
def test_multipart_sharded_checkpoint():
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
# Scenario 3: Sharded multi-part checkpoint
# Create sharded checkpoint files
base_name = "diffusion_pytorch_model"
total_parts = 3
for i in range(1, total_parts + 1):
(temp_path / f"{base_name}-{i:05d}-of-{total_parts:05d}.safetensors").touch()
paths = get_checkpoint_paths(temp_path)
assert len(paths) == total_parts
# Check if all expected part paths are present
expected_paths = [temp_path / f"{base_name}-{i:05d}-of-{total_parts:05d}.safetensors" for i in range(1, total_parts + 1)]
assert set(paths) == set(expected_paths)
def test_transformer_model_dir():
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
transformer_dir = temp_path / "transformer"
transformer_dir.mkdir()
transformer_file = transformer_dir / "diffusion_pytorch_model.safetensors"
transformer_file.touch()
paths = get_checkpoint_paths(temp_path)
assert transformer_file in paths
def test_mixed_files_sharded_checkpoints():
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
# Scenario 5: Mixed files and sharded checkpoints
mixed_dir = temp_path / "mixed"
mixed_dir.mkdir()
# Create a single file
(mixed_dir / "single_model.safetensors").touch()
# Create sharded checkpoint
base_name = "diffusion_pytorch_model"
total_parts = 2
for i in range(1, total_parts + 1):
(mixed_dir / f"{base_name}-{i:05d}-of-{total_parts:05d}.safetensors").touch()
paths = get_checkpoint_paths(mixed_dir)
assert len(paths) == total_parts + 1
# Verify correct handling of Path and str inputs
path_input = mixed_dir
str_input = str(mixed_dir)
path_paths = get_checkpoint_paths(path_input)
str_paths = get_checkpoint_paths(str_input)
assert set(path_paths) == set(str_paths)