Merge branch 'sd3' into fix-dataset-with-metadata

This commit is contained in:
Kohya S
2025-08-30 10:32:44 +09:00
14 changed files with 96 additions and 60 deletions

View File

@@ -22,7 +22,7 @@ jobs:
matrix:
os: [ubuntu-latest]
python-version: ["3.10"] # Python versions to test
pytorch-version: ["2.4.0"] # PyTorch versions to test
pytorch-version: ["2.4.0", "2.6.0"] # PyTorch versions to test
steps:
- uses: actions/checkout@v4

View File

@@ -4,18 +4,29 @@ This repository contains training, generation and utility scripts for Stable Dif
This feature is experimental. The options and the training script may change in the future. Please let us know if you have any idea to improve the training.
__Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchvision==0.19.0` with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe. `requirements.txt` is also updated, so please update the requirements.__
__Please update PyTorch to 2.6.0 or later. We have tested with `torch==2.6.0` and `torchvision==0.21.0` with CUDA 12.4. `requirements.txt` is also updated, so please update the requirements.__
The command to install PyTorch is as follows:
`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124`
`pip3 install torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/cu124`
If you are using DeepSpeed, please install DeepSpeed with `pip install deepspeed==0.16.7`.
For RTX 50 series GPUs, PyTorch 2.8.0 with CUDA 12.8/9 should be used. `requirements.txt` will work with this version.
If you are using DeepSpeed, please install DeepSpeed with `pip install deepspeed` (appropriate version is not confirmed yet).
- [FLUX.1 training](#flux1-training)
- [SD3 training](#sd3-training)
### Recent Updates
Aug 28, 2025:
- In order to support the latest GPUs and features, we have updated the **PyTorch and library versions**. PR [#2178](https://github.com/kohya-ss/sd-scripts/pull/2178) There are many changes, so please let us know if you encounter any issues.
- The PyTorch version used for testing has been updated to 2.6.0. We have confirmed that it works with PyTorch 2.6.0 and later.
- The `requirements.txt` has been updated, so please update your dependencies.
- You can update the dependencies with `pip install -r requirements.txt`.
- The version specification for `bitsandbytes` has been removed. If you encounter errors on RTX 50 series GPUs, please update it with `pip install -U bitsandbytes`.
- We have modified each script to minimize warnings as much as possible.
- The modified scripts will work in the old environment (library versions), but please update them when convenient.
Jul 30, 2025:
- **Breaking Change**: For FLUX.1 and Chroma training, the CFG (Classifier-Free Guidance, using negative prompts) scale option for sample image generation during training has been changed from `--g` to `--l`. The `--g` option is now used for the embedded guidance scale. Please update your prompts accordingly. See [Sample Image Generation During Training](#sample-image-generation-during-training) for details.

View File

@@ -220,8 +220,12 @@ class DummyTextModel(torch.nn.Module):
class DummyCLIPL(torch.nn.Module):
def __init__(self):
super().__init__()
self.output_shape = (77, 1) # Note: The original code had (77, 768), but we use (77, 1) for the dummy output
self.dummy_param = torch.nn.Parameter(torch.zeros(1)) # get dtype and device from this parameter
self.output_shape = (77, 1) # Note: The original code had (77, 768), but we use (77, 1) for the dummy output
# dtype and device from these parameters. train_network.py accesses them
self.dummy_param = torch.nn.Parameter(torch.zeros(1))
self.dummy_param_2 = torch.nn.Parameter(torch.zeros(1))
self.dummy_param_3 = torch.nn.Parameter(torch.zeros(1))
self.text_model = DummyTextModel()
@property

View File

@@ -6,6 +6,7 @@ import os
import torch
from library.device_utils import init_ipex
init_ipex()
import diffusers
@@ -14,8 +15,10 @@ from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # ,
from safetensors.torch import load_file, save_file
from library.original_unet import UNet2DConditionModel
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
# DiffUsers版StableDiffusionのモデルパラメータ
@@ -974,7 +977,7 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"):
checkpoint = None
state_dict = load_file(ckpt_path) # , device) # may causes error
else:
checkpoint = torch.load(ckpt_path, map_location=device)
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)
if "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
else:

View File

@@ -114,8 +114,10 @@ from torch import nn
from torch.nn import functional as F
from einops import rearrange
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
BLOCK_OUT_CHANNELS: Tuple[int] = (320, 640, 1280, 1280)
@@ -530,7 +532,9 @@ class DownBlock2D(nn.Module):
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
hidden_states = resnet(hidden_states, temb)
@@ -626,15 +630,9 @@ class CrossAttention(nn.Module):
hidden_states,
encoder_hidden_states,
attention_mask,
) = translate_attention_names_from_diffusers(
hidden_states=hidden_states, context=context, mask=mask, **kwargs
)
) = translate_attention_names_from_diffusers(hidden_states=hidden_states, context=context, mask=mask, **kwargs)
return self.processor(
attn=self,
hidden_states=hidden_states,
encoder_hidden_states=context,
attention_mask=mask,
**kwargs
attn=self, hidden_states=hidden_states, encoder_hidden_states=context, attention_mask=mask, **kwargs
)
if self.use_memory_efficient_attention_xformers:
return self.forward_memory_efficient_xformers(hidden_states, context, mask)
@@ -748,13 +746,14 @@ class CrossAttention(nn.Module):
out = self.to_out[0](out)
return out
def translate_attention_names_from_diffusers(
hidden_states: torch.FloatTensor,
context: Optional[torch.FloatTensor] = None,
mask: Optional[torch.FloatTensor] = None,
# HF naming
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None
attention_mask: Optional[torch.FloatTensor] = None,
):
# translate from hugging face diffusers
context = context if context is not None else encoder_hidden_states
@@ -764,6 +763,7 @@ def translate_attention_names_from_diffusers(
return hidden_states, context, mask
# feedforward
class GEGLU(nn.Module):
r"""
@@ -1015,9 +1015,11 @@ class CrossAttnDownBlock2D(nn.Module):
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, use_reentrant=False
)[0]
else:
hidden_states = resnet(hidden_states, temb)
@@ -1098,10 +1100,12 @@ class UNetMidBlock2DCrossAttn(nn.Module):
if attn is not None:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, use_reentrant=False
)[0]
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
if attn is not None:
hidden_states = attn(hidden_states, encoder_hidden_states).sample
@@ -1201,7 +1205,9 @@ class UpBlock2D(nn.Module):
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
hidden_states = resnet(hidden_states, temb)
@@ -1296,9 +1302,11 @@ class CrossAttnUpBlock2D(nn.Module):
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, use_reentrant=False
)[0]
else:
hidden_states = resnet(hidden_states, temb)

View File

@@ -6008,7 +6008,6 @@ def get_noise_noisy_latents_and_timesteps(
b_size = latents.shape[0]
min_timestep = 0 if args.min_timestep is None else args.min_timestep
max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep
timesteps = get_timesteps(min_timestep, max_timestep, b_size, latents.device)
# Add noise to the latents according to the noise magnitude at each timestep
@@ -6281,7 +6280,7 @@ def sample_images_common(
vae,
tokenizer,
text_encoder,
unet,
unet_wrapped,
prompt_replacement=None,
controlnet=None,
):
@@ -6316,7 +6315,7 @@ def sample_images_common(
vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device
# unwrap unet and text_encoder(s)
unet = accelerator.unwrap_model(unet)
unet = accelerator.unwrap_model(unet_wrapped)
if isinstance(text_encoder, (list, tuple)):
text_encoder = [accelerator.unwrap_model(te) for te in text_encoder]
else:
@@ -6462,7 +6461,7 @@ def sample_image_inference(
logger.info(f"sample_sampler: {sampler_name}")
if seed is not None:
logger.info(f"seed: {seed}")
with accelerator.autocast():
with accelerator.autocast(), torch.no_grad():
latents = pipeline(
prompt=prompt,
height=height,

View File

View File

View File

@@ -0,0 +1,4 @@
# dummy module for pytorch_lightning
class ModelCheckpoint:
pass

View File

@@ -1,28 +1,29 @@
accelerate==0.33.0
transformers==4.44.0
diffusers[torch]==0.25.0
ftfy==6.1.1
accelerate==1.6.0
transformers==4.54.1
diffusers[torch]==0.32.1
ftfy==6.3.1
# albumentations==1.3.0
opencv-python==4.8.1.78
opencv-python==4.10.0.84
einops==0.7.0
pytorch-lightning==1.9.0
bitsandbytes==0.44.0
lion-pytorch==0.0.6
# pytorch-lightning==1.9.0
bitsandbytes
lion-pytorch==0.2.3
schedulefree==1.4
pytorch-optimizer==3.7.0
prodigy-plus-schedule-free==1.9.0
prodigy-plus-schedule-free==1.9.2
prodigyopt==1.1.2
tensorboard
safetensors==0.4.4
safetensors==0.4.5
# gradio==3.16.2
altair==4.2.2
easygui==0.98.3
# altair==4.2.2
# easygui==0.98.3
toml==0.10.2
voluptuous==0.13.1
huggingface-hub==0.24.5
voluptuous==0.15.2
huggingface-hub==0.34.3
# for Image utils
imagesize==1.4.1
numpy<=2.0
numpy
# <=2.0
# for BLIP captioning
# requests==2.28.2
# timm==0.6.12
@@ -41,8 +42,8 @@ numpy<=2.0
# open clip for SDXL
# open-clip-torch==2.20.0
# For logging
rich==13.7.0
rich==14.1.0
# for T5XXL tokenizer (SD3/FLUX)
sentencepiece==0.2.0
sentencepiece==0.2.1
# for kohya_ss library
-e .

View File

@@ -23,7 +23,12 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR
self.is_sdxl = True
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
def assert_extra_args(
self,
args,
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
val_dataset_group: Optional[train_util.DatasetGroup],
):
sdxl_train_util.verify_sdxl_training_args(args)
if args.cache_text_encoder_outputs:

View File

@@ -20,7 +20,6 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
self.is_sdxl = True
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
sdxl_train_util.verify_sdxl_training_args(args, supportTextEncoderCaching=False)
train_dataset_group.verify_bucket_reso_steps(32)

View File

@@ -12,7 +12,7 @@ import toml
from tqdm import tqdm
import torch
from library import deepspeed_utils
from library import deepspeed_utils, strategy_base, strategy_sd
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
@@ -73,7 +73,14 @@ def train(args):
args.seed = random.randint(0, 2**32)
set_seed(args.seed)
tokenizer = train_util.load_tokenizer(args)
tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
tokenizer = tokenize_strategy.tokenizer
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
True, args.cache_latents_to_disk, args.vae_batch_size, False
)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
# データセットを準備する
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
@@ -100,7 +107,7 @@ def train(args):
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
current_epoch = Value("i", 0)
@@ -243,12 +250,7 @@ def train(args):
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(
vae,
args.vae_batch_size,
args.cache_latents_to_disk,
accelerator.is_main_process,
)
train_dataset_group.new_cache_latents(vae, accelerator)
vae.to("cpu")
clean_memory_on_device(accelerator.device)
@@ -267,6 +269,7 @@ def train(args):
# dataloaderを準備する
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
train_dataset_group.set_current_strategies()
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
@@ -451,7 +454,7 @@ def train(args):
latents = latents * 0.18215
b_size = latents.shape[0]
input_ids = batch["input_ids"].to(accelerator.device)
input_ids = batch["input_ids_list"][0].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype)
# Sample noise that we'll add to the latents

View File

@@ -414,13 +414,12 @@ class NetworkTrainer:
if text_encoder_outputs_list is not None:
text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs
if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder:
# TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached'
with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast():
# Get the text embedding for conditioning
if args.weighted_captions:
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch['captions'])
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"])
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens_with_weights(
tokenize_strategy,
self.get_models_for_text_encoding(args, accelerator, text_encoders),
@@ -1340,7 +1339,7 @@ class NetworkTrainer:
)
NUM_VALIDATION_TIMESTEPS = 4 # 200, 400, 600, 800 TODO make this configurable
min_timestep = 0 if args.min_timestep is None else args.min_timestep
max_timestep = noise_scheduler.num_train_timesteps if args.max_timestep is None else args.max_timestep
max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep
validation_timesteps = np.linspace(min_timestep, max_timestep, (NUM_VALIDATION_TIMESTEPS + 2), dtype=int)[1:-1]
validation_total_steps = validation_steps * len(validation_timesteps)
original_args_min_timestep = args.min_timestep