From 6edbe00547bac6c2efb2d6952eb910851662cdf2 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 16 Aug 2025 20:07:03 +0900 Subject: [PATCH 1/8] feat: update libraries, remove warnings --- library/model_util.py | 5 +- library/original_unet.py | 42 ++++++----- library/train_util.py | 72 ++++++++++++------- pytorch_lightning/__init__.py | 0 pytorch_lightning/callbacks/__init__.py | 0 .../callbacks/model_checkpoint.py | 4 ++ requirements.txt | 35 ++++----- sdxl_train_network.py | 7 +- train_network.py | 5 +- 9 files changed, 107 insertions(+), 63 deletions(-) create mode 100644 pytorch_lightning/__init__.py create mode 100644 pytorch_lightning/callbacks/__init__.py create mode 100644 pytorch_lightning/callbacks/model_checkpoint.py diff --git a/library/model_util.py b/library/model_util.py index 9918c7b2..bcaa1145 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -6,6 +6,7 @@ import os import torch from library.device_utils import init_ipex + init_ipex() import diffusers @@ -14,8 +15,10 @@ from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , from safetensors.torch import load_file, save_file from library.original_unet import UNet2DConditionModel from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) # DiffUsers版StableDiffusionのモデルパラメータ @@ -974,7 +977,7 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"): checkpoint = None state_dict = load_file(ckpt_path) # , device) # may causes error else: - checkpoint = torch.load(ckpt_path, map_location=device) + checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False) if "state_dict" in checkpoint: state_dict = checkpoint["state_dict"] else: diff --git a/library/original_unet.py b/library/original_unet.py index e944ff22..aa9dc233 100644 --- a/library/original_unet.py +++ b/library/original_unet.py @@ -114,8 +114,10 @@ from torch import nn from torch.nn import functional as F from einops import rearrange from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) BLOCK_OUT_CHANNELS: Tuple[int] = (320, 640, 1280, 1280) @@ -530,7 +532,9 @@ class DownBlock2D(nn.Module): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) else: hidden_states = resnet(hidden_states, temb) @@ -626,15 +630,9 @@ class CrossAttention(nn.Module): hidden_states, encoder_hidden_states, attention_mask, - ) = translate_attention_names_from_diffusers( - hidden_states=hidden_states, context=context, mask=mask, **kwargs - ) + ) = translate_attention_names_from_diffusers(hidden_states=hidden_states, context=context, mask=mask, **kwargs) return self.processor( - attn=self, - hidden_states=hidden_states, - encoder_hidden_states=context, - attention_mask=mask, - **kwargs + attn=self, hidden_states=hidden_states, encoder_hidden_states=context, attention_mask=mask, **kwargs ) if self.use_memory_efficient_attention_xformers: return self.forward_memory_efficient_xformers(hidden_states, context, mask) @@ -748,13 +746,14 @@ class CrossAttention(nn.Module): out = self.to_out[0](out) return out + def translate_attention_names_from_diffusers( hidden_states: torch.FloatTensor, context: Optional[torch.FloatTensor] = None, mask: Optional[torch.FloatTensor] = None, # HF naming encoder_hidden_states: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None + attention_mask: Optional[torch.FloatTensor] = None, ): # translate from hugging face diffusers context = context if context is not None else encoder_hidden_states @@ -764,6 +763,7 @@ def translate_attention_names_from_diffusers( return hidden_states, context, mask + # feedforward class GEGLU(nn.Module): r""" @@ -1015,9 +1015,11 @@ class CrossAttnDownBlock2D(nn.Module): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, use_reentrant=False )[0] else: hidden_states = resnet(hidden_states, temb) @@ -1098,10 +1100,12 @@ class UNetMidBlock2DCrossAttn(nn.Module): if attn is not None: hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states + create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, use_reentrant=False )[0] - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) else: if attn is not None: hidden_states = attn(hidden_states, encoder_hidden_states).sample @@ -1201,7 +1205,9 @@ class UpBlock2D(nn.Module): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) else: hidden_states = resnet(hidden_states, temb) @@ -1296,9 +1302,11 @@ class CrossAttnUpBlock2D(nn.Module): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, use_reentrant=False )[0] else: hidden_states = resnet(hidden_states, temb) diff --git a/library/train_util.py b/library/train_util.py index 39518395..b432d0b6 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -683,7 +683,7 @@ class BaseDataset(torch.utils.data.Dataset): resolution: Optional[Tuple[int, int]], network_multiplier: float, debug_dataset: bool, - resize_interpolation: Optional[str] = None + resize_interpolation: Optional[str] = None, ) -> None: super().__init__() @@ -719,7 +719,9 @@ class BaseDataset(torch.utils.data.Dataset): self.image_transforms = IMAGE_TRANSFORMS if resize_interpolation is not None: - assert validate_interpolation_fn(resize_interpolation), f"Resize interpolation \"{resize_interpolation}\" is not a valid interpolation" + assert validate_interpolation_fn( + resize_interpolation + ), f'Resize interpolation "{resize_interpolation}" is not a valid interpolation' self.resize_interpolation = resize_interpolation self.image_data: Dict[str, ImageInfo] = {} @@ -1613,7 +1615,11 @@ class BaseDataset(torch.utils.data.Dataset): if self.enable_bucket: img, original_size, crop_ltrb = trim_and_resize_if_required( - subset.random_crop, img, image_info.bucket_reso, image_info.resized_size, resize_interpolation=image_info.resize_interpolation + subset.random_crop, + img, + image_info.bucket_reso, + image_info.resized_size, + resize_interpolation=image_info.resize_interpolation, ) else: if face_cx > 0: # 顔位置情報あり @@ -2101,7 +2107,9 @@ class DreamBoothDataset(BaseDataset): for img_path, caption, size in zip(img_paths, captions, sizes): info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path) - info.resize_interpolation = subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation + info.resize_interpolation = ( + subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation + ) if size is not None: info.image_size = size if subset.is_reg: @@ -2385,7 +2393,7 @@ class ControlNetDataset(BaseDataset): bucket_no_upscale: bool, debug_dataset: bool, validation_split: float, - validation_seed: Optional[int], + validation_seed: Optional[int], resize_interpolation: Optional[str] = None, ) -> None: super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation) @@ -2448,7 +2456,7 @@ class ControlNetDataset(BaseDataset): self.num_train_images = self.dreambooth_dataset_delegate.num_train_images self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images self.validation_split = validation_split - self.validation_seed = validation_seed + self.validation_seed = validation_seed self.resize_interpolation = resize_interpolation # assert all conditioning data exists @@ -2538,7 +2546,14 @@ class ControlNetDataset(BaseDataset): cond_img.shape[0] == original_size_hw[0] and cond_img.shape[1] == original_size_hw[1] ), f"size of conditioning image is not match / 画像サイズが合いません: {image_info.absolute_path}" - cond_img = resize_image(cond_img, original_size_hw[1], original_size_hw[0], target_size_hw[1], target_size_hw[0], self.resize_interpolation) + cond_img = resize_image( + cond_img, + original_size_hw[1], + original_size_hw[0], + target_size_hw[1], + target_size_hw[0], + self.resize_interpolation, + ) # TODO support random crop # 現在サポートしているcropはrandomではなく中央のみ @@ -2552,7 +2567,14 @@ class ControlNetDataset(BaseDataset): # ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}" # resize to target if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]: - cond_img = resize_image(cond_img, cond_img.shape[0], cond_img.shape[1], target_size_hw[1], target_size_hw[0], self.resize_interpolation) + cond_img = resize_image( + cond_img, + cond_img.shape[0], + cond_img.shape[1], + target_size_hw[1], + target_size_hw[0], + self.resize_interpolation, + ) if flipped: cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride @@ -3000,7 +3022,9 @@ def load_images_and_masks_for_caching( for info in image_infos: image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8) # TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要 - image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation) + image, original_size, crop_ltrb = trim_and_resize_if_required( + random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation + ) original_sizes.append(original_size) crop_ltrbs.append(crop_ltrb) @@ -3041,7 +3065,9 @@ def cache_batch_latents( for info in image_infos: image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8) # TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要 - image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation) + image, original_size, crop_ltrb = trim_and_resize_if_required( + random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation + ) info.latents_original_size = original_size info.latents_crop_ltrb = crop_ltrb @@ -3482,9 +3508,9 @@ def get_sai_model_spec( textual_inversion: bool, is_stable_diffusion_ckpt: Optional[bool] = None, # None for TI and LoRA sd3: str = None, - flux: str = None, # "dev", "schnell" or "chroma" + flux: str = None, # "dev", "schnell" or "chroma" lumina: str = None, - optional_metadata: dict[str, str] | None = None + optional_metadata: dict[str, str] | None = None, ): timestamp = time.time() @@ -3513,7 +3539,7 @@ def get_sai_model_spec( # Extract metadata_* fields from args and merge with optional_metadata extracted_metadata = {} - + # Extract all metadata_* attributes from args for attr_name in dir(args): if attr_name.startswith("metadata_") and not attr_name.startswith("metadata___"): @@ -3523,7 +3549,7 @@ def get_sai_model_spec( field_name = attr_name[9:] # len("metadata_") = 9 if field_name not in ["title", "author", "description", "license", "tags"]: extracted_metadata[field_name] = value - + # Merge extracted metadata with provided optional_metadata all_optional_metadata = {**extracted_metadata} if optional_metadata: @@ -3546,7 +3572,7 @@ def get_sai_model_spec( tags=args.metadata_tags, timesteps=timesteps, clip_skip=args.clip_skip, # None or int - model_config=model_config, + model_config=model_config, optional_metadata=all_optional_metadata if all_optional_metadata else None, ) return metadata @@ -3562,7 +3588,7 @@ def get_sai_model_spec_dataclass( sd3: str = None, flux: str = None, lumina: str = None, - optional_metadata: dict[str, str] | None = None + optional_metadata: dict[str, str] | None = None, ) -> sai_model_spec.ModelSpecMetadata: """ Get ModelSpec metadata as a dataclass - preferred for new code. @@ -5558,11 +5584,12 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio def patch_accelerator_for_fp16_training(accelerator): - + from accelerate import DistributedType + if accelerator.distributed_type == DistributedType.DEEPSPEED: return - + org_unscale_grads = accelerator.scaler._unscale_grads_ def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): @@ -6054,7 +6081,6 @@ def get_noise_noisy_latents_and_timesteps( b_size = latents.shape[0] min_timestep = 0 if args.min_timestep is None else args.min_timestep max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep - timesteps = get_timesteps(min_timestep, max_timestep, b_size, latents.device) # Add noise to the latents according to the noise magnitude at each timestep @@ -6279,7 +6305,6 @@ def line_to_prompt_dict(line: str) -> dict: prompt_dict["renorm_cfg"] = float(m.group(1)) continue - except ValueError as ex: logger.error(f"Exception in parsing / 解析エラー: {parg}") logger.error(ex) @@ -6328,7 +6353,7 @@ def sample_images_common( vae, tokenizer, text_encoder, - unet, + unet_wrapped, prompt_replacement=None, controlnet=None, ): @@ -6363,7 +6388,7 @@ def sample_images_common( vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device # unwrap unet and text_encoder(s) - unet = accelerator.unwrap_model(unet) + unet = accelerator.unwrap_model(unet_wrapped) if isinstance(text_encoder, (list, tuple)): text_encoder = [accelerator.unwrap_model(te) for te in text_encoder] else: @@ -6509,7 +6534,7 @@ def sample_image_inference( logger.info(f"sample_sampler: {sampler_name}") if seed is not None: logger.info(f"seed: {seed}") - with accelerator.autocast(): + with accelerator.autocast(), torch.no_grad(): latents = pipeline( prompt=prompt, height=height, @@ -6647,4 +6672,3 @@ class LossRecorder: if losses == 0: return 0 return self.loss_total / losses - diff --git a/pytorch_lightning/__init__.py b/pytorch_lightning/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pytorch_lightning/callbacks/__init__.py b/pytorch_lightning/callbacks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py new file mode 100644 index 00000000..1ba14563 --- /dev/null +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -0,0 +1,4 @@ +# dummy module for pytorch_lightning + +class ModelCheckpoint: + pass diff --git a/requirements.txt b/requirements.txt index 448af323..7c7060c7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,28 +1,29 @@ -accelerate==0.33.0 -transformers==4.44.0 -diffusers[torch]==0.25.0 -ftfy==6.1.1 +accelerate==1.6.0 +transformers==4.54.1 +diffusers[torch]==0.32.1 +ftfy==6.3.1 # albumentations==1.3.0 -opencv-python==4.8.1.78 +opencv-python==4.10.0.84 einops==0.7.0 -pytorch-lightning==1.9.0 -bitsandbytes==0.44.0 -lion-pytorch==0.0.6 +# pytorch-lightning==1.9.0 +bitsandbytes==0.45.4 +lion-pytorch==0.2.3 schedulefree==1.4 pytorch-optimizer==3.7.0 -prodigy-plus-schedule-free==1.9.0 +prodigy-plus-schedule-free==1.9.2 prodigyopt==1.1.2 tensorboard -safetensors==0.4.4 +safetensors==0.4.5 # gradio==3.16.2 -altair==4.2.2 -easygui==0.98.3 +# altair==4.2.2 +# easygui==0.98.3 toml==0.10.2 -voluptuous==0.13.1 -huggingface-hub==0.24.5 +voluptuous==0.15.2 +huggingface-hub==0.34.3 # for Image utils imagesize==1.4.1 -numpy<=2.0 +numpy +# <=2.0 # for BLIP captioning # requests==2.28.2 # timm==0.6.12 @@ -41,8 +42,8 @@ numpy<=2.0 # open clip for SDXL # open-clip-torch==2.20.0 # For logging -rich==13.7.0 +rich==14.1.0 # for T5XXL tokenizer (SD3/FLUX) -sentencepiece==0.2.0 +sentencepiece==0.2.1 # for kohya_ss library -e . diff --git a/sdxl_train_network.py b/sdxl_train_network.py index d56c76b0..5c5bcd63 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -23,7 +23,12 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer): self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR self.is_sdxl = True - def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): + def assert_extra_args( + self, + args, + train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], + val_dataset_group: Optional[train_util.DatasetGroup], + ): sdxl_train_util.verify_sdxl_training_args(args) if args.cache_text_encoder_outputs: diff --git a/train_network.py b/train_network.py index e055f5d8..3dedb574 100644 --- a/train_network.py +++ b/train_network.py @@ -414,13 +414,12 @@ class NetworkTrainer: if text_encoder_outputs_list is not None: text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs - if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder: # TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached' with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast(): # Get the text embedding for conditioning if args.weighted_captions: - input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch['captions']) + input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"]) encoded_text_encoder_conds = text_encoding_strategy.encode_tokens_with_weights( tokenize_strategy, self.get_models_for_text_encoding(args, accelerator, text_encoders), @@ -1340,7 +1339,7 @@ class NetworkTrainer: ) NUM_VALIDATION_TIMESTEPS = 4 # 200, 400, 600, 800 TODO make this configurable min_timestep = 0 if args.min_timestep is None else args.min_timestep - max_timestep = noise_scheduler.num_train_timesteps if args.max_timestep is None else args.max_timestep + max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep validation_timesteps = np.linspace(min_timestep, max_timestep, (NUM_VALIDATION_TIMESTEPS + 2), dtype=int)[1:-1] validation_total_steps = validation_steps * len(validation_timesteps) original_args_min_timestep = args.min_timestep From 6f24bce7ccacdf0c13614fe84413c2446adbf35c Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 16 Aug 2025 22:03:31 +0900 Subject: [PATCH 2/8] fix: remove unnecessary super call in assert_extra_args method --- sdxl_train_textual_inversion.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sdxl_train_textual_inversion.py b/sdxl_train_textual_inversion.py index 98200760..be538cdd 100644 --- a/sdxl_train_textual_inversion.py +++ b/sdxl_train_textual_inversion.py @@ -20,7 +20,6 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine self.is_sdxl = True def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): - super().assert_extra_args(args, train_dataset_group, val_dataset_group) sdxl_train_util.verify_sdxl_training_args(args, supportTextEncoderCaching=False) train_dataset_group.verify_bucket_reso_steps(32) From f61c442f0b7d4bc30f3d3eb3e169c13f424107ad Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 16 Aug 2025 22:03:52 +0900 Subject: [PATCH 3/8] fix: use strategy for tokenizer and latent caching --- train_control_net.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/train_control_net.py b/train_control_net.py index 97cd1ebb..c12693ba 100644 --- a/train_control_net.py +++ b/train_control_net.py @@ -12,7 +12,7 @@ import toml from tqdm import tqdm import torch -from library import deepspeed_utils +from library import deepspeed_utils, strategy_base, strategy_sd from library.device_utils import init_ipex, clean_memory_on_device init_ipex() @@ -73,7 +73,14 @@ def train(args): args.seed = random.randint(0, 2**32) set_seed(args.seed) - tokenizer = train_util.load_tokenizer(args) + tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + tokenizer = tokenize_strategy.tokenizer + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + True, args.cache_latents_to_disk, args.vae_batch_size, False + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) # データセットを準備する blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) @@ -100,7 +107,7 @@ def train(args): ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + blueprint = blueprint_generator.generate(user_config, args) train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) current_epoch = Value("i", 0) @@ -243,12 +250,7 @@ def train(args): vae.requires_grad_(False) vae.eval() with torch.no_grad(): - train_dataset_group.cache_latents( - vae, - args.vae_batch_size, - args.cache_latents_to_disk, - accelerator.is_main_process, - ) + train_dataset_group.new_cache_latents(vae, accelerator) vae.to("cpu") clean_memory_on_device(accelerator.device) @@ -267,6 +269,7 @@ def train(args): # dataloaderを準備する # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + train_dataset_group.set_current_strategies() n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( @@ -451,7 +454,7 @@ def train(args): latents = latents * 0.18215 b_size = latents.shape[0] - input_ids = batch["input_ids"].to(accelerator.device) + input_ids = batch["input_ids_list"][0].to(accelerator.device) encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype) # Sample noise that we'll add to the latents From acba279b0b6a62ed9266f1011bbbae13d098d7fe Mon Sep 17 00:00:00 2001 From: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Sun, 24 Aug 2025 17:03:18 +0900 Subject: [PATCH 4/8] fix: update PyTorch version in workflow matrix --- .github/workflows/tests.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9e037e53..88b0b177 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -23,7 +23,8 @@ jobs: os: [ubuntu-latest] python-version: ["3.10"] # Python versions to test pytorch-version: ["2.4.0"] # PyTorch versions to test - + pytorch-version: ["2.6.0"] # PyTorch versions to test + steps: - uses: actions/checkout@v4 with: From f7acd2f7a3319487feb5277934f5cc998e46b231 Mon Sep 17 00:00:00 2001 From: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Sun, 24 Aug 2025 17:16:17 +0900 Subject: [PATCH 5/8] fix: consolidate PyTorch versions in workflow matrix --- .github/workflows/tests.yml | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 88b0b177..d35fe392 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -22,9 +22,8 @@ jobs: matrix: os: [ubuntu-latest] python-version: ["3.10"] # Python versions to test - pytorch-version: ["2.4.0"] # PyTorch versions to test - pytorch-version: ["2.6.0"] # PyTorch versions to test - + pytorch-version: ["2.4.0", "2.6.0"] # PyTorch versions to test + steps: - uses: actions/checkout@v4 with: From ac72cf88a76d6165e70c96c6aa76282977bd378c Mon Sep 17 00:00:00 2001 From: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Thu, 28 Aug 2025 08:35:40 +0900 Subject: [PATCH 6/8] feat: remove bitsandbytes version specification in requirements.txt --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 7c7060c7..624978b4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ ftfy==6.3.1 opencv-python==4.10.0.84 einops==0.7.0 # pytorch-lightning==1.9.0 -bitsandbytes==0.45.4 +bitsandbytes lion-pytorch==0.2.3 schedulefree==1.4 pytorch-optimizer==3.7.0 From c52c45cd7ab6b8e92b8967e1d46965a67128ee6a Mon Sep 17 00:00:00 2001 From: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Thu, 28 Aug 2025 08:36:09 +0900 Subject: [PATCH 7/8] doc: update for PyTorch and libraries versions --- README.md | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index be0ae406..7dc9a4b6 100644 --- a/README.md +++ b/README.md @@ -4,18 +4,29 @@ This repository contains training, generation and utility scripts for Stable Dif This feature is experimental. The options and the training script may change in the future. Please let us know if you have any idea to improve the training. -__Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchvision==0.19.0` with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe. `requirements.txt` is also updated, so please update the requirements.__ +__Please update PyTorch to 2.6.0 or later. We have tested with `torch==2.6.0` and `torchvision==0.21.0` with CUDA 12.4. `requirements.txt` is also updated, so please update the requirements.__ The command to install PyTorch is as follows: -`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +`pip3 install torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/cu124` -If you are using DeepSpeed, please install DeepSpeed with `pip install deepspeed==0.16.7`. +For RTX 50 series GPUs, PyTorch 2.8.0 with CUDA 12.8/9 should be used. `requirements.txt` will work with this version. + +If you are using DeepSpeed, please install DeepSpeed with `pip install deepspeed` (appropriate version is not confirmed yet). - [FLUX.1 training](#flux1-training) - [SD3 training](#sd3-training) ### Recent Updates +Aug 28, 2025: +- In order to support the latest GPUs and features, we have updated the **PyTorch and library versions**. There are many changes, so please let us know if you encounter any issues. +- The PyTorch version used for testing has been updated to 2.6.0. We have confirmed that it works with PyTorch 2.6.0 and later. +- The `requirements.txt` has been updated, so please update your dependencies. + - You can update the dependencies with `pip install -r requirements.txt`. + - The version specification for `bitsandbytes` has been removed. If you encounter errors on RTX 50 series GPUs, please update it with `pip install -U bitsandbytes`. +- We have modified each script to minimize warnings as much as possible. + - The modified scripts will work in the old environment (library versions), but please update them when convenient. + Jul 30, 2025: - **Breaking Change**: For FLUX.1 and Chroma training, the CFG (Classifier-Free Guidance, using negative prompts) scale option for sample image generation during training has been changed from `--g` to `--l`. The `--g` option is now used for the embedded guidance scale. Please update your prompts accordingly. See [Sample Image Generation During Training](#sample-image-generation-during-training) for details. From 5a5138d0ab71d1d640851e6ffa8eb2ad2f2e2b60 Mon Sep 17 00:00:00 2001 From: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Thu, 28 Aug 2025 08:38:14 +0900 Subject: [PATCH 8/8] doc: add PR reference for PyTorch and library versions update --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 7dc9a4b6..5953cef5 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ If you are using DeepSpeed, please install DeepSpeed with `pip install deepspeed ### Recent Updates Aug 28, 2025: -- In order to support the latest GPUs and features, we have updated the **PyTorch and library versions**. There are many changes, so please let us know if you encounter any issues. +- In order to support the latest GPUs and features, we have updated the **PyTorch and library versions**. PR [#2178](https://github.com/kohya-ss/sd-scripts/pull/2178) There are many changes, so please let us know if you encounter any issues. - The PyTorch version used for testing has been updated to 2.6.0. We have confirmed that it works with PyTorch 2.6.0 and later. - The `requirements.txt` has been updated, so please update your dependencies. - You can update the dependencies with `pip install -r requirements.txt`.