mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 09:30:28 +00:00
Compare commits
11 Commits
511148e2c3
...
74e272ebd2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
74e272ebd2 | ||
|
|
5462a6bb24 | ||
|
|
63711390a0 | ||
|
|
206adb6438 | ||
|
|
60bfa97b19 | ||
|
|
f0c767e0f2 | ||
|
|
a0c26a0efa | ||
|
|
67d0621313 | ||
|
|
6a826d21b1 | ||
|
|
24ab4c0c4a | ||
|
|
c0f2808763 |
@@ -190,7 +190,7 @@ The script adds HunyuanImage-2.1 specific arguments. For common arguments (like
|
||||
* `--fp8_vl`
|
||||
- Use FP8 for the VLM (Qwen2.5-VL) text encoder.
|
||||
* `--text_encoder_cpu`
|
||||
- Runs the text encoders on CPU to reduce VRAM usage. This is useful when VRAM is insufficient (less than 12GB). Encoding one text may take a few minutes (depending on CPU). It is highly recommended to use this option with `--cache_text_encoder_outputs_to_disk` to avoid repeated encoding every time training starts.
|
||||
- Runs the text encoders on CPU to reduce VRAM usage. This is useful when VRAM is insufficient (less than 12GB). Encoding one text may take a few minutes (depending on CPU). It is highly recommended to use this option with `--cache_text_encoder_outputs_to_disk` to avoid repeated encoding every time training starts. **In addition, increasing `--num_cpu_threads_per_process` in the `accelerate launch` command, like `--num_cpu_threads_per_process=8` or `16`, can speed up encoding in some environments.**
|
||||
* `--blocks_to_swap=<integer>` **[Experimental Feature]**
|
||||
- Setting to reduce VRAM usage by swapping parts of the model (Transformer blocks) between CPU and GPU. Specify the number of blocks to swap as an integer (e.g., `18`). Larger values reduce VRAM usage but decrease training speed. Adjust according to your GPU's VRAM capacity. Can be used with `gradient_checkpointing`.
|
||||
* `--cache_text_encoder_outputs`
|
||||
|
||||
@@ -249,7 +249,15 @@ def sample_image_inference(
|
||||
arg_c_null = None
|
||||
|
||||
gen_args = SimpleNamespace(
|
||||
image_size=(height, width), infer_steps=sample_steps, flow_shift=flow_shift, guidance_scale=cfg_scale, fp8=args.fp8_scaled
|
||||
image_size=(height, width),
|
||||
infer_steps=sample_steps,
|
||||
flow_shift=flow_shift,
|
||||
guidance_scale=cfg_scale,
|
||||
fp8=args.fp8_scaled,
|
||||
apg_start_step_ocr=38,
|
||||
apg_start_step_general=5,
|
||||
guidance_rescale=0.0,
|
||||
guidance_rescale_apg=0.0,
|
||||
)
|
||||
|
||||
from hunyuan_image_minimal_inference import generate_body # import here to avoid circular import
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
import re
|
||||
from dataclasses import replace
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
@@ -26,6 +28,63 @@ MODEL_NAME_SCHNELL = "schnell"
|
||||
MODEL_VERSION_CHROMA = "chroma"
|
||||
|
||||
|
||||
def get_checkpoint_paths(ckpt_path: str | Path):
|
||||
"""
|
||||
Get checkpoint paths for flux models
|
||||
|
||||
- huggingface directory structure
|
||||
- huggingface sharded safetensors files
|
||||
- in transformer directory
|
||||
- plain directory
|
||||
- single safetensor files
|
||||
"""
|
||||
if not isinstance(ckpt_path, Path):
|
||||
# Convert to Path object
|
||||
ckpt_path = Path(ckpt_path)
|
||||
|
||||
# If ckpt_path is a directory
|
||||
if ckpt_path.is_dir():
|
||||
# List to store potential checkpoint paths
|
||||
potential_paths = []
|
||||
|
||||
# Check for files directly in the directory
|
||||
potential_paths.extend(ckpt_path.glob('*.safetensors'))
|
||||
|
||||
# Check for files in the transformer subdirectory
|
||||
transformer_path = ckpt_path / 'transformer'
|
||||
if transformer_path.is_dir():
|
||||
potential_paths.extend(transformer_path.glob('*.safetensors'))
|
||||
|
||||
# Filter and expand multi-part checkpoint paths
|
||||
checkpoint_paths = []
|
||||
for path in potential_paths:
|
||||
# If it's a multi-part checkpoint
|
||||
if '-of-' in path.name:
|
||||
# Use regex to extract parts
|
||||
match = re.search(r'(.+?)-(\d+)-of-(\d+)', path.name)
|
||||
if match:
|
||||
base_name, current_part, total_parts = match.groups()
|
||||
|
||||
# Generate all part paths
|
||||
part_paths = [
|
||||
path.with_name(f'{base_name}-{i:05d}-of-{int(total_parts):05d}.safetensors')
|
||||
for i in range(1, int(total_parts) + 1)
|
||||
]
|
||||
|
||||
checkpoint_paths.extend(part_paths)
|
||||
else:
|
||||
# Single file checkpoint
|
||||
checkpoint_paths.append(path)
|
||||
|
||||
# Remove duplicates while preserving order
|
||||
checkpoint_paths = list(dict.fromkeys(checkpoint_paths))
|
||||
|
||||
else:
|
||||
# If ckpt_path is a single file
|
||||
checkpoint_paths = [ckpt_path]
|
||||
|
||||
return checkpoint_paths
|
||||
|
||||
def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]:
|
||||
"""
|
||||
チェックポイントの状態を分析し、DiffusersかBFLか、devかschnellか、ブロック数を計算して返す。
|
||||
@@ -43,12 +102,7 @@ def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int
|
||||
# check the state dict: Diffusers or BFL, dev or schnell, number of blocks
|
||||
logger.info(f"Checking the state dict: Diffusers or BFL, dev or schnell")
|
||||
|
||||
if os.path.isdir(ckpt_path): # if ckpt_path is a directory, it is Diffusers
|
||||
ckpt_path = os.path.join(ckpt_path, "transformer", "diffusion_pytorch_model-00001-of-00003.safetensors")
|
||||
if "00001-of-00003" in ckpt_path:
|
||||
ckpt_paths = [ckpt_path.replace("00001-of-00003", f"0000{i}-of-00003") for i in range(1, 4)]
|
||||
else:
|
||||
ckpt_paths = [ckpt_path]
|
||||
ckpt_paths = get_checkpoint_paths(ckpt_path)
|
||||
|
||||
keys = []
|
||||
for ckpt_path in ckpt_paths:
|
||||
|
||||
@@ -327,14 +327,17 @@ def save_sd_model_on_epoch_end_or_stepwise(
|
||||
|
||||
|
||||
def add_sdxl_training_arguments(parser: argparse.ArgumentParser, support_text_encoder_caching: bool = True):
|
||||
parser.add_argument(
|
||||
"--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_text_encoder_outputs_to_disk",
|
||||
action="store_true",
|
||||
help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
|
||||
)
|
||||
if support_text_encoder_caching:
|
||||
parser.add_argument(
|
||||
"--cache_text_encoder_outputs",
|
||||
action="store_true",
|
||||
help="cache text encoder outputs / text encoderの出力をキャッシュする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_text_encoder_outputs_to_disk",
|
||||
action="store_true",
|
||||
help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable_mmap_load_safetensors",
|
||||
action="store_true",
|
||||
@@ -342,7 +345,7 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser, support_text_en
|
||||
)
|
||||
|
||||
|
||||
def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):
|
||||
def verify_sdxl_training_args(args: argparse.Namespace, support_text_encoder_caching: bool = True):
|
||||
assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません"
|
||||
|
||||
if args.clip_skip is not None:
|
||||
@@ -365,7 +368,7 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin
|
||||
# not hasattr(args, "weighted_captions") or not args.weighted_captions
|
||||
# ), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"
|
||||
|
||||
if supportTextEncoderCaching:
|
||||
if support_text_encoder_caching:
|
||||
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
|
||||
args.cache_text_encoder_outputs = True
|
||||
logger.warning(
|
||||
|
||||
@@ -20,7 +20,8 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
|
||||
self.is_sdxl = True
|
||||
|
||||
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
|
||||
sdxl_train_util.verify_sdxl_training_args(args, supportTextEncoderCaching=False)
|
||||
# super().assert_extra_args(args, train_dataset_group) # do not call parent because it checks reso steps with 64
|
||||
sdxl_train_util.verify_sdxl_training_args(args, support_text_encoder_caching=False)
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(32)
|
||||
if val_dataset_group is not None:
|
||||
|
||||
93
tests/library/test_flux_utils.py
Normal file
93
tests/library/test_flux_utils.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
|
||||
from library.flux_utils import get_checkpoint_paths
|
||||
|
||||
|
||||
def test_get_checkpoint_paths():
|
||||
# Create a temporary directory for testing
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_path = Path(temp_dir)
|
||||
|
||||
# Scenario 1: Single safetensors file in root directory
|
||||
single_file = temp_path / "model.safetensors"
|
||||
single_file.touch()
|
||||
paths = get_checkpoint_paths(str(single_file))
|
||||
assert len(paths) == 1
|
||||
assert paths[0] == single_file
|
||||
|
||||
|
||||
def test_multiple_root_checkpoint_paths():
|
||||
"""
|
||||
Multiple single safetensors files in root directory
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_path = Path(temp_dir)
|
||||
# Scenario 2:
|
||||
file1 = temp_path / "model1.safetensors"
|
||||
file2 = temp_path / "model2.safetensors"
|
||||
file1.touch()
|
||||
file2.touch()
|
||||
paths = get_checkpoint_paths(temp_path)
|
||||
assert len(paths) == 2
|
||||
assert set(paths) == {file1, file2}
|
||||
|
||||
|
||||
def test_multipart_sharded_checkpoint():
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_path = Path(temp_dir)
|
||||
# Scenario 3: Sharded multi-part checkpoint
|
||||
# Create sharded checkpoint files
|
||||
base_name = "diffusion_pytorch_model"
|
||||
total_parts = 3
|
||||
for i in range(1, total_parts + 1):
|
||||
(temp_path / f"{base_name}-{i:05d}-of-{total_parts:05d}.safetensors").touch()
|
||||
|
||||
paths = get_checkpoint_paths(temp_path)
|
||||
assert len(paths) == total_parts
|
||||
|
||||
# Check if all expected part paths are present
|
||||
expected_paths = [temp_path / f"{base_name}-{i:05d}-of-{total_parts:05d}.safetensors" for i in range(1, total_parts + 1)]
|
||||
assert set(paths) == set(expected_paths)
|
||||
|
||||
|
||||
def test_transformer_model_dir():
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_path = Path(temp_dir)
|
||||
transformer_dir = temp_path / "transformer"
|
||||
transformer_dir.mkdir()
|
||||
transformer_file = transformer_dir / "diffusion_pytorch_model.safetensors"
|
||||
transformer_file.touch()
|
||||
|
||||
paths = get_checkpoint_paths(temp_path)
|
||||
assert transformer_file in paths
|
||||
|
||||
|
||||
def test_mixed_files_sharded_checkpoints():
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_path = Path(temp_dir)
|
||||
# Scenario 5: Mixed files and sharded checkpoints
|
||||
mixed_dir = temp_path / "mixed"
|
||||
mixed_dir.mkdir()
|
||||
|
||||
# Create a single file
|
||||
(mixed_dir / "single_model.safetensors").touch()
|
||||
|
||||
# Create sharded checkpoint
|
||||
base_name = "diffusion_pytorch_model"
|
||||
total_parts = 2
|
||||
for i in range(1, total_parts + 1):
|
||||
(mixed_dir / f"{base_name}-{i:05d}-of-{total_parts:05d}.safetensors").touch()
|
||||
|
||||
paths = get_checkpoint_paths(mixed_dir)
|
||||
assert len(paths) == total_parts + 1
|
||||
|
||||
# Verify correct handling of Path and str inputs
|
||||
path_input = mixed_dir
|
||||
str_input = str(mixed_dir)
|
||||
|
||||
path_paths = get_checkpoint_paths(path_input)
|
||||
str_paths = get_checkpoint_paths(str_input)
|
||||
|
||||
assert set(path_paths) == set(str_paths)
|
||||
Reference in New Issue
Block a user