mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 00:49:40 +00:00
Merge branch 'sd3' into fix-dataset-with-metadata
This commit is contained in:
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
@@ -22,7 +22,7 @@ jobs:
|
||||
matrix:
|
||||
os: [ubuntu-latest]
|
||||
python-version: ["3.10"] # Python versions to test
|
||||
pytorch-version: ["2.4.0"] # PyTorch versions to test
|
||||
pytorch-version: ["2.4.0", "2.6.0"] # PyTorch versions to test
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
17
README.md
17
README.md
@@ -4,18 +4,29 @@ This repository contains training, generation and utility scripts for Stable Dif
|
||||
|
||||
This feature is experimental. The options and the training script may change in the future. Please let us know if you have any idea to improve the training.
|
||||
|
||||
__Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchvision==0.19.0` with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe. `requirements.txt` is also updated, so please update the requirements.__
|
||||
__Please update PyTorch to 2.6.0 or later. We have tested with `torch==2.6.0` and `torchvision==0.21.0` with CUDA 12.4. `requirements.txt` is also updated, so please update the requirements.__
|
||||
|
||||
The command to install PyTorch is as follows:
|
||||
`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124`
|
||||
`pip3 install torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/cu124`
|
||||
|
||||
If you are using DeepSpeed, please install DeepSpeed with `pip install deepspeed==0.16.7`.
|
||||
For RTX 50 series GPUs, PyTorch 2.8.0 with CUDA 12.8/9 should be used. `requirements.txt` will work with this version.
|
||||
|
||||
If you are using DeepSpeed, please install DeepSpeed with `pip install deepspeed` (appropriate version is not confirmed yet).
|
||||
|
||||
- [FLUX.1 training](#flux1-training)
|
||||
- [SD3 training](#sd3-training)
|
||||
|
||||
### Recent Updates
|
||||
|
||||
Aug 28, 2025:
|
||||
- In order to support the latest GPUs and features, we have updated the **PyTorch and library versions**. PR [#2178](https://github.com/kohya-ss/sd-scripts/pull/2178) There are many changes, so please let us know if you encounter any issues.
|
||||
- The PyTorch version used for testing has been updated to 2.6.0. We have confirmed that it works with PyTorch 2.6.0 and later.
|
||||
- The `requirements.txt` has been updated, so please update your dependencies.
|
||||
- You can update the dependencies with `pip install -r requirements.txt`.
|
||||
- The version specification for `bitsandbytes` has been removed. If you encounter errors on RTX 50 series GPUs, please update it with `pip install -U bitsandbytes`.
|
||||
- We have modified each script to minimize warnings as much as possible.
|
||||
- The modified scripts will work in the old environment (library versions), but please update them when convenient.
|
||||
|
||||
Jul 30, 2025:
|
||||
- **Breaking Change**: For FLUX.1 and Chroma training, the CFG (Classifier-Free Guidance, using negative prompts) scale option for sample image generation during training has been changed from `--g` to `--l`. The `--g` option is now used for the embedded guidance scale. Please update your prompts accordingly. See [Sample Image Generation During Training](#sample-image-generation-during-training) for details.
|
||||
|
||||
|
||||
@@ -220,8 +220,12 @@ class DummyTextModel(torch.nn.Module):
|
||||
class DummyCLIPL(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.output_shape = (77, 1) # Note: The original code had (77, 768), but we use (77, 1) for the dummy output
|
||||
self.dummy_param = torch.nn.Parameter(torch.zeros(1)) # get dtype and device from this parameter
|
||||
self.output_shape = (77, 1) # Note: The original code had (77, 768), but we use (77, 1) for the dummy output
|
||||
|
||||
# dtype and device from these parameters. train_network.py accesses them
|
||||
self.dummy_param = torch.nn.Parameter(torch.zeros(1))
|
||||
self.dummy_param_2 = torch.nn.Parameter(torch.zeros(1))
|
||||
self.dummy_param_3 = torch.nn.Parameter(torch.zeros(1))
|
||||
self.text_model = DummyTextModel()
|
||||
|
||||
@property
|
||||
|
||||
@@ -6,6 +6,7 @@ import os
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex
|
||||
|
||||
init_ipex()
|
||||
|
||||
import diffusers
|
||||
@@ -14,8 +15,10 @@ from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # ,
|
||||
from safetensors.torch import load_file, save_file
|
||||
from library.original_unet import UNet2DConditionModel
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# DiffUsers版StableDiffusionのモデルパラメータ
|
||||
@@ -974,7 +977,7 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"):
|
||||
checkpoint = None
|
||||
state_dict = load_file(ckpt_path) # , device) # may causes error
|
||||
else:
|
||||
checkpoint = torch.load(ckpt_path, map_location=device)
|
||||
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)
|
||||
if "state_dict" in checkpoint:
|
||||
state_dict = checkpoint["state_dict"]
|
||||
else:
|
||||
|
||||
@@ -114,8 +114,10 @@ from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from einops import rearrange
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BLOCK_OUT_CHANNELS: Tuple[int] = (320, 640, 1280, 1280)
|
||||
@@ -530,7 +532,9 @@ class DownBlock2D(nn.Module):
|
||||
|
||||
return custom_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
@@ -626,15 +630,9 @@ class CrossAttention(nn.Module):
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
attention_mask,
|
||||
) = translate_attention_names_from_diffusers(
|
||||
hidden_states=hidden_states, context=context, mask=mask, **kwargs
|
||||
)
|
||||
) = translate_attention_names_from_diffusers(hidden_states=hidden_states, context=context, mask=mask, **kwargs)
|
||||
return self.processor(
|
||||
attn=self,
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=context,
|
||||
attention_mask=mask,
|
||||
**kwargs
|
||||
attn=self, hidden_states=hidden_states, encoder_hidden_states=context, attention_mask=mask, **kwargs
|
||||
)
|
||||
if self.use_memory_efficient_attention_xformers:
|
||||
return self.forward_memory_efficient_xformers(hidden_states, context, mask)
|
||||
@@ -748,13 +746,14 @@ class CrossAttention(nn.Module):
|
||||
out = self.to_out[0](out)
|
||||
return out
|
||||
|
||||
|
||||
def translate_attention_names_from_diffusers(
|
||||
hidden_states: torch.FloatTensor,
|
||||
context: Optional[torch.FloatTensor] = None,
|
||||
mask: Optional[torch.FloatTensor] = None,
|
||||
# HF naming
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
# translate from hugging face diffusers
|
||||
context = context if context is not None else encoder_hidden_states
|
||||
@@ -764,6 +763,7 @@ def translate_attention_names_from_diffusers(
|
||||
|
||||
return hidden_states, context, mask
|
||||
|
||||
|
||||
# feedforward
|
||||
class GEGLU(nn.Module):
|
||||
r"""
|
||||
@@ -1015,9 +1015,11 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
|
||||
return custom_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
|
||||
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
||||
)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, use_reentrant=False
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
@@ -1098,10 +1100,12 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
|
||||
if attn is not None:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
|
||||
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, use_reentrant=False
|
||||
)[0]
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
||||
)
|
||||
else:
|
||||
if attn is not None:
|
||||
hidden_states = attn(hidden_states, encoder_hidden_states).sample
|
||||
@@ -1201,7 +1205,9 @@ class UpBlock2D(nn.Module):
|
||||
|
||||
return custom_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
@@ -1296,9 +1302,11 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
|
||||
return custom_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
|
||||
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
||||
)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, use_reentrant=False
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
@@ -6008,7 +6008,6 @@ def get_noise_noisy_latents_and_timesteps(
|
||||
b_size = latents.shape[0]
|
||||
min_timestep = 0 if args.min_timestep is None else args.min_timestep
|
||||
max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep
|
||||
|
||||
timesteps = get_timesteps(min_timestep, max_timestep, b_size, latents.device)
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
@@ -6281,7 +6280,7 @@ def sample_images_common(
|
||||
vae,
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
unet,
|
||||
unet_wrapped,
|
||||
prompt_replacement=None,
|
||||
controlnet=None,
|
||||
):
|
||||
@@ -6316,7 +6315,7 @@ def sample_images_common(
|
||||
vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device
|
||||
|
||||
# unwrap unet and text_encoder(s)
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
unet = accelerator.unwrap_model(unet_wrapped)
|
||||
if isinstance(text_encoder, (list, tuple)):
|
||||
text_encoder = [accelerator.unwrap_model(te) for te in text_encoder]
|
||||
else:
|
||||
@@ -6462,7 +6461,7 @@ def sample_image_inference(
|
||||
logger.info(f"sample_sampler: {sampler_name}")
|
||||
if seed is not None:
|
||||
logger.info(f"seed: {seed}")
|
||||
with accelerator.autocast():
|
||||
with accelerator.autocast(), torch.no_grad():
|
||||
latents = pipeline(
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
|
||||
0
pytorch_lightning/__init__.py
Normal file
0
pytorch_lightning/__init__.py
Normal file
0
pytorch_lightning/callbacks/__init__.py
Normal file
0
pytorch_lightning/callbacks/__init__.py
Normal file
4
pytorch_lightning/callbacks/model_checkpoint.py
Normal file
4
pytorch_lightning/callbacks/model_checkpoint.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# dummy module for pytorch_lightning
|
||||
|
||||
class ModelCheckpoint:
|
||||
pass
|
||||
@@ -1,28 +1,29 @@
|
||||
accelerate==0.33.0
|
||||
transformers==4.44.0
|
||||
diffusers[torch]==0.25.0
|
||||
ftfy==6.1.1
|
||||
accelerate==1.6.0
|
||||
transformers==4.54.1
|
||||
diffusers[torch]==0.32.1
|
||||
ftfy==6.3.1
|
||||
# albumentations==1.3.0
|
||||
opencv-python==4.8.1.78
|
||||
opencv-python==4.10.0.84
|
||||
einops==0.7.0
|
||||
pytorch-lightning==1.9.0
|
||||
bitsandbytes==0.44.0
|
||||
lion-pytorch==0.0.6
|
||||
# pytorch-lightning==1.9.0
|
||||
bitsandbytes
|
||||
lion-pytorch==0.2.3
|
||||
schedulefree==1.4
|
||||
pytorch-optimizer==3.7.0
|
||||
prodigy-plus-schedule-free==1.9.0
|
||||
prodigy-plus-schedule-free==1.9.2
|
||||
prodigyopt==1.1.2
|
||||
tensorboard
|
||||
safetensors==0.4.4
|
||||
safetensors==0.4.5
|
||||
# gradio==3.16.2
|
||||
altair==4.2.2
|
||||
easygui==0.98.3
|
||||
# altair==4.2.2
|
||||
# easygui==0.98.3
|
||||
toml==0.10.2
|
||||
voluptuous==0.13.1
|
||||
huggingface-hub==0.24.5
|
||||
voluptuous==0.15.2
|
||||
huggingface-hub==0.34.3
|
||||
# for Image utils
|
||||
imagesize==1.4.1
|
||||
numpy<=2.0
|
||||
numpy
|
||||
# <=2.0
|
||||
# for BLIP captioning
|
||||
# requests==2.28.2
|
||||
# timm==0.6.12
|
||||
@@ -41,8 +42,8 @@ numpy<=2.0
|
||||
# open clip for SDXL
|
||||
# open-clip-torch==2.20.0
|
||||
# For logging
|
||||
rich==13.7.0
|
||||
rich==14.1.0
|
||||
# for T5XXL tokenizer (SD3/FLUX)
|
||||
sentencepiece==0.2.0
|
||||
sentencepiece==0.2.1
|
||||
# for kohya_ss library
|
||||
-e .
|
||||
|
||||
@@ -23,7 +23,12 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR
|
||||
self.is_sdxl = True
|
||||
|
||||
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
|
||||
def assert_extra_args(
|
||||
self,
|
||||
args,
|
||||
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
|
||||
val_dataset_group: Optional[train_util.DatasetGroup],
|
||||
):
|
||||
sdxl_train_util.verify_sdxl_training_args(args)
|
||||
|
||||
if args.cache_text_encoder_outputs:
|
||||
|
||||
@@ -20,7 +20,6 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
|
||||
self.is_sdxl = True
|
||||
|
||||
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
|
||||
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
|
||||
sdxl_train_util.verify_sdxl_training_args(args, supportTextEncoderCaching=False)
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(32)
|
||||
|
||||
@@ -12,7 +12,7 @@ import toml
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
from library import deepspeed_utils
|
||||
from library import deepspeed_utils, strategy_base, strategy_sd
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
@@ -73,7 +73,14 @@ def train(args):
|
||||
args.seed = random.randint(0, 2**32)
|
||||
set_seed(args.seed)
|
||||
|
||||
tokenizer = train_util.load_tokenizer(args)
|
||||
tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
|
||||
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
|
||||
tokenizer = tokenize_strategy.tokenizer
|
||||
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
|
||||
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
|
||||
True, args.cache_latents_to_disk, args.vae_batch_size, False
|
||||
)
|
||||
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
|
||||
|
||||
# データセットを準備する
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
|
||||
@@ -100,7 +107,7 @@ def train(args):
|
||||
]
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||
blueprint = blueprint_generator.generate(user_config, args)
|
||||
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
@@ -243,12 +250,7 @@ def train(args):
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
with torch.no_grad():
|
||||
train_dataset_group.cache_latents(
|
||||
vae,
|
||||
args.vae_batch_size,
|
||||
args.cache_latents_to_disk,
|
||||
accelerator.is_main_process,
|
||||
)
|
||||
train_dataset_group.new_cache_latents(vae, accelerator)
|
||||
vae.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
@@ -267,6 +269,7 @@ def train(args):
|
||||
|
||||
# dataloaderを準備する
|
||||
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
||||
train_dataset_group.set_current_strategies()
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
@@ -451,7 +454,7 @@ def train(args):
|
||||
latents = latents * 0.18215
|
||||
b_size = latents.shape[0]
|
||||
|
||||
input_ids = batch["input_ids"].to(accelerator.device)
|
||||
input_ids = batch["input_ids_list"][0].to(accelerator.device)
|
||||
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype)
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
|
||||
@@ -414,13 +414,12 @@ class NetworkTrainer:
|
||||
if text_encoder_outputs_list is not None:
|
||||
text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs
|
||||
|
||||
|
||||
if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder:
|
||||
# TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached'
|
||||
with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast():
|
||||
# Get the text embedding for conditioning
|
||||
if args.weighted_captions:
|
||||
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch['captions'])
|
||||
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"])
|
||||
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens_with_weights(
|
||||
tokenize_strategy,
|
||||
self.get_models_for_text_encoding(args, accelerator, text_encoders),
|
||||
@@ -1340,7 +1339,7 @@ class NetworkTrainer:
|
||||
)
|
||||
NUM_VALIDATION_TIMESTEPS = 4 # 200, 400, 600, 800 TODO make this configurable
|
||||
min_timestep = 0 if args.min_timestep is None else args.min_timestep
|
||||
max_timestep = noise_scheduler.num_train_timesteps if args.max_timestep is None else args.max_timestep
|
||||
max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep
|
||||
validation_timesteps = np.linspace(min_timestep, max_timestep, (NUM_VALIDATION_TIMESTEPS + 2), dtype=int)[1:-1]
|
||||
validation_total_steps = validation_steps * len(validation_timesteps)
|
||||
original_args_min_timestep = args.min_timestep
|
||||
|
||||
Reference in New Issue
Block a user