Compare commits

..

8 Commits

Author SHA1 Message Date
Kohya S.
308a0cc9fc Merge pull request #2312 from kohya-ss/dev
Merge dev to main
2026-04-07 08:53:13 +09:00
Kohya S
7e60e163c1 Merge branch 'main' into dev 2026-04-07 08:49:58 +09:00
Kohya S.
a8f5c222e0 Merge pull request #2311 from kohya-ss/doc-update-readme-for-next-release
README: Add planned changes for next release (intel GPU compatibility)
2026-04-07 08:47:37 +09:00
Kohya S
1d588d6cb6 README: Add planned changes for next release and improve Intel GPU compatibility 2026-04-07 08:44:31 +09:00
Kohya S.
a7d35701a0 Merge pull request #2307 from WhitePr/dev
update ipex
2026-04-07 08:41:41 +09:00
WhitePr
8da05a10dc Update IPEX libs 2026-04-04 05:37:18 +09:00
WhitePr
197b129284 Modifying the method for get the Torch version 2026-04-04 04:44:53 +09:00
Kohya S.
51435f1718 Merge pull request #2303 from kohya-ss/sd3
fix: improve numerical stability by conditionally using float32 in Anima with fp16 training
2026-04-02 12:40:48 +09:00
5 changed files with 23 additions and 161 deletions

View File

@@ -50,6 +50,9 @@ Stable Diffusion等の画像生成モデルの学習、モデルによる画像
### 更新履歴 ### 更新履歴
- 次のリリースに含まれる予定の主な変更点は以下の通りです。リリース前の変更点は予告なく変更される可能性があります。
- Intel GPUの互換性を向上しました。[PR #2307](https://github.com/kohya-ss/sd-scripts/pull/2307) WhitePr氏に感謝します。
- **Version 0.10.3 (2026-04-02):** - **Version 0.10.3 (2026-04-02):**
- Animaでfp16で学習する際の安定性をさらに改善しました。[PR #2302](https://github.com/kohya-ss/sd-scripts/pull/2302) 問題をご報告いただいた方々に深く感謝します。 - Animaでfp16で学習する際の安定性をさらに改善しました。[PR #2302](https://github.com/kohya-ss/sd-scripts/pull/2302) 問題をご報告いただいた方々に深く感謝します。

View File

@@ -47,6 +47,9 @@ If you find this project helpful, please consider supporting its development via
### Change History ### Change History
- The following are the main changes planned for the next release. Please note that these changes may be subject to change without notice before the release.
- Improved compatibility with Intel GPUs. Thanks to WhitePr for [PR #2307](https://github.com/kohya-ss/sd-scripts/pull/2307).
- **Version 0.10.3 (2026-04-02):** - **Version 0.10.3 (2026-04-02):**
- Stability when training with fp16 on Anima has been further improved. See [PR #2302](https://github.com/kohya-ss/sd-scripts/pull/2302) for details. We deeply appreciate those who reported the issue. - Stability when training with fp16 on Anima has been further improved. See [PR #2302](https://github.com/kohya-ss/sd-scripts/pull/2302) for details. We deeply appreciate those who reported the issue.

View File

@@ -1,7 +1,5 @@
import json import json
import os import os
from pathlib import Path
import re
from dataclasses import replace from dataclasses import replace
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
@@ -28,63 +26,6 @@ MODEL_NAME_SCHNELL = "schnell"
MODEL_VERSION_CHROMA = "chroma" MODEL_VERSION_CHROMA = "chroma"
def get_checkpoint_paths(ckpt_path: str | Path):
"""
Get checkpoint paths for flux models
- huggingface directory structure
- huggingface sharded safetensors files
- in transformer directory
- plain directory
- single safetensor files
"""
if not isinstance(ckpt_path, Path):
# Convert to Path object
ckpt_path = Path(ckpt_path)
# If ckpt_path is a directory
if ckpt_path.is_dir():
# List to store potential checkpoint paths
potential_paths = []
# Check for files directly in the directory
potential_paths.extend(ckpt_path.glob('*.safetensors'))
# Check for files in the transformer subdirectory
transformer_path = ckpt_path / 'transformer'
if transformer_path.is_dir():
potential_paths.extend(transformer_path.glob('*.safetensors'))
# Filter and expand multi-part checkpoint paths
checkpoint_paths = []
for path in potential_paths:
# If it's a multi-part checkpoint
if '-of-' in path.name:
# Use regex to extract parts
match = re.search(r'(.+?)-(\d+)-of-(\d+)', path.name)
if match:
base_name, current_part, total_parts = match.groups()
# Generate all part paths
part_paths = [
path.with_name(f'{base_name}-{i:05d}-of-{int(total_parts):05d}.safetensors')
for i in range(1, int(total_parts) + 1)
]
checkpoint_paths.extend(part_paths)
else:
# Single file checkpoint
checkpoint_paths.append(path)
# Remove duplicates while preserving order
checkpoint_paths = list(dict.fromkeys(checkpoint_paths))
else:
# If ckpt_path is a single file
checkpoint_paths = [ckpt_path]
return checkpoint_paths
def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]: def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]:
""" """
チェックポイントの状態を分析し、DiffusersかBFLか、devかschnellか、ブロック数を計算して返す。 チェックポイントの状態を分析し、DiffusersかBFLか、devかschnellか、ブロック数を計算して返す。
@@ -102,7 +43,12 @@ def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int
# check the state dict: Diffusers or BFL, dev or schnell, number of blocks # check the state dict: Diffusers or BFL, dev or schnell, number of blocks
logger.info(f"Checking the state dict: Diffusers or BFL, dev or schnell") logger.info(f"Checking the state dict: Diffusers or BFL, dev or schnell")
ckpt_paths = get_checkpoint_paths(ckpt_path) if os.path.isdir(ckpt_path): # if ckpt_path is a directory, it is Diffusers
ckpt_path = os.path.join(ckpt_path, "transformer", "diffusion_pytorch_model-00001-of-00003.safetensors")
if "00001-of-00003" in ckpt_path:
ckpt_paths = [ckpt_path.replace("00001-of-00003", f"0000{i}-of-00003") for i in range(1, 4)]
else:
ckpt_paths = [ckpt_path]
keys = [] keys = []
for ckpt_path in ckpt_paths: for ckpt_path in ckpt_paths:

View File

@@ -1,6 +1,7 @@
import os import os
import sys import sys
import torch import torch
from packaging import version
try: try:
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
has_ipex = True has_ipex = True
@@ -8,7 +9,7 @@ except Exception:
has_ipex = False has_ipex = False
from .hijacks import ipex_hijacks from .hijacks import ipex_hijacks
torch_version = float(torch.__version__[:3]) torch_version = version.parse(torch.__version__)
# pylint: disable=protected-access, missing-function-docstring, line-too-long # pylint: disable=protected-access, missing-function-docstring, line-too-long
@@ -56,7 +57,6 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.__path__ = torch.xpu.__path__ torch.cuda.__path__ = torch.xpu.__path__
torch.cuda.set_stream = torch.xpu.set_stream torch.cuda.set_stream = torch.xpu.set_stream
torch.cuda.torch = torch.xpu.torch torch.cuda.torch = torch.xpu.torch
torch.cuda.Union = torch.xpu.Union
torch.cuda.__annotations__ = torch.xpu.__annotations__ torch.cuda.__annotations__ = torch.xpu.__annotations__
torch.cuda.__package__ = torch.xpu.__package__ torch.cuda.__package__ = torch.xpu.__package__
torch.cuda.__builtins__ = torch.xpu.__builtins__ torch.cuda.__builtins__ = torch.xpu.__builtins__
@@ -64,14 +64,12 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.StreamContext = torch.xpu.StreamContext torch.cuda.StreamContext = torch.xpu.StreamContext
torch.cuda._lazy_call = torch.xpu._lazy_call torch.cuda._lazy_call = torch.xpu._lazy_call
torch.cuda.random = torch.xpu.random torch.cuda.random = torch.xpu.random
torch.cuda._device = torch.xpu._device
torch.cuda.__name__ = torch.xpu.__name__ torch.cuda.__name__ = torch.xpu.__name__
torch.cuda._device_t = torch.xpu._device_t
torch.cuda.__spec__ = torch.xpu.__spec__ torch.cuda.__spec__ = torch.xpu.__spec__
torch.cuda.__file__ = torch.xpu.__file__ torch.cuda.__file__ = torch.xpu.__file__
# torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing # torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
if torch_version < 2.3: if torch_version < version.parse("2.3"):
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
torch.cuda._initialized = torch.xpu.lazy_init._initialized torch.cuda._initialized = torch.xpu.lazy_init._initialized
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
@@ -114,17 +112,22 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.threading = torch.xpu.threading torch.cuda.threading = torch.xpu.threading
torch.cuda.traceback = torch.xpu.traceback torch.cuda.traceback = torch.xpu.traceback
if torch_version < 2.5: if torch_version < version.parse("2.5"):
torch.cuda.os = torch.xpu.os torch.cuda.os = torch.xpu.os
torch.cuda.Device = torch.xpu.Device torch.cuda.Device = torch.xpu.Device
torch.cuda.warnings = torch.xpu.warnings torch.cuda.warnings = torch.xpu.warnings
torch.cuda.classproperty = torch.xpu.classproperty torch.cuda.classproperty = torch.xpu.classproperty
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
if torch_version < 2.7: if torch_version < version.parse("2.7"):
torch.cuda.Tuple = torch.xpu.Tuple torch.cuda.Tuple = torch.xpu.Tuple
torch.cuda.List = torch.xpu.List torch.cuda.List = torch.xpu.List
if torch_version < version.parse("2.11"):
torch.cuda._device_t = torch.xpu._device_t
torch.cuda._device = torch.xpu._device
torch.cuda.Union = torch.xpu.Union
# Memory: # Memory:
if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read(): if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
@@ -160,7 +163,7 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.initial_seed = torch.xpu.initial_seed torch.cuda.initial_seed = torch.xpu.initial_seed
# C # C
if torch_version < 2.3: if torch_version < version.parse("2.3"):
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentRawStream torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentRawStream
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count
ipex._C._DeviceProperties.major = 12 ipex._C._DeviceProperties.major = 12

View File

@@ -1,93 +0,0 @@
import pytest
from pathlib import Path
import tempfile
from library.flux_utils import get_checkpoint_paths
def test_get_checkpoint_paths():
# Create a temporary directory for testing
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
# Scenario 1: Single safetensors file in root directory
single_file = temp_path / "model.safetensors"
single_file.touch()
paths = get_checkpoint_paths(str(single_file))
assert len(paths) == 1
assert paths[0] == single_file
def test_multiple_root_checkpoint_paths():
"""
Multiple single safetensors files in root directory
"""
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
# Scenario 2:
file1 = temp_path / "model1.safetensors"
file2 = temp_path / "model2.safetensors"
file1.touch()
file2.touch()
paths = get_checkpoint_paths(temp_path)
assert len(paths) == 2
assert set(paths) == {file1, file2}
def test_multipart_sharded_checkpoint():
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
# Scenario 3: Sharded multi-part checkpoint
# Create sharded checkpoint files
base_name = "diffusion_pytorch_model"
total_parts = 3
for i in range(1, total_parts + 1):
(temp_path / f"{base_name}-{i:05d}-of-{total_parts:05d}.safetensors").touch()
paths = get_checkpoint_paths(temp_path)
assert len(paths) == total_parts
# Check if all expected part paths are present
expected_paths = [temp_path / f"{base_name}-{i:05d}-of-{total_parts:05d}.safetensors" for i in range(1, total_parts + 1)]
assert set(paths) == set(expected_paths)
def test_transformer_model_dir():
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
transformer_dir = temp_path / "transformer"
transformer_dir.mkdir()
transformer_file = transformer_dir / "diffusion_pytorch_model.safetensors"
transformer_file.touch()
paths = get_checkpoint_paths(temp_path)
assert transformer_file in paths
def test_mixed_files_sharded_checkpoints():
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
# Scenario 5: Mixed files and sharded checkpoints
mixed_dir = temp_path / "mixed"
mixed_dir.mkdir()
# Create a single file
(mixed_dir / "single_model.safetensors").touch()
# Create sharded checkpoint
base_name = "diffusion_pytorch_model"
total_parts = 2
for i in range(1, total_parts + 1):
(mixed_dir / f"{base_name}-{i:05d}-of-{total_parts:05d}.safetensors").touch()
paths = get_checkpoint_paths(mixed_dir)
assert len(paths) == total_parts + 1
# Verify correct handling of Path and str inputs
path_input = mixed_dir
str_input = str(mixed_dir)
path_paths = get_checkpoint_paths(path_input)
str_paths = get_checkpoint_paths(str_input)
assert set(path_paths) == set(str_paths)