From c0f2808763422cd898ab00b257b9a115b295314e Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 26 Mar 2025 16:34:47 -0400 Subject: [PATCH 1/2] Support more checkpoint files for flux --- library/flux_utils.py | 66 +++++++++++++++++++++++++++++++++++++++---- 1 file changed, 60 insertions(+), 6 deletions(-) diff --git a/library/flux_utils.py b/library/flux_utils.py index 8be1d63e..21162302 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -1,5 +1,7 @@ import json import os +from pathlib import Path +import re from dataclasses import replace from typing import List, Optional, Tuple, Union @@ -25,6 +27,63 @@ MODEL_NAME_DEV = "dev" MODEL_NAME_SCHNELL = "schnell" +def get_checkpoint_paths(ckpt_path: str | Path): + """ + Get checkpoint paths for flux models + + - huggingface directory structure + - huggingface sharded safetensors files + - in transformer directory + - plain directory + - single safetensor files + """ + if not isinstance(ckpt_path, Path): + # Convert to Path object + ckpt_path = Path(ckpt_path) + + # If ckpt_path is a directory + if ckpt_path.is_dir(): + # List to store potential checkpoint paths + potential_paths = [] + + # Check for files directly in the directory + potential_paths.extend(ckpt_path.glob('*.safetensors')) + + # Check for files in the transformer subdirectory + transformer_path = ckpt_path / 'transformer' + if transformer_path.is_dir(): + potential_paths.extend(transformer_path.glob('*.safetensors')) + + # Filter and expand multi-part checkpoint paths + checkpoint_paths = [] + for path in potential_paths: + # If it's a multi-part checkpoint + if '-of-' in path.name: + # Use regex to extract parts + match = re.search(r'(.+?)-(\d+)-of-(\d+)', path.name) + if match: + base_name, current_part, total_parts = match.groups() + + # Generate all part paths + part_paths = [ + path.with_name(f'{base_name}-{i:05d}-of-{int(total_parts):05d}.safetensors') + for i in range(1, int(total_parts) + 1) + ] + + checkpoint_paths.extend(part_paths) + else: + # Single file checkpoint + checkpoint_paths.append(path) + + # Remove duplicates while preserving order + checkpoint_paths = list(dict.fromkeys(checkpoint_paths)) + + else: + # If ckpt_path is a single file + checkpoint_paths = [ckpt_path] + + return checkpoint_paths + def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]: """ チェックポイントの状態を分析し、DiffusersかBFLか、devかschnellか、ブロック数を計算して返す。 @@ -42,12 +101,7 @@ def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int # check the state dict: Diffusers or BFL, dev or schnell, number of blocks logger.info(f"Checking the state dict: Diffusers or BFL, dev or schnell") - if os.path.isdir(ckpt_path): # if ckpt_path is a directory, it is Diffusers - ckpt_path = os.path.join(ckpt_path, "transformer", "diffusion_pytorch_model-00001-of-00003.safetensors") - if "00001-of-00003" in ckpt_path: - ckpt_paths = [ckpt_path.replace("00001-of-00003", f"0000{i}-of-00003") for i in range(1, 4)] - else: - ckpt_paths = [ckpt_path] + ckpt_paths = get_checkpoint_paths(ckpt_path) keys = [] for ckpt_path in ckpt_paths: From 24ab4c0c4aac9a9fbb2e8d542239dfad61901511 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 26 Mar 2025 16:31:21 -0400 Subject: [PATCH 2/2] Support loading more checkpoint types --- tests/library/test_flux_utils.py | 93 ++++++++++++++++++++++++++++++++ 1 file changed, 93 insertions(+) create mode 100644 tests/library/test_flux_utils.py diff --git a/tests/library/test_flux_utils.py b/tests/library/test_flux_utils.py new file mode 100644 index 00000000..5479de19 --- /dev/null +++ b/tests/library/test_flux_utils.py @@ -0,0 +1,93 @@ +import pytest +from pathlib import Path +import tempfile + +from library.flux_utils import get_checkpoint_paths + + +def test_get_checkpoint_paths(): + # Create a temporary directory for testing + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + # Scenario 1: Single safetensors file in root directory + single_file = temp_path / "model.safetensors" + single_file.touch() + paths = get_checkpoint_paths(str(single_file)) + assert len(paths) == 1 + assert paths[0] == single_file + + +def test_multiple_root_checkpoint_paths(): + """ + Multiple single safetensors files in root directory + """ + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + # Scenario 2: + file1 = temp_path / "model1.safetensors" + file2 = temp_path / "model2.safetensors" + file1.touch() + file2.touch() + paths = get_checkpoint_paths(temp_path) + assert len(paths) == 2 + assert set(paths) == {file1, file2} + + +def test_multipart_sharded_checkpoint(): + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + # Scenario 3: Sharded multi-part checkpoint + # Create sharded checkpoint files + base_name = "diffusion_pytorch_model" + total_parts = 3 + for i in range(1, total_parts + 1): + (temp_path / f"{base_name}-{i:05d}-of-{total_parts:05d}.safetensors").touch() + + paths = get_checkpoint_paths(temp_path) + assert len(paths) == total_parts + + # Check if all expected part paths are present + expected_paths = [temp_path / f"{base_name}-{i:05d}-of-{total_parts:05d}.safetensors" for i in range(1, total_parts + 1)] + assert set(paths) == set(expected_paths) + + +def test_transformer_model_dir(): + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + transformer_dir = temp_path / "transformer" + transformer_dir.mkdir() + transformer_file = transformer_dir / "diffusion_pytorch_model.safetensors" + transformer_file.touch() + + paths = get_checkpoint_paths(temp_path) + assert transformer_file in paths + + +def test_mixed_files_sharded_checkpoints(): + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + # Scenario 5: Mixed files and sharded checkpoints + mixed_dir = temp_path / "mixed" + mixed_dir.mkdir() + + # Create a single file + (mixed_dir / "single_model.safetensors").touch() + + # Create sharded checkpoint + base_name = "diffusion_pytorch_model" + total_parts = 2 + for i in range(1, total_parts + 1): + (mixed_dir / f"{base_name}-{i:05d}-of-{total_parts:05d}.safetensors").touch() + + paths = get_checkpoint_paths(mixed_dir) + assert len(paths) == total_parts + 1 + + # Verify correct handling of Path and str inputs + path_input = mixed_dir + str_input = str(mixed_dir) + + path_paths = get_checkpoint_paths(path_input) + str_paths = get_checkpoint_paths(str_input) + + assert set(path_paths) == set(str_paths)