diff --git a/library/model_util.py b/library/model_util.py index 9918c7b2..bcaa1145 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -6,6 +6,7 @@ import os import torch from library.device_utils import init_ipex + init_ipex() import diffusers @@ -14,8 +15,10 @@ from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , from safetensors.torch import load_file, save_file from library.original_unet import UNet2DConditionModel from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) # DiffUsers版StableDiffusionのモデルパラメータ @@ -974,7 +977,7 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"): checkpoint = None state_dict = load_file(ckpt_path) # , device) # may causes error else: - checkpoint = torch.load(ckpt_path, map_location=device) + checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False) if "state_dict" in checkpoint: state_dict = checkpoint["state_dict"] else: diff --git a/library/original_unet.py b/library/original_unet.py index e944ff22..aa9dc233 100644 --- a/library/original_unet.py +++ b/library/original_unet.py @@ -114,8 +114,10 @@ from torch import nn from torch.nn import functional as F from einops import rearrange from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) BLOCK_OUT_CHANNELS: Tuple[int] = (320, 640, 1280, 1280) @@ -530,7 +532,9 @@ class DownBlock2D(nn.Module): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) else: hidden_states = resnet(hidden_states, temb) @@ -626,15 +630,9 @@ class CrossAttention(nn.Module): hidden_states, encoder_hidden_states, attention_mask, - ) = translate_attention_names_from_diffusers( - hidden_states=hidden_states, context=context, mask=mask, **kwargs - ) + ) = translate_attention_names_from_diffusers(hidden_states=hidden_states, context=context, mask=mask, **kwargs) return self.processor( - attn=self, - hidden_states=hidden_states, - encoder_hidden_states=context, - attention_mask=mask, - **kwargs + attn=self, hidden_states=hidden_states, encoder_hidden_states=context, attention_mask=mask, **kwargs ) if self.use_memory_efficient_attention_xformers: return self.forward_memory_efficient_xformers(hidden_states, context, mask) @@ -748,13 +746,14 @@ class CrossAttention(nn.Module): out = self.to_out[0](out) return out + def translate_attention_names_from_diffusers( hidden_states: torch.FloatTensor, context: Optional[torch.FloatTensor] = None, mask: Optional[torch.FloatTensor] = None, # HF naming encoder_hidden_states: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None + attention_mask: Optional[torch.FloatTensor] = None, ): # translate from hugging face diffusers context = context if context is not None else encoder_hidden_states @@ -764,6 +763,7 @@ def translate_attention_names_from_diffusers( return hidden_states, context, mask + # feedforward class GEGLU(nn.Module): r""" @@ -1015,9 +1015,11 @@ class CrossAttnDownBlock2D(nn.Module): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, use_reentrant=False )[0] else: hidden_states = resnet(hidden_states, temb) @@ -1098,10 +1100,12 @@ class UNetMidBlock2DCrossAttn(nn.Module): if attn is not None: hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states + create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, use_reentrant=False )[0] - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) else: if attn is not None: hidden_states = attn(hidden_states, encoder_hidden_states).sample @@ -1201,7 +1205,9 @@ class UpBlock2D(nn.Module): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) else: hidden_states = resnet(hidden_states, temb) @@ -1296,9 +1302,11 @@ class CrossAttnUpBlock2D(nn.Module): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, use_reentrant=False )[0] else: hidden_states = resnet(hidden_states, temb) diff --git a/library/train_util.py b/library/train_util.py index 39518395..b432d0b6 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -683,7 +683,7 @@ class BaseDataset(torch.utils.data.Dataset): resolution: Optional[Tuple[int, int]], network_multiplier: float, debug_dataset: bool, - resize_interpolation: Optional[str] = None + resize_interpolation: Optional[str] = None, ) -> None: super().__init__() @@ -719,7 +719,9 @@ class BaseDataset(torch.utils.data.Dataset): self.image_transforms = IMAGE_TRANSFORMS if resize_interpolation is not None: - assert validate_interpolation_fn(resize_interpolation), f"Resize interpolation \"{resize_interpolation}\" is not a valid interpolation" + assert validate_interpolation_fn( + resize_interpolation + ), f'Resize interpolation "{resize_interpolation}" is not a valid interpolation' self.resize_interpolation = resize_interpolation self.image_data: Dict[str, ImageInfo] = {} @@ -1613,7 +1615,11 @@ class BaseDataset(torch.utils.data.Dataset): if self.enable_bucket: img, original_size, crop_ltrb = trim_and_resize_if_required( - subset.random_crop, img, image_info.bucket_reso, image_info.resized_size, resize_interpolation=image_info.resize_interpolation + subset.random_crop, + img, + image_info.bucket_reso, + image_info.resized_size, + resize_interpolation=image_info.resize_interpolation, ) else: if face_cx > 0: # 顔位置情報あり @@ -2101,7 +2107,9 @@ class DreamBoothDataset(BaseDataset): for img_path, caption, size in zip(img_paths, captions, sizes): info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path) - info.resize_interpolation = subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation + info.resize_interpolation = ( + subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation + ) if size is not None: info.image_size = size if subset.is_reg: @@ -2385,7 +2393,7 @@ class ControlNetDataset(BaseDataset): bucket_no_upscale: bool, debug_dataset: bool, validation_split: float, - validation_seed: Optional[int], + validation_seed: Optional[int], resize_interpolation: Optional[str] = None, ) -> None: super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation) @@ -2448,7 +2456,7 @@ class ControlNetDataset(BaseDataset): self.num_train_images = self.dreambooth_dataset_delegate.num_train_images self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images self.validation_split = validation_split - self.validation_seed = validation_seed + self.validation_seed = validation_seed self.resize_interpolation = resize_interpolation # assert all conditioning data exists @@ -2538,7 +2546,14 @@ class ControlNetDataset(BaseDataset): cond_img.shape[0] == original_size_hw[0] and cond_img.shape[1] == original_size_hw[1] ), f"size of conditioning image is not match / 画像サイズが合いません: {image_info.absolute_path}" - cond_img = resize_image(cond_img, original_size_hw[1], original_size_hw[0], target_size_hw[1], target_size_hw[0], self.resize_interpolation) + cond_img = resize_image( + cond_img, + original_size_hw[1], + original_size_hw[0], + target_size_hw[1], + target_size_hw[0], + self.resize_interpolation, + ) # TODO support random crop # 現在サポートしているcropはrandomではなく中央のみ @@ -2552,7 +2567,14 @@ class ControlNetDataset(BaseDataset): # ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}" # resize to target if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]: - cond_img = resize_image(cond_img, cond_img.shape[0], cond_img.shape[1], target_size_hw[1], target_size_hw[0], self.resize_interpolation) + cond_img = resize_image( + cond_img, + cond_img.shape[0], + cond_img.shape[1], + target_size_hw[1], + target_size_hw[0], + self.resize_interpolation, + ) if flipped: cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride @@ -3000,7 +3022,9 @@ def load_images_and_masks_for_caching( for info in image_infos: image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8) # TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要 - image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation) + image, original_size, crop_ltrb = trim_and_resize_if_required( + random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation + ) original_sizes.append(original_size) crop_ltrbs.append(crop_ltrb) @@ -3041,7 +3065,9 @@ def cache_batch_latents( for info in image_infos: image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8) # TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要 - image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation) + image, original_size, crop_ltrb = trim_and_resize_if_required( + random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation + ) info.latents_original_size = original_size info.latents_crop_ltrb = crop_ltrb @@ -3482,9 +3508,9 @@ def get_sai_model_spec( textual_inversion: bool, is_stable_diffusion_ckpt: Optional[bool] = None, # None for TI and LoRA sd3: str = None, - flux: str = None, # "dev", "schnell" or "chroma" + flux: str = None, # "dev", "schnell" or "chroma" lumina: str = None, - optional_metadata: dict[str, str] | None = None + optional_metadata: dict[str, str] | None = None, ): timestamp = time.time() @@ -3513,7 +3539,7 @@ def get_sai_model_spec( # Extract metadata_* fields from args and merge with optional_metadata extracted_metadata = {} - + # Extract all metadata_* attributes from args for attr_name in dir(args): if attr_name.startswith("metadata_") and not attr_name.startswith("metadata___"): @@ -3523,7 +3549,7 @@ def get_sai_model_spec( field_name = attr_name[9:] # len("metadata_") = 9 if field_name not in ["title", "author", "description", "license", "tags"]: extracted_metadata[field_name] = value - + # Merge extracted metadata with provided optional_metadata all_optional_metadata = {**extracted_metadata} if optional_metadata: @@ -3546,7 +3572,7 @@ def get_sai_model_spec( tags=args.metadata_tags, timesteps=timesteps, clip_skip=args.clip_skip, # None or int - model_config=model_config, + model_config=model_config, optional_metadata=all_optional_metadata if all_optional_metadata else None, ) return metadata @@ -3562,7 +3588,7 @@ def get_sai_model_spec_dataclass( sd3: str = None, flux: str = None, lumina: str = None, - optional_metadata: dict[str, str] | None = None + optional_metadata: dict[str, str] | None = None, ) -> sai_model_spec.ModelSpecMetadata: """ Get ModelSpec metadata as a dataclass - preferred for new code. @@ -5558,11 +5584,12 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio def patch_accelerator_for_fp16_training(accelerator): - + from accelerate import DistributedType + if accelerator.distributed_type == DistributedType.DEEPSPEED: return - + org_unscale_grads = accelerator.scaler._unscale_grads_ def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): @@ -6054,7 +6081,6 @@ def get_noise_noisy_latents_and_timesteps( b_size = latents.shape[0] min_timestep = 0 if args.min_timestep is None else args.min_timestep max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep - timesteps = get_timesteps(min_timestep, max_timestep, b_size, latents.device) # Add noise to the latents according to the noise magnitude at each timestep @@ -6279,7 +6305,6 @@ def line_to_prompt_dict(line: str) -> dict: prompt_dict["renorm_cfg"] = float(m.group(1)) continue - except ValueError as ex: logger.error(f"Exception in parsing / 解析エラー: {parg}") logger.error(ex) @@ -6328,7 +6353,7 @@ def sample_images_common( vae, tokenizer, text_encoder, - unet, + unet_wrapped, prompt_replacement=None, controlnet=None, ): @@ -6363,7 +6388,7 @@ def sample_images_common( vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device # unwrap unet and text_encoder(s) - unet = accelerator.unwrap_model(unet) + unet = accelerator.unwrap_model(unet_wrapped) if isinstance(text_encoder, (list, tuple)): text_encoder = [accelerator.unwrap_model(te) for te in text_encoder] else: @@ -6509,7 +6534,7 @@ def sample_image_inference( logger.info(f"sample_sampler: {sampler_name}") if seed is not None: logger.info(f"seed: {seed}") - with accelerator.autocast(): + with accelerator.autocast(), torch.no_grad(): latents = pipeline( prompt=prompt, height=height, @@ -6647,4 +6672,3 @@ class LossRecorder: if losses == 0: return 0 return self.loss_total / losses - diff --git a/pytorch_lightning/__init__.py b/pytorch_lightning/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pytorch_lightning/callbacks/__init__.py b/pytorch_lightning/callbacks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py new file mode 100644 index 00000000..1ba14563 --- /dev/null +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -0,0 +1,4 @@ +# dummy module for pytorch_lightning + +class ModelCheckpoint: + pass diff --git a/requirements.txt b/requirements.txt index 448af323..7c7060c7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,28 +1,29 @@ -accelerate==0.33.0 -transformers==4.44.0 -diffusers[torch]==0.25.0 -ftfy==6.1.1 +accelerate==1.6.0 +transformers==4.54.1 +diffusers[torch]==0.32.1 +ftfy==6.3.1 # albumentations==1.3.0 -opencv-python==4.8.1.78 +opencv-python==4.10.0.84 einops==0.7.0 -pytorch-lightning==1.9.0 -bitsandbytes==0.44.0 -lion-pytorch==0.0.6 +# pytorch-lightning==1.9.0 +bitsandbytes==0.45.4 +lion-pytorch==0.2.3 schedulefree==1.4 pytorch-optimizer==3.7.0 -prodigy-plus-schedule-free==1.9.0 +prodigy-plus-schedule-free==1.9.2 prodigyopt==1.1.2 tensorboard -safetensors==0.4.4 +safetensors==0.4.5 # gradio==3.16.2 -altair==4.2.2 -easygui==0.98.3 +# altair==4.2.2 +# easygui==0.98.3 toml==0.10.2 -voluptuous==0.13.1 -huggingface-hub==0.24.5 +voluptuous==0.15.2 +huggingface-hub==0.34.3 # for Image utils imagesize==1.4.1 -numpy<=2.0 +numpy +# <=2.0 # for BLIP captioning # requests==2.28.2 # timm==0.6.12 @@ -41,8 +42,8 @@ numpy<=2.0 # open clip for SDXL # open-clip-torch==2.20.0 # For logging -rich==13.7.0 +rich==14.1.0 # for T5XXL tokenizer (SD3/FLUX) -sentencepiece==0.2.0 +sentencepiece==0.2.1 # for kohya_ss library -e . diff --git a/sdxl_train_network.py b/sdxl_train_network.py index d56c76b0..5c5bcd63 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -23,7 +23,12 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer): self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR self.is_sdxl = True - def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): + def assert_extra_args( + self, + args, + train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], + val_dataset_group: Optional[train_util.DatasetGroup], + ): sdxl_train_util.verify_sdxl_training_args(args) if args.cache_text_encoder_outputs: diff --git a/train_network.py b/train_network.py index e055f5d8..3dedb574 100644 --- a/train_network.py +++ b/train_network.py @@ -414,13 +414,12 @@ class NetworkTrainer: if text_encoder_outputs_list is not None: text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs - if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder: # TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached' with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast(): # Get the text embedding for conditioning if args.weighted_captions: - input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch['captions']) + input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"]) encoded_text_encoder_conds = text_encoding_strategy.encode_tokens_with_weights( tokenize_strategy, self.get_models_for_text_encoding(args, accelerator, text_encoders), @@ -1340,7 +1339,7 @@ class NetworkTrainer: ) NUM_VALIDATION_TIMESTEPS = 4 # 200, 400, 600, 800 TODO make this configurable min_timestep = 0 if args.min_timestep is None else args.min_timestep - max_timestep = noise_scheduler.num_train_timesteps if args.max_timestep is None else args.max_timestep + max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep validation_timesteps = np.linspace(min_timestep, max_timestep, (NUM_VALIDATION_TIMESTEPS + 2), dtype=int)[1:-1] validation_total_steps = validation_steps * len(validation_timesteps) original_args_min_timestep = args.min_timestep