This commit is contained in:
Dave Lage
2025-09-30 09:53:44 +05:30
committed by GitHub
2 changed files with 153 additions and 6 deletions

View File

@@ -1,5 +1,7 @@
import json
import os
from pathlib import Path
import re
from dataclasses import replace
from typing import List, Optional, Tuple, Union
@@ -26,6 +28,63 @@ MODEL_NAME_SCHNELL = "schnell"
MODEL_VERSION_CHROMA = "chroma"
def get_checkpoint_paths(ckpt_path: str | Path):
"""
Get checkpoint paths for flux models
- huggingface directory structure
- huggingface sharded safetensors files
- in transformer directory
- plain directory
- single safetensor files
"""
if not isinstance(ckpt_path, Path):
# Convert to Path object
ckpt_path = Path(ckpt_path)
# If ckpt_path is a directory
if ckpt_path.is_dir():
# List to store potential checkpoint paths
potential_paths = []
# Check for files directly in the directory
potential_paths.extend(ckpt_path.glob('*.safetensors'))
# Check for files in the transformer subdirectory
transformer_path = ckpt_path / 'transformer'
if transformer_path.is_dir():
potential_paths.extend(transformer_path.glob('*.safetensors'))
# Filter and expand multi-part checkpoint paths
checkpoint_paths = []
for path in potential_paths:
# If it's a multi-part checkpoint
if '-of-' in path.name:
# Use regex to extract parts
match = re.search(r'(.+?)-(\d+)-of-(\d+)', path.name)
if match:
base_name, current_part, total_parts = match.groups()
# Generate all part paths
part_paths = [
path.with_name(f'{base_name}-{i:05d}-of-{int(total_parts):05d}.safetensors')
for i in range(1, int(total_parts) + 1)
]
checkpoint_paths.extend(part_paths)
else:
# Single file checkpoint
checkpoint_paths.append(path)
# Remove duplicates while preserving order
checkpoint_paths = list(dict.fromkeys(checkpoint_paths))
else:
# If ckpt_path is a single file
checkpoint_paths = [ckpt_path]
return checkpoint_paths
def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]:
"""
チェックポイントの状態を分析し、DiffusersかBFLか、devかschnellか、ブロック数を計算して返す。
@@ -43,12 +102,7 @@ def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int
# check the state dict: Diffusers or BFL, dev or schnell, number of blocks
logger.info(f"Checking the state dict: Diffusers or BFL, dev or schnell")
if os.path.isdir(ckpt_path): # if ckpt_path is a directory, it is Diffusers
ckpt_path = os.path.join(ckpt_path, "transformer", "diffusion_pytorch_model-00001-of-00003.safetensors")
if "00001-of-00003" in ckpt_path:
ckpt_paths = [ckpt_path.replace("00001-of-00003", f"0000{i}-of-00003") for i in range(1, 4)]
else:
ckpt_paths = [ckpt_path]
ckpt_paths = get_checkpoint_paths(ckpt_path)
keys = []
for ckpt_path in ckpt_paths: