diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9e037e53..d35fe392 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -22,7 +22,7 @@ jobs: matrix: os: [ubuntu-latest] python-version: ["3.10"] # Python versions to test - pytorch-version: ["2.4.0"] # PyTorch versions to test + pytorch-version: ["2.4.0", "2.6.0"] # PyTorch versions to test steps: - uses: actions/checkout@v4 diff --git a/README.md b/README.md index 7b665dc2..5e569eab 100644 --- a/README.md +++ b/README.md @@ -4,18 +4,29 @@ This repository contains training, generation and utility scripts for Stable Dif This feature is experimental. The options and the training script may change in the future. Please let us know if you have any idea to improve the training. -__Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchvision==0.19.0` with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe. `requirements.txt` is also updated, so please update the requirements.__ +__Please update PyTorch to 2.6.0 or later. We have tested with `torch==2.6.0` and `torchvision==0.21.0` with CUDA 12.4. `requirements.txt` is also updated, so please update the requirements.__ The command to install PyTorch is as follows: -`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +`pip3 install torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/cu124` -If you are using DeepSpeed, please install DeepSpeed with `pip install deepspeed==0.16.7`. +For RTX 50 series GPUs, PyTorch 2.8.0 with CUDA 12.8/9 should be used. `requirements.txt` will work with this version. + +If you are using DeepSpeed, please install DeepSpeed with `pip install deepspeed` (appropriate version is not confirmed yet). - [FLUX.1 training](#flux1-training) - [SD3 training](#sd3-training) ### Recent Updates +Aug 28, 2025: +- In order to support the latest GPUs and features, we have updated the **PyTorch and library versions**. PR [#2178](https://github.com/kohya-ss/sd-scripts/pull/2178) There are many changes, so please let us know if you encounter any issues. +- The PyTorch version used for testing has been updated to 2.6.0. We have confirmed that it works with PyTorch 2.6.0 and later. +- The `requirements.txt` has been updated, so please update your dependencies. + - You can update the dependencies with `pip install -r requirements.txt`. + - The version specification for `bitsandbytes` has been removed. If you encounter errors on RTX 50 series GPUs, please update it with `pip install -U bitsandbytes`. +- We have modified each script to minimize warnings as much as possible. + - The modified scripts will work in the old environment (library versions), but please update them when convenient. + Jul 30, 2025: - **Breaking Change**: For FLUX.1 and Chroma training, the CFG (Classifier-Free Guidance, using negative prompts) scale option for sample image generation during training has been changed from `--g` to `--l`. The `--g` option is now used for the embedded guidance scale. Please update your prompts accordingly. See [Sample Image Generation During Training](#sample-image-generation-during-training) for details. diff --git a/library/flux_utils.py b/library/flux_utils.py index 3f0a0d63..22054854 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -220,8 +220,12 @@ class DummyTextModel(torch.nn.Module): class DummyCLIPL(torch.nn.Module): def __init__(self): super().__init__() - self.output_shape = (77, 1) # Note: The original code had (77, 768), but we use (77, 1) for the dummy output - self.dummy_param = torch.nn.Parameter(torch.zeros(1)) # get dtype and device from this parameter + self.output_shape = (77, 1) # Note: The original code had (77, 768), but we use (77, 1) for the dummy output + + # dtype and device from these parameters. train_network.py accesses them + self.dummy_param = torch.nn.Parameter(torch.zeros(1)) + self.dummy_param_2 = torch.nn.Parameter(torch.zeros(1)) + self.dummy_param_3 = torch.nn.Parameter(torch.zeros(1)) self.text_model = DummyTextModel() @property diff --git a/library/model_util.py b/library/model_util.py index 9918c7b2..bcaa1145 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -6,6 +6,7 @@ import os import torch from library.device_utils import init_ipex + init_ipex() import diffusers @@ -14,8 +15,10 @@ from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , from safetensors.torch import load_file, save_file from library.original_unet import UNet2DConditionModel from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) # DiffUsers版StableDiffusionのモデルパラメータ @@ -974,7 +977,7 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"): checkpoint = None state_dict = load_file(ckpt_path) # , device) # may causes error else: - checkpoint = torch.load(ckpt_path, map_location=device) + checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False) if "state_dict" in checkpoint: state_dict = checkpoint["state_dict"] else: diff --git a/library/original_unet.py b/library/original_unet.py index e944ff22..aa9dc233 100644 --- a/library/original_unet.py +++ b/library/original_unet.py @@ -114,8 +114,10 @@ from torch import nn from torch.nn import functional as F from einops import rearrange from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) BLOCK_OUT_CHANNELS: Tuple[int] = (320, 640, 1280, 1280) @@ -530,7 +532,9 @@ class DownBlock2D(nn.Module): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) else: hidden_states = resnet(hidden_states, temb) @@ -626,15 +630,9 @@ class CrossAttention(nn.Module): hidden_states, encoder_hidden_states, attention_mask, - ) = translate_attention_names_from_diffusers( - hidden_states=hidden_states, context=context, mask=mask, **kwargs - ) + ) = translate_attention_names_from_diffusers(hidden_states=hidden_states, context=context, mask=mask, **kwargs) return self.processor( - attn=self, - hidden_states=hidden_states, - encoder_hidden_states=context, - attention_mask=mask, - **kwargs + attn=self, hidden_states=hidden_states, encoder_hidden_states=context, attention_mask=mask, **kwargs ) if self.use_memory_efficient_attention_xformers: return self.forward_memory_efficient_xformers(hidden_states, context, mask) @@ -748,13 +746,14 @@ class CrossAttention(nn.Module): out = self.to_out[0](out) return out + def translate_attention_names_from_diffusers( hidden_states: torch.FloatTensor, context: Optional[torch.FloatTensor] = None, mask: Optional[torch.FloatTensor] = None, # HF naming encoder_hidden_states: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None + attention_mask: Optional[torch.FloatTensor] = None, ): # translate from hugging face diffusers context = context if context is not None else encoder_hidden_states @@ -764,6 +763,7 @@ def translate_attention_names_from_diffusers( return hidden_states, context, mask + # feedforward class GEGLU(nn.Module): r""" @@ -1015,9 +1015,11 @@ class CrossAttnDownBlock2D(nn.Module): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, use_reentrant=False )[0] else: hidden_states = resnet(hidden_states, temb) @@ -1098,10 +1100,12 @@ class UNetMidBlock2DCrossAttn(nn.Module): if attn is not None: hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states + create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, use_reentrant=False )[0] - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) else: if attn is not None: hidden_states = attn(hidden_states, encoder_hidden_states).sample @@ -1201,7 +1205,9 @@ class UpBlock2D(nn.Module): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) else: hidden_states = resnet(hidden_states, temb) @@ -1296,9 +1302,11 @@ class CrossAttnUpBlock2D(nn.Module): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, use_reentrant=False )[0] else: hidden_states = resnet(hidden_states, temb) diff --git a/library/train_util.py b/library/train_util.py index 61e42108..caafcc28 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -6008,7 +6008,6 @@ def get_noise_noisy_latents_and_timesteps( b_size = latents.shape[0] min_timestep = 0 if args.min_timestep is None else args.min_timestep max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep - timesteps = get_timesteps(min_timestep, max_timestep, b_size, latents.device) # Add noise to the latents according to the noise magnitude at each timestep @@ -6281,7 +6280,7 @@ def sample_images_common( vae, tokenizer, text_encoder, - unet, + unet_wrapped, prompt_replacement=None, controlnet=None, ): @@ -6316,7 +6315,7 @@ def sample_images_common( vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device # unwrap unet and text_encoder(s) - unet = accelerator.unwrap_model(unet) + unet = accelerator.unwrap_model(unet_wrapped) if isinstance(text_encoder, (list, tuple)): text_encoder = [accelerator.unwrap_model(te) for te in text_encoder] else: @@ -6462,7 +6461,7 @@ def sample_image_inference( logger.info(f"sample_sampler: {sampler_name}") if seed is not None: logger.info(f"seed: {seed}") - with accelerator.autocast(): + with accelerator.autocast(), torch.no_grad(): latents = pipeline( prompt=prompt, height=height, diff --git a/pytorch_lightning/__init__.py b/pytorch_lightning/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pytorch_lightning/callbacks/__init__.py b/pytorch_lightning/callbacks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py new file mode 100644 index 00000000..1ba14563 --- /dev/null +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -0,0 +1,4 @@ +# dummy module for pytorch_lightning + +class ModelCheckpoint: + pass diff --git a/requirements.txt b/requirements.txt index 448af323..624978b4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,28 +1,29 @@ -accelerate==0.33.0 -transformers==4.44.0 -diffusers[torch]==0.25.0 -ftfy==6.1.1 +accelerate==1.6.0 +transformers==4.54.1 +diffusers[torch]==0.32.1 +ftfy==6.3.1 # albumentations==1.3.0 -opencv-python==4.8.1.78 +opencv-python==4.10.0.84 einops==0.7.0 -pytorch-lightning==1.9.0 -bitsandbytes==0.44.0 -lion-pytorch==0.0.6 +# pytorch-lightning==1.9.0 +bitsandbytes +lion-pytorch==0.2.3 schedulefree==1.4 pytorch-optimizer==3.7.0 -prodigy-plus-schedule-free==1.9.0 +prodigy-plus-schedule-free==1.9.2 prodigyopt==1.1.2 tensorboard -safetensors==0.4.4 +safetensors==0.4.5 # gradio==3.16.2 -altair==4.2.2 -easygui==0.98.3 +# altair==4.2.2 +# easygui==0.98.3 toml==0.10.2 -voluptuous==0.13.1 -huggingface-hub==0.24.5 +voluptuous==0.15.2 +huggingface-hub==0.34.3 # for Image utils imagesize==1.4.1 -numpy<=2.0 +numpy +# <=2.0 # for BLIP captioning # requests==2.28.2 # timm==0.6.12 @@ -41,8 +42,8 @@ numpy<=2.0 # open clip for SDXL # open-clip-torch==2.20.0 # For logging -rich==13.7.0 +rich==14.1.0 # for T5XXL tokenizer (SD3/FLUX) -sentencepiece==0.2.0 +sentencepiece==0.2.1 # for kohya_ss library -e . diff --git a/sdxl_train_network.py b/sdxl_train_network.py index d56c76b0..5c5bcd63 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -23,7 +23,12 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer): self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR self.is_sdxl = True - def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): + def assert_extra_args( + self, + args, + train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], + val_dataset_group: Optional[train_util.DatasetGroup], + ): sdxl_train_util.verify_sdxl_training_args(args) if args.cache_text_encoder_outputs: diff --git a/sdxl_train_textual_inversion.py b/sdxl_train_textual_inversion.py index 98200760..be538cdd 100644 --- a/sdxl_train_textual_inversion.py +++ b/sdxl_train_textual_inversion.py @@ -20,7 +20,6 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine self.is_sdxl = True def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): - super().assert_extra_args(args, train_dataset_group, val_dataset_group) sdxl_train_util.verify_sdxl_training_args(args, supportTextEncoderCaching=False) train_dataset_group.verify_bucket_reso_steps(32) diff --git a/train_control_net.py b/train_control_net.py index 97cd1ebb..c12693ba 100644 --- a/train_control_net.py +++ b/train_control_net.py @@ -12,7 +12,7 @@ import toml from tqdm import tqdm import torch -from library import deepspeed_utils +from library import deepspeed_utils, strategy_base, strategy_sd from library.device_utils import init_ipex, clean_memory_on_device init_ipex() @@ -73,7 +73,14 @@ def train(args): args.seed = random.randint(0, 2**32) set_seed(args.seed) - tokenizer = train_util.load_tokenizer(args) + tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + tokenizer = tokenize_strategy.tokenizer + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + True, args.cache_latents_to_disk, args.vae_batch_size, False + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) # データセットを準備する blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) @@ -100,7 +107,7 @@ def train(args): ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + blueprint = blueprint_generator.generate(user_config, args) train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) current_epoch = Value("i", 0) @@ -243,12 +250,7 @@ def train(args): vae.requires_grad_(False) vae.eval() with torch.no_grad(): - train_dataset_group.cache_latents( - vae, - args.vae_batch_size, - args.cache_latents_to_disk, - accelerator.is_main_process, - ) + train_dataset_group.new_cache_latents(vae, accelerator) vae.to("cpu") clean_memory_on_device(accelerator.device) @@ -267,6 +269,7 @@ def train(args): # dataloaderを準備する # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + train_dataset_group.set_current_strategies() n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( @@ -451,7 +454,7 @@ def train(args): latents = latents * 0.18215 b_size = latents.shape[0] - input_ids = batch["input_ids"].to(accelerator.device) + input_ids = batch["input_ids_list"][0].to(accelerator.device) encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype) # Sample noise that we'll add to the latents diff --git a/train_network.py b/train_network.py index e055f5d8..3dedb574 100644 --- a/train_network.py +++ b/train_network.py @@ -414,13 +414,12 @@ class NetworkTrainer: if text_encoder_outputs_list is not None: text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs - if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder: # TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached' with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast(): # Get the text embedding for conditioning if args.weighted_captions: - input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch['captions']) + input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"]) encoded_text_encoder_conds = text_encoding_strategy.encode_tokens_with_weights( tokenize_strategy, self.get_models_for_text_encoding(args, accelerator, text_encoders), @@ -1340,7 +1339,7 @@ class NetworkTrainer: ) NUM_VALIDATION_TIMESTEPS = 4 # 200, 400, 600, 800 TODO make this configurable min_timestep = 0 if args.min_timestep is None else args.min_timestep - max_timestep = noise_scheduler.num_train_timesteps if args.max_timestep is None else args.max_timestep + max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep validation_timesteps = np.linspace(min_timestep, max_timestep, (NUM_VALIDATION_TIMESTEPS + 2), dtype=int)[1:-1] validation_total_steps = validation_steps * len(validation_timesteps) original_args_min_timestep = args.min_timestep