mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 06:28:48 +00:00
* feat: Add LoHa/LoKr network support for SDXL and Anima - networks/network_base.py: shared AdditionalNetwork base class with architecture auto-detection (SDXL/Anima) and generic module injection - networks/loha.py: LoHa (Low-rank Hadamard Product) module with HadaWeight custom autograd, training/inference classes, and factory functions - networks/lokr.py: LoKr (Low-rank Kronecker Product) module with factorization, training/inference classes, and factory functions - library/lora_utils.py: extend weight merge hook to detect and merge LoHa/LoKr weights alongside standard LoRA Linear and Conv2d 1x1 layers only; Conv2d 3x3 (Tucker decomposition) support will be added separately. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * feat: Enhance LoHa and LoKr modules with Tucker decomposition support - Added Tucker decomposition functionality to LoHa and LoKr modules. - Implemented new methods for weight rebuilding using Tucker decomposition. - Updated initialization and weight handling for Conv2d 3x3+ layers. - Modified get_diff_weight methods to accommodate Tucker and non-Tucker modes. - Enhanced network base to include unet_conv_target_modules for architecture detection. * fix: rank dropout handling in LoRAModule for Conv2d and Linear layers, see #2272 for details * doc: add dtype comment for load_safetensors_with_lora_and_fp8 function * fix: enhance architecture detection to support InferSdxlUNet2DConditionModel for gen_img.py * doc: update model support structure to include Lumina Image 2.0, HunyuanImage-2.1, and Anima-Preview * doc: add documentation for LoHa and LoKr fine-tuning methods * Update networks/network_base.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update docs/loha_lokr.md Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix: refactor LoHa and LoKr imports for weight merging in load_safetensors_with_lora_and_fp8 function --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
4.4 KiB
4.4 KiB
This file provides the overview and guidance for developers working with the codebase, including setup instructions, architecture details, and common commands.
Project Architecture
Core Training Framework
The codebase is built around a strategy pattern architecture that supports multiple diffusion model families:
library/strategy_base.py: Base classes for tokenization, text encoding, latent caching, and training strategieslibrary/strategy_*.py: Model-specific implementations for SD, SDXL, SD3, FLUX, etc.library/train_util.py: Core training utilities shared across all model typeslibrary/config_util.py: Configuration management with TOML support
Model Support Structure
Each supported model family has a consistent structure:
- Training script:
{model}_train.py(full fine-tuning),{model}_train_network.py(LoRA/network training) - Model utilities:
library/{model}_models.py,library/{model}_train_utils.py,library/{model}_utils.py - Networks:
networks/lora_{model}.py,networks/oft_{model}.pyfor adapter training
Supported Models
- Stable Diffusion 1.x:
train*.py,library/train_util.py,train_db.py(for DreamBooth) - SDXL:
sdxl_train*.py,library/sdxl_* - SD3:
sd3_train*.py,library/sd3_* - FLUX.1:
flux_train*.py,library/flux_* - Lumina Image 2.0:
lumina_train*.py,library/lumina_* - HunyuanImage-2.1:
hunyuan_image_train*.py,library/hunyuan_image_* - Anima-Preview:
anima_train*.py,library/anima_*
Key Components
Memory Management
- Block swapping: CPU-GPU memory optimization via
--blocks_to_swapparameter, works with custom offloading. Only available for models with transformer architectures like SD3 and FLUX.1. - Custom offloading:
library/custom_offloading_utils.pyfor advanced memory management - Gradient checkpointing: Memory reduction during training
Training Features
- LoRA training: Low-rank adaptation networks in
networks/lora*.py - ControlNet training: Conditional generation control
- Textual Inversion: Custom embedding training
- Multi-resolution training: Bucket-based aspect ratio handling
- Validation loss: Real-time training monitoring, only for LoRA training
Configuration System
Dataset configuration uses TOML files with structured validation:
[datasets.sample_dataset]
resolution = 1024
batch_size = 2
[[datasets.sample_dataset.subsets]]
image_dir = "path/to/images"
caption_extension = ".txt"
Common Development Commands
Training Commands Pattern
All training scripts follow this general pattern:
accelerate launch --mixed_precision bf16 {script_name}.py \
--pretrained_model_name_or_path model.safetensors \
--dataset_config config.toml \
--output_dir output \
--output_name model_name \
[model-specific options]
Memory Optimization
For low VRAM environments, use block swapping:
# Add to any training command for memory reduction
--blocks_to_swap 10 # Swap 10 blocks to CPU (adjust number as needed)
Utility Scripts
Located in tools/ directory:
tools/merge_lora.py: Merge LoRA weights into base modelstools/cache_latents.py: Pre-cache VAE latents for faster trainingtools/cache_text_encoder_outputs.py: Pre-cache text encoder outputs
Development Notes
Strategy Pattern Implementation
When adding support for new models, implement the four core strategies:
TokenizeStrategy: Text tokenization handlingTextEncodingStrategy: Text encoder forward passLatentsCachingStrategy: VAE encoding/cachingTextEncoderOutputsCachingStrategy: Text encoder output caching
Testing Approach
- Unit tests focus on utility functions and model loading
- Integration tests validate training script syntax and basic execution
- Most tests use mocks to avoid requiring actual model files
- Add tests for new model support in
tests/test_{model}_*.py
Configuration System
- Use
config_util.pydataclasses for type-safe configuration - Support both command-line arguments and TOML file configuration
- Validate configuration early in training scripts to prevent runtime errors
Memory Management
- Always consider VRAM limitations when implementing features
- Use gradient checkpointing for large models
- Implement block swapping for models with transformer architectures
- Cache intermediate results (latents, text embeddings) when possible