Files
Kohya-ss-sd-scripts/library/strategy_sd.py
2026-04-01 13:10:21 +00:00

262 lines
12 KiB
Python

import glob
import os
from typing import Any, List, Optional, Tuple, Union
import numpy as np
import torch
from transformers import CLIPTokenizer
from library import train_util
from library.strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncodingStrategy
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
TOKENIZER_ID = "openai/clip-vit-large-patch14"
V2_STABLE_DIFFUSION_ID = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ
class SdTokenizeStrategy(TokenizeStrategy):
def __init__(self, v2: bool, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None) -> None:
"""
max_length does not include <BOS> and <EOS> (None, 75, 150, 225)
"""
logger.info(f"Using {'v2' if v2 else 'v1'} tokenizer")
if v2:
self.tokenizer = self._load_tokenizer(
CLIPTokenizer, V2_STABLE_DIFFUSION_ID, subfolder="tokenizer", tokenizer_cache_dir=tokenizer_cache_dir
)
else:
self.tokenizer = self._load_tokenizer(CLIPTokenizer, TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
if max_length is None:
self.max_length = self.tokenizer.model_max_length
else:
self.max_length = max_length + 2
self.break_separator = "BREAK"
def _split_on_break(self, text: str) -> List[str]:
"""Split text on BREAK separator (case-sensitive), filtering empty segments."""
segments = text.split(self.break_separator)
# Filter out empty or whitespace-only segments
filtered = [seg.strip() for seg in segments if seg.strip()]
# Return at least one segment to maintain consistency
return filtered if filtered else [""]
def _tokenize_segments(self, segments: List[str], weighted: bool = False) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Tokenize multiple segments and concatenate them."""
if len(segments) == 1:
# No BREAK present, use existing logic
if weighted:
return self._get_input_ids(self.tokenizer, segments[0], self.max_length, weighted=True)
else:
tokens = self._get_input_ids(self.tokenizer, segments[0], self.max_length)
return tokens, None
# Multiple segments - tokenize each separately
all_tokens = []
all_weights = [] if weighted else None
for segment in segments:
if weighted:
seg_tokens, seg_weights = self._get_input_ids(self.tokenizer, segment, self.max_length, weighted=True)
all_tokens.append(seg_tokens)
all_weights.append(seg_weights)
else:
seg_tokens = self._get_input_ids(self.tokenizer, segment, self.max_length)
all_tokens.append(seg_tokens)
# Concatenate along the sequence dimension (dim=1 for tokens that are [batch, seq_len] or [n_chunks, seq_len])
combined_tokens = torch.cat(all_tokens, dim=1) if all_tokens[0].dim() == 2 else torch.cat(all_tokens, dim=0)
combined_weights = None
if weighted:
combined_weights = torch.cat(all_weights, dim=1) if all_weights[0].dim() == 2 else torch.cat(all_weights, dim=0)
return combined_tokens, combined_weights
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
text = [text] if isinstance(text, str) else text
tokens_list = []
for t in text:
segments = self._split_on_break(t)
tokens, _ = self._tokenize_segments(segments, weighted=False)
tokens_list.append(tokens)
# Pad tokens to same length for stacking
max_length = max(t.shape[-1] for t in tokens_list)
padded_tokens = []
for tokens in tokens_list:
if tokens.shape[-1] < max_length:
# Pad with pad_token_id
pad_size = max_length - tokens.shape[-1]
if tokens.dim() == 2:
padding = torch.full((tokens.shape[0], pad_size), self.tokenizer.pad_token_id, dtype=tokens.dtype)
tokens = torch.cat([tokens, padding], dim=1)
else:
padding = torch.full((pad_size,), self.tokenizer.pad_token_id, dtype=tokens.dtype)
tokens = torch.cat([tokens, padding], dim=0)
padded_tokens.append(tokens)
return [torch.stack(padded_tokens, dim=0)]
def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
text = [text] if isinstance(text, str) else text
tokens_list = []
weights_list = []
for t in text:
segments = self._split_on_break(t)
tokens, weights = self._tokenize_segments(segments, weighted=True)
tokens_list.append(tokens)
weights_list.append(weights)
return [torch.stack(tokens_list, dim=0)], [torch.stack(weights_list, dim=0)]
class SdTextEncodingStrategy(TextEncodingStrategy):
def __init__(self, clip_skip: Optional[int] = None) -> None:
self.clip_skip = clip_skip
def _encode_with_clip_skip(self, text_encoder: Any, tokens: torch.Tensor) -> torch.Tensor:
"""Encode tokens with optional CLIP skip."""
if self.clip_skip is None:
return text_encoder(tokens)[0]
enc_out = text_encoder(tokens, output_hidden_states=True, return_dict=True)
hidden_states = enc_out["hidden_states"][-self.clip_skip]
return text_encoder.text_model.final_layer_norm(hidden_states)
def _reconstruct_embeddings(self, encoder_hidden_states: torch.Tensor, tokens: torch.Tensor,
max_token_length: int, model_max_length: int,
tokenizer: Any) -> torch.Tensor:
"""Reconstruct embeddings from chunked encoding."""
v1 = tokenizer.pad_token_id == tokenizer.eos_token_id
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
if not v1:
# v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す
for i in range(1, max_token_length, model_max_length):
chunk = encoder_hidden_states[:, i : i + model_max_length - 2]
if i > 0:
for j in range(len(chunk)):
if tokens[j, 1] == tokenizer.eos_token:
chunk[j, 0] = chunk[j, 1]
states_list.append(chunk)
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1))
else:
# v1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
for i in range(1, max_token_length, model_max_length):
states_list.append(encoder_hidden_states[:, i : i + model_max_length - 2])
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1))
return torch.cat(states_list, dim=1)
def _apply_weights_single_chunk(self, encoder_hidden_states: torch.Tensor,
weights: torch.Tensor) -> torch.Tensor:
"""Apply weights for single chunk case (no max_token_length)."""
return encoder_hidden_states * weights.squeeze(1).unsqueeze(2)
def _apply_weights_multi_chunk(self, encoder_hidden_states: torch.Tensor,
weights: torch.Tensor) -> torch.Tensor:
"""Apply weights for multi-chunk case (with max_token_length)."""
for i in range(weights.shape[1]):
start_idx = i * 75 + 1
end_idx = i * 75 + 76
encoder_hidden_states[:, start_idx:end_idx] = (
encoder_hidden_states[:, start_idx:end_idx] * weights[:, i, 1:-1].unsqueeze(-1)
)
return encoder_hidden_states
def encode_tokens(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
) -> List[torch.Tensor]:
text_encoder = models[0]
tokens = tokens[0]
sd_tokenize_strategy = tokenize_strategy # type: SdTokenizeStrategy
b_size = tokens.size()[0]
max_token_length = tokens.size()[1] * tokens.size()[2]
model_max_length = sd_tokenize_strategy.tokenizer.model_max_length
tokens = tokens.reshape((-1, model_max_length))
tokens = tokens.to(text_encoder.device)
encoder_hidden_states = self._encode_with_clip_skip(text_encoder, tokens)
encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1]))
if max_token_length != model_max_length:
encoder_hidden_states = self._reconstruct_embeddings(
encoder_hidden_states, tokens, max_token_length,
model_max_length, sd_tokenize_strategy.tokenizer
)
return [encoder_hidden_states]
def encode_tokens_with_weights(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
tokens_list: List[torch.Tensor],
weights_list: List[torch.Tensor],
) -> List[torch.Tensor]:
encoder_hidden_states = self.encode_tokens(tokenize_strategy, models, tokens_list)[0]
weights = weights_list[0].to(encoder_hidden_states.device)
if weights.shape[1] == 1:
encoder_hidden_states = self._apply_weights_single_chunk(encoder_hidden_states, weights)
else:
encoder_hidden_states = self._apply_weights_multi_chunk(encoder_hidden_states, weights)
return [encoder_hidden_states]
class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy):
# sd and sdxl share the same strategy. we can make them separate, but the difference is only the suffix.
# and we keep the old npz for the backward compatibility.
SD_OLD_LATENTS_NPZ_SUFFIX = ".npz"
SD_LATENTS_NPZ_SUFFIX = "_sd.npz"
SDXL_LATENTS_NPZ_SUFFIX = "_sdxl.npz"
def __init__(self, sd: bool, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
self.sd = sd
self.suffix = (
SdSdxlLatentsCachingStrategy.SD_LATENTS_NPZ_SUFFIX if sd else SdSdxlLatentsCachingStrategy.SDXL_LATENTS_NPZ_SUFFIX
)
@property
def cache_suffix(self) -> str:
return self.suffix
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
# support old .npz
old_npz_file = os.path.splitext(absolute_path)[0] + SdSdxlLatentsCachingStrategy.SD_OLD_LATENTS_NPZ_SUFFIX
if os.path.exists(old_npz_file):
return old_npz_file
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.suffix
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)
def load_latents_from_disk(
self, npz_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
return self._default_load_latents_from_disk(8, npz_path, bucket_reso)
# TODO remove circular dependency for ImageInfo
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
encode_by_vae = lambda img_tensor: vae.encode(img_tensor).latent_dist.sample()
vae_device = vae.device
vae_dtype = vae.dtype
self._default_cache_batch_latents(
encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True
)
if not train_util.HIGH_VRAM:
train_util.clean_memory_on_device(vae.device)