mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-17 17:24:21 +00:00
feat: update libraries, remove warnings
This commit is contained in:
@@ -6,6 +6,7 @@ import os
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex
|
||||
|
||||
init_ipex()
|
||||
|
||||
import diffusers
|
||||
@@ -14,8 +15,10 @@ from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # ,
|
||||
from safetensors.torch import load_file, save_file
|
||||
from library.original_unet import UNet2DConditionModel
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# DiffUsers版StableDiffusionのモデルパラメータ
|
||||
@@ -974,7 +977,7 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"):
|
||||
checkpoint = None
|
||||
state_dict = load_file(ckpt_path) # , device) # may causes error
|
||||
else:
|
||||
checkpoint = torch.load(ckpt_path, map_location=device)
|
||||
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)
|
||||
if "state_dict" in checkpoint:
|
||||
state_dict = checkpoint["state_dict"]
|
||||
else:
|
||||
|
||||
@@ -114,8 +114,10 @@ from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from einops import rearrange
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BLOCK_OUT_CHANNELS: Tuple[int] = (320, 640, 1280, 1280)
|
||||
@@ -530,7 +532,9 @@ class DownBlock2D(nn.Module):
|
||||
|
||||
return custom_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
@@ -626,15 +630,9 @@ class CrossAttention(nn.Module):
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
attention_mask,
|
||||
) = translate_attention_names_from_diffusers(
|
||||
hidden_states=hidden_states, context=context, mask=mask, **kwargs
|
||||
)
|
||||
) = translate_attention_names_from_diffusers(hidden_states=hidden_states, context=context, mask=mask, **kwargs)
|
||||
return self.processor(
|
||||
attn=self,
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=context,
|
||||
attention_mask=mask,
|
||||
**kwargs
|
||||
attn=self, hidden_states=hidden_states, encoder_hidden_states=context, attention_mask=mask, **kwargs
|
||||
)
|
||||
if self.use_memory_efficient_attention_xformers:
|
||||
return self.forward_memory_efficient_xformers(hidden_states, context, mask)
|
||||
@@ -748,13 +746,14 @@ class CrossAttention(nn.Module):
|
||||
out = self.to_out[0](out)
|
||||
return out
|
||||
|
||||
|
||||
def translate_attention_names_from_diffusers(
|
||||
hidden_states: torch.FloatTensor,
|
||||
context: Optional[torch.FloatTensor] = None,
|
||||
mask: Optional[torch.FloatTensor] = None,
|
||||
# HF naming
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
# translate from hugging face diffusers
|
||||
context = context if context is not None else encoder_hidden_states
|
||||
@@ -764,6 +763,7 @@ def translate_attention_names_from_diffusers(
|
||||
|
||||
return hidden_states, context, mask
|
||||
|
||||
|
||||
# feedforward
|
||||
class GEGLU(nn.Module):
|
||||
r"""
|
||||
@@ -1015,9 +1015,11 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
|
||||
return custom_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
|
||||
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
||||
)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, use_reentrant=False
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
@@ -1098,10 +1100,12 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
|
||||
if attn is not None:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
|
||||
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, use_reentrant=False
|
||||
)[0]
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
||||
)
|
||||
else:
|
||||
if attn is not None:
|
||||
hidden_states = attn(hidden_states, encoder_hidden_states).sample
|
||||
@@ -1201,7 +1205,9 @@ class UpBlock2D(nn.Module):
|
||||
|
||||
return custom_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
@@ -1296,9 +1302,11 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
|
||||
return custom_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
|
||||
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
||||
)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, use_reentrant=False
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
@@ -683,7 +683,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
resolution: Optional[Tuple[int, int]],
|
||||
network_multiplier: float,
|
||||
debug_dataset: bool,
|
||||
resize_interpolation: Optional[str] = None
|
||||
resize_interpolation: Optional[str] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -719,7 +719,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
self.image_transforms = IMAGE_TRANSFORMS
|
||||
|
||||
if resize_interpolation is not None:
|
||||
assert validate_interpolation_fn(resize_interpolation), f"Resize interpolation \"{resize_interpolation}\" is not a valid interpolation"
|
||||
assert validate_interpolation_fn(
|
||||
resize_interpolation
|
||||
), f'Resize interpolation "{resize_interpolation}" is not a valid interpolation'
|
||||
self.resize_interpolation = resize_interpolation
|
||||
|
||||
self.image_data: Dict[str, ImageInfo] = {}
|
||||
@@ -1613,7 +1615,11 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
if self.enable_bucket:
|
||||
img, original_size, crop_ltrb = trim_and_resize_if_required(
|
||||
subset.random_crop, img, image_info.bucket_reso, image_info.resized_size, resize_interpolation=image_info.resize_interpolation
|
||||
subset.random_crop,
|
||||
img,
|
||||
image_info.bucket_reso,
|
||||
image_info.resized_size,
|
||||
resize_interpolation=image_info.resize_interpolation,
|
||||
)
|
||||
else:
|
||||
if face_cx > 0: # 顔位置情報あり
|
||||
@@ -2101,7 +2107,9 @@ class DreamBoothDataset(BaseDataset):
|
||||
|
||||
for img_path, caption, size in zip(img_paths, captions, sizes):
|
||||
info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path)
|
||||
info.resize_interpolation = subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation
|
||||
info.resize_interpolation = (
|
||||
subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation
|
||||
)
|
||||
if size is not None:
|
||||
info.image_size = size
|
||||
if subset.is_reg:
|
||||
@@ -2385,7 +2393,7 @@ class ControlNetDataset(BaseDataset):
|
||||
bucket_no_upscale: bool,
|
||||
debug_dataset: bool,
|
||||
validation_split: float,
|
||||
validation_seed: Optional[int],
|
||||
validation_seed: Optional[int],
|
||||
resize_interpolation: Optional[str] = None,
|
||||
) -> None:
|
||||
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)
|
||||
@@ -2448,7 +2456,7 @@ class ControlNetDataset(BaseDataset):
|
||||
self.num_train_images = self.dreambooth_dataset_delegate.num_train_images
|
||||
self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images
|
||||
self.validation_split = validation_split
|
||||
self.validation_seed = validation_seed
|
||||
self.validation_seed = validation_seed
|
||||
self.resize_interpolation = resize_interpolation
|
||||
|
||||
# assert all conditioning data exists
|
||||
@@ -2538,7 +2546,14 @@ class ControlNetDataset(BaseDataset):
|
||||
cond_img.shape[0] == original_size_hw[0] and cond_img.shape[1] == original_size_hw[1]
|
||||
), f"size of conditioning image is not match / 画像サイズが合いません: {image_info.absolute_path}"
|
||||
|
||||
cond_img = resize_image(cond_img, original_size_hw[1], original_size_hw[0], target_size_hw[1], target_size_hw[0], self.resize_interpolation)
|
||||
cond_img = resize_image(
|
||||
cond_img,
|
||||
original_size_hw[1],
|
||||
original_size_hw[0],
|
||||
target_size_hw[1],
|
||||
target_size_hw[0],
|
||||
self.resize_interpolation,
|
||||
)
|
||||
|
||||
# TODO support random crop
|
||||
# 現在サポートしているcropはrandomではなく中央のみ
|
||||
@@ -2552,7 +2567,14 @@ class ControlNetDataset(BaseDataset):
|
||||
# ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
|
||||
# resize to target
|
||||
if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]:
|
||||
cond_img = resize_image(cond_img, cond_img.shape[0], cond_img.shape[1], target_size_hw[1], target_size_hw[0], self.resize_interpolation)
|
||||
cond_img = resize_image(
|
||||
cond_img,
|
||||
cond_img.shape[0],
|
||||
cond_img.shape[1],
|
||||
target_size_hw[1],
|
||||
target_size_hw[0],
|
||||
self.resize_interpolation,
|
||||
)
|
||||
|
||||
if flipped:
|
||||
cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride
|
||||
@@ -3000,7 +3022,9 @@ def load_images_and_masks_for_caching(
|
||||
for info in image_infos:
|
||||
image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8)
|
||||
# TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要
|
||||
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation)
|
||||
image, original_size, crop_ltrb = trim_and_resize_if_required(
|
||||
random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation
|
||||
)
|
||||
|
||||
original_sizes.append(original_size)
|
||||
crop_ltrbs.append(crop_ltrb)
|
||||
@@ -3041,7 +3065,9 @@ def cache_batch_latents(
|
||||
for info in image_infos:
|
||||
image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8)
|
||||
# TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要
|
||||
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation)
|
||||
image, original_size, crop_ltrb = trim_and_resize_if_required(
|
||||
random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation
|
||||
)
|
||||
|
||||
info.latents_original_size = original_size
|
||||
info.latents_crop_ltrb = crop_ltrb
|
||||
@@ -3482,9 +3508,9 @@ def get_sai_model_spec(
|
||||
textual_inversion: bool,
|
||||
is_stable_diffusion_ckpt: Optional[bool] = None, # None for TI and LoRA
|
||||
sd3: str = None,
|
||||
flux: str = None, # "dev", "schnell" or "chroma"
|
||||
flux: str = None, # "dev", "schnell" or "chroma"
|
||||
lumina: str = None,
|
||||
optional_metadata: dict[str, str] | None = None
|
||||
optional_metadata: dict[str, str] | None = None,
|
||||
):
|
||||
timestamp = time.time()
|
||||
|
||||
@@ -3513,7 +3539,7 @@ def get_sai_model_spec(
|
||||
|
||||
# Extract metadata_* fields from args and merge with optional_metadata
|
||||
extracted_metadata = {}
|
||||
|
||||
|
||||
# Extract all metadata_* attributes from args
|
||||
for attr_name in dir(args):
|
||||
if attr_name.startswith("metadata_") and not attr_name.startswith("metadata___"):
|
||||
@@ -3523,7 +3549,7 @@ def get_sai_model_spec(
|
||||
field_name = attr_name[9:] # len("metadata_") = 9
|
||||
if field_name not in ["title", "author", "description", "license", "tags"]:
|
||||
extracted_metadata[field_name] = value
|
||||
|
||||
|
||||
# Merge extracted metadata with provided optional_metadata
|
||||
all_optional_metadata = {**extracted_metadata}
|
||||
if optional_metadata:
|
||||
@@ -3546,7 +3572,7 @@ def get_sai_model_spec(
|
||||
tags=args.metadata_tags,
|
||||
timesteps=timesteps,
|
||||
clip_skip=args.clip_skip, # None or int
|
||||
model_config=model_config,
|
||||
model_config=model_config,
|
||||
optional_metadata=all_optional_metadata if all_optional_metadata else None,
|
||||
)
|
||||
return metadata
|
||||
@@ -3562,7 +3588,7 @@ def get_sai_model_spec_dataclass(
|
||||
sd3: str = None,
|
||||
flux: str = None,
|
||||
lumina: str = None,
|
||||
optional_metadata: dict[str, str] | None = None
|
||||
optional_metadata: dict[str, str] | None = None,
|
||||
) -> sai_model_spec.ModelSpecMetadata:
|
||||
"""
|
||||
Get ModelSpec metadata as a dataclass - preferred for new code.
|
||||
@@ -5558,11 +5584,12 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
|
||||
|
||||
|
||||
def patch_accelerator_for_fp16_training(accelerator):
|
||||
|
||||
|
||||
from accelerate import DistributedType
|
||||
|
||||
if accelerator.distributed_type == DistributedType.DEEPSPEED:
|
||||
return
|
||||
|
||||
|
||||
org_unscale_grads = accelerator.scaler._unscale_grads_
|
||||
|
||||
def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
|
||||
@@ -6054,7 +6081,6 @@ def get_noise_noisy_latents_and_timesteps(
|
||||
b_size = latents.shape[0]
|
||||
min_timestep = 0 if args.min_timestep is None else args.min_timestep
|
||||
max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep
|
||||
|
||||
timesteps = get_timesteps(min_timestep, max_timestep, b_size, latents.device)
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
@@ -6279,7 +6305,6 @@ def line_to_prompt_dict(line: str) -> dict:
|
||||
prompt_dict["renorm_cfg"] = float(m.group(1))
|
||||
continue
|
||||
|
||||
|
||||
except ValueError as ex:
|
||||
logger.error(f"Exception in parsing / 解析エラー: {parg}")
|
||||
logger.error(ex)
|
||||
@@ -6328,7 +6353,7 @@ def sample_images_common(
|
||||
vae,
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
unet,
|
||||
unet_wrapped,
|
||||
prompt_replacement=None,
|
||||
controlnet=None,
|
||||
):
|
||||
@@ -6363,7 +6388,7 @@ def sample_images_common(
|
||||
vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device
|
||||
|
||||
# unwrap unet and text_encoder(s)
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
unet = accelerator.unwrap_model(unet_wrapped)
|
||||
if isinstance(text_encoder, (list, tuple)):
|
||||
text_encoder = [accelerator.unwrap_model(te) for te in text_encoder]
|
||||
else:
|
||||
@@ -6509,7 +6534,7 @@ def sample_image_inference(
|
||||
logger.info(f"sample_sampler: {sampler_name}")
|
||||
if seed is not None:
|
||||
logger.info(f"seed: {seed}")
|
||||
with accelerator.autocast():
|
||||
with accelerator.autocast(), torch.no_grad():
|
||||
latents = pipeline(
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
@@ -6647,4 +6672,3 @@ class LossRecorder:
|
||||
if losses == 0:
|
||||
return 0
|
||||
return self.loss_total / losses
|
||||
|
||||
|
||||
Reference in New Issue
Block a user