feat: update libraries, remove warnings

This commit is contained in:
Kohya S
2025-08-16 20:07:03 +09:00
parent 18e62515c4
commit 6edbe00547
9 changed files with 107 additions and 63 deletions

View File

@@ -6,6 +6,7 @@ import os
import torch
from library.device_utils import init_ipex
init_ipex()
import diffusers
@@ -14,8 +15,10 @@ from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # ,
from safetensors.torch import load_file, save_file
from library.original_unet import UNet2DConditionModel
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
# DiffUsers版StableDiffusionのモデルパラメータ
@@ -974,7 +977,7 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"):
checkpoint = None
state_dict = load_file(ckpt_path) # , device) # may causes error
else:
checkpoint = torch.load(ckpt_path, map_location=device)
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)
if "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
else:

View File

@@ -114,8 +114,10 @@ from torch import nn
from torch.nn import functional as F
from einops import rearrange
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
BLOCK_OUT_CHANNELS: Tuple[int] = (320, 640, 1280, 1280)
@@ -530,7 +532,9 @@ class DownBlock2D(nn.Module):
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
hidden_states = resnet(hidden_states, temb)
@@ -626,15 +630,9 @@ class CrossAttention(nn.Module):
hidden_states,
encoder_hidden_states,
attention_mask,
) = translate_attention_names_from_diffusers(
hidden_states=hidden_states, context=context, mask=mask, **kwargs
)
) = translate_attention_names_from_diffusers(hidden_states=hidden_states, context=context, mask=mask, **kwargs)
return self.processor(
attn=self,
hidden_states=hidden_states,
encoder_hidden_states=context,
attention_mask=mask,
**kwargs
attn=self, hidden_states=hidden_states, encoder_hidden_states=context, attention_mask=mask, **kwargs
)
if self.use_memory_efficient_attention_xformers:
return self.forward_memory_efficient_xformers(hidden_states, context, mask)
@@ -748,13 +746,14 @@ class CrossAttention(nn.Module):
out = self.to_out[0](out)
return out
def translate_attention_names_from_diffusers(
hidden_states: torch.FloatTensor,
context: Optional[torch.FloatTensor] = None,
mask: Optional[torch.FloatTensor] = None,
# HF naming
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None
attention_mask: Optional[torch.FloatTensor] = None,
):
# translate from hugging face diffusers
context = context if context is not None else encoder_hidden_states
@@ -764,6 +763,7 @@ def translate_attention_names_from_diffusers(
return hidden_states, context, mask
# feedforward
class GEGLU(nn.Module):
r"""
@@ -1015,9 +1015,11 @@ class CrossAttnDownBlock2D(nn.Module):
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, use_reentrant=False
)[0]
else:
hidden_states = resnet(hidden_states, temb)
@@ -1098,10 +1100,12 @@ class UNetMidBlock2DCrossAttn(nn.Module):
if attn is not None:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, use_reentrant=False
)[0]
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
if attn is not None:
hidden_states = attn(hidden_states, encoder_hidden_states).sample
@@ -1201,7 +1205,9 @@ class UpBlock2D(nn.Module):
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
hidden_states = resnet(hidden_states, temb)
@@ -1296,9 +1302,11 @@ class CrossAttnUpBlock2D(nn.Module):
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, use_reentrant=False
)[0]
else:
hidden_states = resnet(hidden_states, temb)

View File

@@ -683,7 +683,7 @@ class BaseDataset(torch.utils.data.Dataset):
resolution: Optional[Tuple[int, int]],
network_multiplier: float,
debug_dataset: bool,
resize_interpolation: Optional[str] = None
resize_interpolation: Optional[str] = None,
) -> None:
super().__init__()
@@ -719,7 +719,9 @@ class BaseDataset(torch.utils.data.Dataset):
self.image_transforms = IMAGE_TRANSFORMS
if resize_interpolation is not None:
assert validate_interpolation_fn(resize_interpolation), f"Resize interpolation \"{resize_interpolation}\" is not a valid interpolation"
assert validate_interpolation_fn(
resize_interpolation
), f'Resize interpolation "{resize_interpolation}" is not a valid interpolation'
self.resize_interpolation = resize_interpolation
self.image_data: Dict[str, ImageInfo] = {}
@@ -1613,7 +1615,11 @@ class BaseDataset(torch.utils.data.Dataset):
if self.enable_bucket:
img, original_size, crop_ltrb = trim_and_resize_if_required(
subset.random_crop, img, image_info.bucket_reso, image_info.resized_size, resize_interpolation=image_info.resize_interpolation
subset.random_crop,
img,
image_info.bucket_reso,
image_info.resized_size,
resize_interpolation=image_info.resize_interpolation,
)
else:
if face_cx > 0: # 顔位置情報あり
@@ -2101,7 +2107,9 @@ class DreamBoothDataset(BaseDataset):
for img_path, caption, size in zip(img_paths, captions, sizes):
info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path)
info.resize_interpolation = subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation
info.resize_interpolation = (
subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation
)
if size is not None:
info.image_size = size
if subset.is_reg:
@@ -2538,7 +2546,14 @@ class ControlNetDataset(BaseDataset):
cond_img.shape[0] == original_size_hw[0] and cond_img.shape[1] == original_size_hw[1]
), f"size of conditioning image is not match / 画像サイズが合いません: {image_info.absolute_path}"
cond_img = resize_image(cond_img, original_size_hw[1], original_size_hw[0], target_size_hw[1], target_size_hw[0], self.resize_interpolation)
cond_img = resize_image(
cond_img,
original_size_hw[1],
original_size_hw[0],
target_size_hw[1],
target_size_hw[0],
self.resize_interpolation,
)
# TODO support random crop
# 現在サポートしているcropはrandomではなく中央のみ
@@ -2552,7 +2567,14 @@ class ControlNetDataset(BaseDataset):
# ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
# resize to target
if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]:
cond_img = resize_image(cond_img, cond_img.shape[0], cond_img.shape[1], target_size_hw[1], target_size_hw[0], self.resize_interpolation)
cond_img = resize_image(
cond_img,
cond_img.shape[0],
cond_img.shape[1],
target_size_hw[1],
target_size_hw[0],
self.resize_interpolation,
)
if flipped:
cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride
@@ -3000,7 +3022,9 @@ def load_images_and_masks_for_caching(
for info in image_infos:
image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8)
# TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation)
image, original_size, crop_ltrb = trim_and_resize_if_required(
random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation
)
original_sizes.append(original_size)
crop_ltrbs.append(crop_ltrb)
@@ -3041,7 +3065,9 @@ def cache_batch_latents(
for info in image_infos:
image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8)
# TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation)
image, original_size, crop_ltrb = trim_and_resize_if_required(
random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation
)
info.latents_original_size = original_size
info.latents_crop_ltrb = crop_ltrb
@@ -3484,7 +3510,7 @@ def get_sai_model_spec(
sd3: str = None,
flux: str = None, # "dev", "schnell" or "chroma"
lumina: str = None,
optional_metadata: dict[str, str] | None = None
optional_metadata: dict[str, str] | None = None,
):
timestamp = time.time()
@@ -3562,7 +3588,7 @@ def get_sai_model_spec_dataclass(
sd3: str = None,
flux: str = None,
lumina: str = None,
optional_metadata: dict[str, str] | None = None
optional_metadata: dict[str, str] | None = None,
) -> sai_model_spec.ModelSpecMetadata:
"""
Get ModelSpec metadata as a dataclass - preferred for new code.
@@ -5560,6 +5586,7 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
def patch_accelerator_for_fp16_training(accelerator):
from accelerate import DistributedType
if accelerator.distributed_type == DistributedType.DEEPSPEED:
return
@@ -6054,7 +6081,6 @@ def get_noise_noisy_latents_and_timesteps(
b_size = latents.shape[0]
min_timestep = 0 if args.min_timestep is None else args.min_timestep
max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep
timesteps = get_timesteps(min_timestep, max_timestep, b_size, latents.device)
# Add noise to the latents according to the noise magnitude at each timestep
@@ -6279,7 +6305,6 @@ def line_to_prompt_dict(line: str) -> dict:
prompt_dict["renorm_cfg"] = float(m.group(1))
continue
except ValueError as ex:
logger.error(f"Exception in parsing / 解析エラー: {parg}")
logger.error(ex)
@@ -6328,7 +6353,7 @@ def sample_images_common(
vae,
tokenizer,
text_encoder,
unet,
unet_wrapped,
prompt_replacement=None,
controlnet=None,
):
@@ -6363,7 +6388,7 @@ def sample_images_common(
vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device
# unwrap unet and text_encoder(s)
unet = accelerator.unwrap_model(unet)
unet = accelerator.unwrap_model(unet_wrapped)
if isinstance(text_encoder, (list, tuple)):
text_encoder = [accelerator.unwrap_model(te) for te in text_encoder]
else:
@@ -6509,7 +6534,7 @@ def sample_image_inference(
logger.info(f"sample_sampler: {sampler_name}")
if seed is not None:
logger.info(f"seed: {seed}")
with accelerator.autocast():
with accelerator.autocast(), torch.no_grad():
latents = pipeline(
prompt=prompt,
height=height,
@@ -6647,4 +6672,3 @@ class LossRecorder:
if losses == 0:
return 0
return self.loss_total / losses

View File

View File

View File

@@ -0,0 +1,4 @@
# dummy module for pytorch_lightning
class ModelCheckpoint:
pass

View File

@@ -1,28 +1,29 @@
accelerate==0.33.0
transformers==4.44.0
diffusers[torch]==0.25.0
ftfy==6.1.1
accelerate==1.6.0
transformers==4.54.1
diffusers[torch]==0.32.1
ftfy==6.3.1
# albumentations==1.3.0
opencv-python==4.8.1.78
opencv-python==4.10.0.84
einops==0.7.0
pytorch-lightning==1.9.0
bitsandbytes==0.44.0
lion-pytorch==0.0.6
# pytorch-lightning==1.9.0
bitsandbytes==0.45.4
lion-pytorch==0.2.3
schedulefree==1.4
pytorch-optimizer==3.7.0
prodigy-plus-schedule-free==1.9.0
prodigy-plus-schedule-free==1.9.2
prodigyopt==1.1.2
tensorboard
safetensors==0.4.4
safetensors==0.4.5
# gradio==3.16.2
altair==4.2.2
easygui==0.98.3
# altair==4.2.2
# easygui==0.98.3
toml==0.10.2
voluptuous==0.13.1
huggingface-hub==0.24.5
voluptuous==0.15.2
huggingface-hub==0.34.3
# for Image utils
imagesize==1.4.1
numpy<=2.0
numpy
# <=2.0
# for BLIP captioning
# requests==2.28.2
# timm==0.6.12
@@ -41,8 +42,8 @@ numpy<=2.0
# open clip for SDXL
# open-clip-torch==2.20.0
# For logging
rich==13.7.0
rich==14.1.0
# for T5XXL tokenizer (SD3/FLUX)
sentencepiece==0.2.0
sentencepiece==0.2.1
# for kohya_ss library
-e .

View File

@@ -23,7 +23,12 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR
self.is_sdxl = True
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
def assert_extra_args(
self,
args,
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
val_dataset_group: Optional[train_util.DatasetGroup],
):
sdxl_train_util.verify_sdxl_training_args(args)
if args.cache_text_encoder_outputs:

View File

@@ -414,13 +414,12 @@ class NetworkTrainer:
if text_encoder_outputs_list is not None:
text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs
if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder:
# TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached'
with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast():
# Get the text embedding for conditioning
if args.weighted_captions:
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch['captions'])
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"])
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens_with_weights(
tokenize_strategy,
self.get_models_for_text_encoding(args, accelerator, text_encoders),
@@ -1340,7 +1339,7 @@ class NetworkTrainer:
)
NUM_VALIDATION_TIMESTEPS = 4 # 200, 400, 600, 800 TODO make this configurable
min_timestep = 0 if args.min_timestep is None else args.min_timestep
max_timestep = noise_scheduler.num_train_timesteps if args.max_timestep is None else args.max_timestep
max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep
validation_timesteps = np.linspace(min_timestep, max_timestep, (NUM_VALIDATION_TIMESTEPS + 2), dtype=int)[1:-1]
validation_total_steps = validation_steps * len(validation_timesteps)
original_args_min_timestep = args.min_timestep