mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
380 lines
14 KiB
Python
380 lines
14 KiB
Python
import os
|
|
import glob
|
|
from typing import Any, List, Optional, Tuple, Union
|
|
import torch
|
|
import numpy as np
|
|
from transformers import AutoTokenizer
|
|
|
|
from library import train_util
|
|
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
|
|
|
|
from library.utils import setup_logging
|
|
|
|
setup_logging()
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
GLM_TOKENIZER_ID = "THUDM/CogView4-6B"
|
|
|
|
|
|
class CogView4TokenizeStrategy(TokenizeStrategy):
|
|
def __init__(self, max_length: int = 512, tokenizer_cache_dir: Optional[str] = None) -> None:
|
|
self.max_length = max_length
|
|
self.tokenizer = self._load_tokenizer(AutoTokenizer, GLM_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
|
|
# Add special tokens if needed
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
|
|
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
|
|
text = [text] if isinstance(text, str) else text
|
|
|
|
# Tokenize with GLM tokenizer
|
|
tokens = self.tokenizer(
|
|
text,
|
|
max_length=self.max_length,
|
|
padding="max_length",
|
|
truncation=True,
|
|
return_tensors="pt"
|
|
)
|
|
|
|
input_ids = tokens["input_ids"]
|
|
attention_mask = tokens["attention_mask"]
|
|
|
|
return [input_ids, attention_mask]
|
|
|
|
|
|
class CogView4TextEncodingStrategy(TextEncodingStrategy):
|
|
def __init__(self, apply_attention_mask: bool = True) -> None:
|
|
"""
|
|
Args:
|
|
apply_attention_mask: Whether to apply attention mask during encoding.
|
|
"""
|
|
self.apply_attention_mask = apply_attention_mask
|
|
|
|
def encode_tokens(
|
|
self,
|
|
tokenize_strategy: TokenizeStrategy,
|
|
models: List[Any],
|
|
tokens: List[torch.Tensor],
|
|
apply_attention_mask: Optional[bool] = None,
|
|
) -> List[torch.Tensor]:
|
|
# supports single model inference
|
|
if apply_attention_mask is None:
|
|
apply_attention_mask = self.apply_attention_mask
|
|
|
|
# Get GLM model (should be the only model in the list)
|
|
glm_model = models[0]
|
|
input_ids = tokens[0]
|
|
attention_mask = tokens[1] if len(tokens) > 1 else None
|
|
|
|
# Move tensors to the correct device
|
|
device = glm_model.device
|
|
input_ids = input_ids.to(device)
|
|
if attention_mask is not None:
|
|
attention_mask = attention_mask.to(device)
|
|
|
|
# Get GLM model outputs
|
|
with torch.no_grad():
|
|
outputs = glm_model(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask if apply_attention_mask else None,
|
|
output_hidden_states=True,
|
|
return_dict=True
|
|
)
|
|
|
|
# Get the last hidden state
|
|
hidden_states = outputs.hidden_states[-1] # [batch_size, seq_len, hidden_size]
|
|
|
|
# For compatibility with existing code, we'll return a list similar to the original
|
|
# but with GLM's hidden states instead of CLIP/T5 outputs
|
|
return [
|
|
hidden_states, # Replaces l_pooled
|
|
hidden_states, # Replaces t5_out (same tensor for now, can be modified if needed)
|
|
torch.zeros(hidden_states.shape[0], hidden_states.shape[1], 3, device=device), # txt_ids placeholder
|
|
attention_mask # attention mask
|
|
]
|
|
|
|
|
|
class CogView4TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
|
COGVIEW4_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_cogview4_te.npz"
|
|
|
|
def __init__(
|
|
self,
|
|
cache_to_disk: bool,
|
|
batch_size: int,
|
|
skip_disk_cache_validity_check: bool,
|
|
is_partial: bool = False,
|
|
apply_attention_mask: bool = True,
|
|
) -> None:
|
|
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
|
|
self.apply_attention_mask = apply_attention_mask
|
|
self.warn_fp8_weights = False
|
|
|
|
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
|
return os.path.splitext(image_abs_path)[0] + CogView4TextEncoderOutputsCachingStrategy.COGVIEW4_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
|
|
|
|
def is_disk_cached_outputs_expected(self, npz_path: str):
|
|
if not self.cache_to_disk:
|
|
return False
|
|
if not os.path.exists(npz_path):
|
|
return False
|
|
if self.skip_disk_cache_validity_check:
|
|
return True
|
|
|
|
try:
|
|
npz = np.load(npz_path)
|
|
required_fields = ["hidden_states", "attention_mask", "apply_attention_mask"]
|
|
for field in required_fields:
|
|
if field not in npz:
|
|
return False
|
|
|
|
npz_apply_attention_mask = bool(npz["apply_attention_mask"])
|
|
if npz_apply_attention_mask != self.apply_attention_mask:
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error loading file: {npz_path}")
|
|
logger.exception(e)
|
|
return False
|
|
|
|
return True
|
|
|
|
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
|
|
data = np.load(npz_path)
|
|
hidden_states = data["hidden_states"]
|
|
attention_mask = data["attention_mask"]
|
|
return [
|
|
hidden_states, # l_pooled replacement
|
|
hidden_states, # t5_out replacement
|
|
np.zeros((hidden_states.shape[0], hidden_states.shape[1], 3), dtype=np.float32), # txt_ids
|
|
attention_mask # attention mask
|
|
]
|
|
|
|
def cache_batch_outputs(
|
|
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
|
|
):
|
|
if not self.warn_fp8_weights:
|
|
model_dtype = next(models[0].parameters()).dtype
|
|
if model_dtype == torch.float8_e4m3fn or model_dtype == torch.float8_e5m2:
|
|
logger.warning(
|
|
"Model is using fp8 weights for caching. This may affect the quality of the cached outputs."
|
|
" / モデルはfp8の重みを使用しています。これはキャッシュの品質に影響を与える可能性があります。"
|
|
)
|
|
self.warn_fp8_weights = True
|
|
|
|
captions = [info.caption for info in infos]
|
|
|
|
tokens_and_masks = tokenize_strategy.tokenize(captions)
|
|
|
|
with torch.no_grad():
|
|
hidden_states, _, _, attention_mask = text_encoding_strategy.encode_tokens(
|
|
tokenize_strategy, models, tokens_and_masks
|
|
)
|
|
|
|
if hidden_states.dtype == torch.bfloat16:
|
|
hidden_states = hidden_states.float()
|
|
|
|
hidden_states = hidden_states.cpu().numpy()
|
|
attention_mask = attention_mask.cpu().numpy() if attention_mask is not None else None
|
|
|
|
for i, info in enumerate(infos):
|
|
hidden_states_i = hidden_states[i]
|
|
attention_mask_i = attention_mask[i] if attention_mask is not None else None
|
|
|
|
if self.cache_to_disk and hasattr(info, 'text_encoder_outputs_npz'):
|
|
np.savez(
|
|
info.text_encoder_outputs_npz,
|
|
hidden_states=hidden_states_i,
|
|
attention_mask=attention_mask_i,
|
|
apply_attention_mask=self.apply_attention_mask,
|
|
)
|
|
else:
|
|
info.text_encoder_outputs = (hidden_states_i, hidden_states_i, np.zeros((hidden_states_i.shape[0], 3), dtype=np.float32), attention_mask_i)
|
|
|
|
|
|
class CogView4LatentsCachingStrategy(LatentsCachingStrategy):
|
|
COGVIEW4_LATENTS_NPZ_SUFFIX = "_cogview4.npz"
|
|
|
|
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
|
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
|
|
|
|
@property
|
|
def cache_suffix(self) -> str:
|
|
return CogView4LatentsCachingStrategy.COGVIEW4_LATENTS_NPZ_SUFFIX
|
|
|
|
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
|
|
"""Get the path for cached latents.
|
|
|
|
Args:
|
|
absolute_path: Absolute path to the source image
|
|
image_size: Tuple of (height, width) for the target resolution
|
|
|
|
Returns:
|
|
Path to the cached latents file
|
|
"""
|
|
return (
|
|
os.path.splitext(absolute_path)[0]
|
|
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
|
|
+ CogView4LatentsCachingStrategy.COGVIEW4_LATENTS_NPZ_SUFFIX
|
|
)
|
|
|
|
def is_disk_cached_latents_expected(
|
|
self,
|
|
bucket_reso: Tuple[int, int],
|
|
npz_path: str,
|
|
flip_aug: bool,
|
|
alpha_mask: bool
|
|
) -> bool:
|
|
"""Check if the latents are already cached and valid.
|
|
|
|
Args:
|
|
bucket_reso: Target resolution as (height, width)
|
|
npz_path: Path to the cached latents file
|
|
flip_aug: Whether flip augmentation was applied
|
|
alpha_mask: Whether alpha mask was used
|
|
|
|
Returns:
|
|
bool: True if valid cache exists, False otherwise
|
|
"""
|
|
# Using 8 as the default number of frames for compatibility
|
|
return self._default_is_disk_cached_latents_expected(
|
|
8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True
|
|
)
|
|
|
|
def load_latents_from_disk(
|
|
self,
|
|
npz_path: str,
|
|
bucket_reso: Tuple[int, int]
|
|
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
|
"""Load latents from disk.
|
|
|
|
Args:
|
|
npz_path: Path to the cached latents file
|
|
bucket_reso: Target resolution as (height, width)
|
|
|
|
Returns:
|
|
Tuple containing:
|
|
- latents: The loaded latents or None if loading failed
|
|
- original_size: Original image size as [height, width]
|
|
- crop_top_left: Crop offset as [top, left]
|
|
- alpha_mask: Alpha mask if available
|
|
- alpha_mask_origin: Original alpha mask if available
|
|
"""
|
|
# Using 8 as the default number of frames for compatibility
|
|
return self._default_load_latents_from_disk(8, npz_path, bucket_reso)
|
|
|
|
def cache_batch_latents(
|
|
self,
|
|
vae: Any,
|
|
image_infos: List[Any],
|
|
flip_aug: bool,
|
|
alpha_mask: bool,
|
|
random_crop: bool
|
|
) -> None:
|
|
"""Cache a batch of latents.
|
|
|
|
Args:
|
|
vae: The VAE model used for encoding
|
|
image_infos: List of image information objects
|
|
flip_aug: Whether to apply flip augmentation
|
|
alpha_mask: Whether to use alpha mask
|
|
random_crop: Whether to apply random crop
|
|
"""
|
|
# Define encoding function that moves output to CPU
|
|
def encode_by_vae(img_tensor: torch.Tensor) -> torch.Tensor:
|
|
with torch.no_grad():
|
|
return vae.encode(img_tensor).to("cpu")
|
|
|
|
# Get VAE device and dtype
|
|
vae_device = vae.device
|
|
vae_dtype = vae.dtype
|
|
|
|
# Cache latents using the default implementation
|
|
self._default_cache_batch_latents(
|
|
encode_by_vae,
|
|
vae_device,
|
|
vae_dtype,
|
|
image_infos,
|
|
flip_aug,
|
|
alpha_mask,
|
|
random_crop,
|
|
multi_resolution=True
|
|
)
|
|
|
|
# Clean up GPU memory if not in high VRAM mode
|
|
if not train_util.HIGH_VRAM:
|
|
train_util.clean_memory_on_device(vae.device)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Test code for CogView4TokenizeStrategy
|
|
tokenizer = CogView4TokenizeStrategy(512)
|
|
text = "hello world"
|
|
|
|
# Test single text tokenization
|
|
input_ids, attention_mask = tokenizer.tokenize(text)
|
|
print("Input IDs:", input_ids)
|
|
print("Attention Mask:", attention_mask)
|
|
|
|
# Test batch tokenization
|
|
texts = ["hello world", "the quick brown fox jumps over the lazy dog"]
|
|
batch_input_ids, batch_attention_mask = tokenizer.tokenize(texts)
|
|
print("\nBatch Input IDs:", batch_input_ids.shape)
|
|
print("Batch Attention Mask:", batch_attention_mask.shape)
|
|
|
|
# Test with a long text
|
|
long_text = ",".join(["hello world! this is long text"] * 10)
|
|
long_input_ids, long_attention_mask = tokenizer.tokenize(long_text)
|
|
print("\nLong text input IDs shape:", long_input_ids.shape)
|
|
print("Long text attention mask shape:", long_attention_mask.shape)
|
|
|
|
# Test text encoding strategy
|
|
print("\nTesting text encoding strategy...")
|
|
from transformers import AutoModel
|
|
|
|
# Load a small GLM model for testing
|
|
model = AutoModel.from_pretrained("THUDM/glm-10b-chinese", trust_remote_code=True)
|
|
model.eval()
|
|
|
|
encoding_strategy = CogView4TextEncodingStrategy()
|
|
tokens = tokenizer.tokenize(texts)
|
|
encoded = encoding_strategy.encode_tokens(tokenizer, [model], tokens)
|
|
|
|
print(f"Number of outputs: {len(encoded)}")
|
|
print(f"Hidden states shape: {encoded[0].shape}")
|
|
print(f"Attention mask shape: {encoded[3].shape if encoded[3] is not None else 'None'}")
|
|
|
|
# Test caching strategy
|
|
print("\nTesting caching strategy...")
|
|
import tempfile
|
|
import os
|
|
|
|
class DummyInfo:
|
|
def __init__(self, caption):
|
|
self.caption = caption
|
|
self.text_encoder_outputs_npz = tempfile.mktemp(suffix=".npz")
|
|
|
|
# Create test data
|
|
infos = [DummyInfo(text) for text in texts]
|
|
|
|
# Test caching
|
|
caching_strategy = CogView4TextEncoderOutputsCachingStrategy(
|
|
cache_to_disk=True,
|
|
batch_size=2,
|
|
skip_disk_cache_validity_check=False
|
|
)
|
|
|
|
# Cache the outputs
|
|
caching_strategy.cache_batch_outputs(tokenizer, [model], encoding_strategy, infos)
|
|
|
|
# Check if files were created
|
|
for info in infos:
|
|
exists = os.path.exists(info.text_encoder_outputs_npz)
|
|
print(f"Cache file {info.text_encoder_outputs_npz} exists: {exists}")
|
|
|
|
# Clean up
|
|
for info in infos:
|
|
if os.path.exists(info.text_encoder_outputs_npz):
|
|
os.remove(info.text_encoder_outputs_npz)
|