Compare commits

..

3 Commits

Author SHA1 Message Date
Dave Lage
ca7a65d915 Merge 24ab4c0c4a into 1dae34b0af 2026-04-01 13:03:39 +00:00
rockerBOO
24ab4c0c4a Support loading more checkpoint types 2025-03-26 16:35:04 -04:00
rockerBOO
c0f2808763 Support more checkpoint files for flux 2025-03-26 16:34:47 -04:00
6 changed files with 165 additions and 41 deletions

View File

@@ -50,12 +50,6 @@ Stable Diffusion等の画像生成モデルの学習、モデルによる画像
### 更新履歴
- 次のリリースに含まれる予定の主な変更点は以下の通りです。リリース前の変更点は予告なく変更される可能性があります。
- Intel GPUの互換性を向上しました。[PR #2307](https://github.com/kohya-ss/sd-scripts/pull/2307) WhitePr氏に感謝します。
- **Version 0.10.3 (2026-04-02):**
- Animaでfp16で学習する際の安定性をさらに改善しました。[PR #2302](https://github.com/kohya-ss/sd-scripts/pull/2302) 問題をご報告いただいた方々に深く感謝します。
- **Version 0.10.2 (2026-03-30):**
- SD/SDXLのLECO学習に対応しました。[PR #2285](https://github.com/kohya-ss/sd-scripts/pull/2285) および [PR #2294](https://github.com/kohya-ss/sd-scripts/pull/2294) umisetokikaze氏に深く感謝します。
- 詳細は[ドキュメント](./docs/train_leco.md)をご覧ください。

View File

@@ -47,12 +47,6 @@ If you find this project helpful, please consider supporting its development via
### Change History
- The following are the main changes planned for the next release. Please note that these changes may be subject to change without notice before the release.
- Improved compatibility with Intel GPUs. Thanks to WhitePr for [PR #2307](https://github.com/kohya-ss/sd-scripts/pull/2307).
- **Version 0.10.3 (2026-04-02):**
- Stability when training with fp16 on Anima has been further improved. See [PR #2302](https://github.com/kohya-ss/sd-scripts/pull/2302) for details. We deeply appreciate those who reported the issue.
- **Version 0.10.2 (2026-03-30):**
- LECO training for SD/SDXL is now supported. Many thanks to umisetokikaze for [PR #2285](https://github.com/kohya-ss/sd-scripts/pull/2285) and [PR #2294](https://github.com/kohya-ss/sd-scripts/pull/2294).
- Please refer to the [documentation](./docs/train_leco.md) for details.

View File

@@ -738,9 +738,9 @@ class FinalLayer(nn.Module):
x_B_T_H_W_D: torch.Tensor,
emb_B_T_D: torch.Tensor,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
use_fp32: bool = False,
):
# Compute AdaLN modulation parameters (in float32 when fp16 to avoid overflow in Linear layers)
use_fp32 = x_B_T_H_W_D.dtype == torch.float16
with torch.autocast(device_type=x_B_T_H_W_D.device.type, dtype=torch.float32, enabled=use_fp32):
if self.use_adaln_lora:
assert adaln_lora_B_T_3D is not None
@@ -863,11 +863,11 @@ class Block(nn.Module):
emb_B_T_D: torch.Tensor,
crossattn_emb: torch.Tensor,
attn_params: attention.AttentionParams,
use_fp32: bool = False,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
use_fp32 = x_B_T_H_W_D.dtype == torch.float16
if use_fp32:
# Cast to float32 for better numerical stability in residual connections. Each module will cast back to float16 by enclosing autocast context.
x_B_T_H_W_D = x_B_T_H_W_D.float()
@@ -959,7 +959,6 @@ class Block(nn.Module):
emb_B_T_D: torch.Tensor,
crossattn_emb: torch.Tensor,
attn_params: attention.AttentionParams,
use_fp32: bool = False,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
@@ -973,7 +972,6 @@ class Block(nn.Module):
emb_B_T_D,
crossattn_emb,
attn_params,
use_fp32,
rope_emb_L_1_1_D,
adaln_lora_B_T_3D,
extra_per_block_pos_emb,
@@ -996,7 +994,6 @@ class Block(nn.Module):
emb_B_T_D,
crossattn_emb,
attn_params,
use_fp32,
rope_emb_L_1_1_D,
adaln_lora_B_T_3D,
extra_per_block_pos_emb,
@@ -1010,7 +1007,6 @@ class Block(nn.Module):
emb_B_T_D,
crossattn_emb,
attn_params,
use_fp32,
rope_emb_L_1_1_D,
adaln_lora_B_T_3D,
extra_per_block_pos_emb,
@@ -1022,7 +1018,6 @@ class Block(nn.Module):
emb_B_T_D,
crossattn_emb,
attn_params,
use_fp32,
rope_emb_L_1_1_D,
adaln_lora_B_T_3D,
extra_per_block_pos_emb,
@@ -1343,19 +1338,16 @@ class Anima(nn.Module):
attn_params = attention.AttentionParams.create_attention_params(self.attn_mode, self.split_attn)
# Determine whether to use float32 for block computations based on input dtype (use float32 for better stability when input is float16)
use_fp32 = x_B_T_H_W_D.dtype == torch.float16
for block_idx, block in enumerate(self.blocks):
if self.blocks_to_swap:
self.offloader.wait_for_block(block_idx)
x_B_T_H_W_D = block(x_B_T_H_W_D, t_embedding_B_T_D, crossattn_emb, attn_params, use_fp32, **block_kwargs)
x_B_T_H_W_D = block(x_B_T_H_W_D, t_embedding_B_T_D, crossattn_emb, attn_params, **block_kwargs)
if self.blocks_to_swap:
self.offloader.submit_move_blocks(self.blocks, block_idx)
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D, use_fp32=use_fp32)
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)
return x_B_C_Tt_Hp_Wp

View File

@@ -1,5 +1,7 @@
import json
import os
from pathlib import Path
import re
from dataclasses import replace
from typing import List, Optional, Tuple, Union
@@ -26,6 +28,63 @@ MODEL_NAME_SCHNELL = "schnell"
MODEL_VERSION_CHROMA = "chroma"
def get_checkpoint_paths(ckpt_path: str | Path):
"""
Get checkpoint paths for flux models
- huggingface directory structure
- huggingface sharded safetensors files
- in transformer directory
- plain directory
- single safetensor files
"""
if not isinstance(ckpt_path, Path):
# Convert to Path object
ckpt_path = Path(ckpt_path)
# If ckpt_path is a directory
if ckpt_path.is_dir():
# List to store potential checkpoint paths
potential_paths = []
# Check for files directly in the directory
potential_paths.extend(ckpt_path.glob('*.safetensors'))
# Check for files in the transformer subdirectory
transformer_path = ckpt_path / 'transformer'
if transformer_path.is_dir():
potential_paths.extend(transformer_path.glob('*.safetensors'))
# Filter and expand multi-part checkpoint paths
checkpoint_paths = []
for path in potential_paths:
# If it's a multi-part checkpoint
if '-of-' in path.name:
# Use regex to extract parts
match = re.search(r'(.+?)-(\d+)-of-(\d+)', path.name)
if match:
base_name, current_part, total_parts = match.groups()
# Generate all part paths
part_paths = [
path.with_name(f'{base_name}-{i:05d}-of-{int(total_parts):05d}.safetensors')
for i in range(1, int(total_parts) + 1)
]
checkpoint_paths.extend(part_paths)
else:
# Single file checkpoint
checkpoint_paths.append(path)
# Remove duplicates while preserving order
checkpoint_paths = list(dict.fromkeys(checkpoint_paths))
else:
# If ckpt_path is a single file
checkpoint_paths = [ckpt_path]
return checkpoint_paths
def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]:
"""
チェックポイントの状態を分析し、DiffusersかBFLか、devかschnellか、ブロック数を計算して返す。
@@ -43,12 +102,7 @@ def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int
# check the state dict: Diffusers or BFL, dev or schnell, number of blocks
logger.info(f"Checking the state dict: Diffusers or BFL, dev or schnell")
if os.path.isdir(ckpt_path): # if ckpt_path is a directory, it is Diffusers
ckpt_path = os.path.join(ckpt_path, "transformer", "diffusion_pytorch_model-00001-of-00003.safetensors")
if "00001-of-00003" in ckpt_path:
ckpt_paths = [ckpt_path.replace("00001-of-00003", f"0000{i}-of-00003") for i in range(1, 4)]
else:
ckpt_paths = [ckpt_path]
ckpt_paths = get_checkpoint_paths(ckpt_path)
keys = []
for ckpt_path in ckpt_paths:

View File

@@ -1,7 +1,6 @@
import os
import sys
import torch
from packaging import version
try:
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
has_ipex = True
@@ -9,7 +8,7 @@ except Exception:
has_ipex = False
from .hijacks import ipex_hijacks
torch_version = version.parse(torch.__version__)
torch_version = float(torch.__version__[:3])
# pylint: disable=protected-access, missing-function-docstring, line-too-long
@@ -57,6 +56,7 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.__path__ = torch.xpu.__path__
torch.cuda.set_stream = torch.xpu.set_stream
torch.cuda.torch = torch.xpu.torch
torch.cuda.Union = torch.xpu.Union
torch.cuda.__annotations__ = torch.xpu.__annotations__
torch.cuda.__package__ = torch.xpu.__package__
torch.cuda.__builtins__ = torch.xpu.__builtins__
@@ -64,12 +64,14 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.StreamContext = torch.xpu.StreamContext
torch.cuda._lazy_call = torch.xpu._lazy_call
torch.cuda.random = torch.xpu.random
torch.cuda._device = torch.xpu._device
torch.cuda.__name__ = torch.xpu.__name__
torch.cuda._device_t = torch.xpu._device_t
torch.cuda.__spec__ = torch.xpu.__spec__
torch.cuda.__file__ = torch.xpu.__file__
# torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
if torch_version < version.parse("2.3"):
if torch_version < 2.3:
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
torch.cuda._initialized = torch.xpu.lazy_init._initialized
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
@@ -112,22 +114,17 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.threading = torch.xpu.threading
torch.cuda.traceback = torch.xpu.traceback
if torch_version < version.parse("2.5"):
if torch_version < 2.5:
torch.cuda.os = torch.xpu.os
torch.cuda.Device = torch.xpu.Device
torch.cuda.warnings = torch.xpu.warnings
torch.cuda.classproperty = torch.xpu.classproperty
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
if torch_version < version.parse("2.7"):
if torch_version < 2.7:
torch.cuda.Tuple = torch.xpu.Tuple
torch.cuda.List = torch.xpu.List
if torch_version < version.parse("2.11"):
torch.cuda._device_t = torch.xpu._device_t
torch.cuda._device = torch.xpu._device
torch.cuda.Union = torch.xpu.Union
# Memory:
if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
@@ -163,7 +160,7 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.initial_seed = torch.xpu.initial_seed
# C
if torch_version < version.parse("2.3"):
if torch_version < 2.3:
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentRawStream
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count
ipex._C._DeviceProperties.major = 12

View File

@@ -0,0 +1,93 @@
import pytest
from pathlib import Path
import tempfile
from library.flux_utils import get_checkpoint_paths
def test_get_checkpoint_paths():
# Create a temporary directory for testing
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
# Scenario 1: Single safetensors file in root directory
single_file = temp_path / "model.safetensors"
single_file.touch()
paths = get_checkpoint_paths(str(single_file))
assert len(paths) == 1
assert paths[0] == single_file
def test_multiple_root_checkpoint_paths():
"""
Multiple single safetensors files in root directory
"""
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
# Scenario 2:
file1 = temp_path / "model1.safetensors"
file2 = temp_path / "model2.safetensors"
file1.touch()
file2.touch()
paths = get_checkpoint_paths(temp_path)
assert len(paths) == 2
assert set(paths) == {file1, file2}
def test_multipart_sharded_checkpoint():
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
# Scenario 3: Sharded multi-part checkpoint
# Create sharded checkpoint files
base_name = "diffusion_pytorch_model"
total_parts = 3
for i in range(1, total_parts + 1):
(temp_path / f"{base_name}-{i:05d}-of-{total_parts:05d}.safetensors").touch()
paths = get_checkpoint_paths(temp_path)
assert len(paths) == total_parts
# Check if all expected part paths are present
expected_paths = [temp_path / f"{base_name}-{i:05d}-of-{total_parts:05d}.safetensors" for i in range(1, total_parts + 1)]
assert set(paths) == set(expected_paths)
def test_transformer_model_dir():
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
transformer_dir = temp_path / "transformer"
transformer_dir.mkdir()
transformer_file = transformer_dir / "diffusion_pytorch_model.safetensors"
transformer_file.touch()
paths = get_checkpoint_paths(temp_path)
assert transformer_file in paths
def test_mixed_files_sharded_checkpoints():
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
# Scenario 5: Mixed files and sharded checkpoints
mixed_dir = temp_path / "mixed"
mixed_dir.mkdir()
# Create a single file
(mixed_dir / "single_model.safetensors").touch()
# Create sharded checkpoint
base_name = "diffusion_pytorch_model"
total_parts = 2
for i in range(1, total_parts + 1):
(mixed_dir / f"{base_name}-{i:05d}-of-{total_parts:05d}.safetensors").touch()
paths = get_checkpoint_paths(mixed_dir)
assert len(paths) == total_parts + 1
# Verify correct handling of Path and str inputs
path_input = mixed_dir
str_input = str(mixed_dir)
path_paths = get_checkpoint_paths(path_input)
str_paths = get_checkpoint_paths(str_input)
assert set(path_paths) == set(str_paths)