This commit is contained in:
sdbds
2025-05-26 10:03:30 +08:00
parent e2ed265104
commit 201e1997a2
2 changed files with 996 additions and 0 deletions

View File

@@ -0,0 +1,379 @@
import os
import glob
from typing import Any, List, Optional, Tuple, Union
import torch
import numpy as np
from transformers import AutoTokenizer
from library import train_util
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
GLM_TOKENIZER_ID = "THUDM/CogView4-6B"
class CogView4TokenizeStrategy(TokenizeStrategy):
def __init__(self, max_length: int = 512, tokenizer_cache_dir: Optional[str] = None) -> None:
self.max_length = max_length
self.tokenizer = self._load_tokenizer(AutoTokenizer, GLM_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
# Add special tokens if needed
self.tokenizer.pad_token = self.tokenizer.eos_token
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
text = [text] if isinstance(text, str) else text
# Tokenize with GLM tokenizer
tokens = self.tokenizer(
text,
max_length=self.max_length,
padding="max_length",
truncation=True,
return_tensors="pt"
)
input_ids = tokens["input_ids"]
attention_mask = tokens["attention_mask"]
return [input_ids, attention_mask]
class CogView4TextEncodingStrategy(TextEncodingStrategy):
def __init__(self, apply_attention_mask: bool = True) -> None:
"""
Args:
apply_attention_mask: Whether to apply attention mask during encoding.
"""
self.apply_attention_mask = apply_attention_mask
def encode_tokens(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
tokens: List[torch.Tensor],
apply_attention_mask: Optional[bool] = None,
) -> List[torch.Tensor]:
# supports single model inference
if apply_attention_mask is None:
apply_attention_mask = self.apply_attention_mask
# Get GLM model (should be the only model in the list)
glm_model = models[0]
input_ids = tokens[0]
attention_mask = tokens[1] if len(tokens) > 1 else None
# Move tensors to the correct device
device = glm_model.device
input_ids = input_ids.to(device)
if attention_mask is not None:
attention_mask = attention_mask.to(device)
# Get GLM model outputs
with torch.no_grad():
outputs = glm_model(
input_ids=input_ids,
attention_mask=attention_mask if apply_attention_mask else None,
output_hidden_states=True,
return_dict=True
)
# Get the last hidden state
hidden_states = outputs.hidden_states[-1] # [batch_size, seq_len, hidden_size]
# For compatibility with existing code, we'll return a list similar to the original
# but with GLM's hidden states instead of CLIP/T5 outputs
return [
hidden_states, # Replaces l_pooled
hidden_states, # Replaces t5_out (same tensor for now, can be modified if needed)
torch.zeros(hidden_states.shape[0], hidden_states.shape[1], 3, device=device), # txt_ids placeholder
attention_mask # attention mask
]
class CogView4TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
COGVIEW4_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_cogview4_te.npz"
def __init__(
self,
cache_to_disk: bool,
batch_size: int,
skip_disk_cache_validity_check: bool,
is_partial: bool = False,
apply_attention_mask: bool = True,
) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
self.apply_attention_mask = apply_attention_mask
self.warn_fp8_weights = False
def get_outputs_npz_path(self, image_abs_path: str) -> str:
return os.path.splitext(image_abs_path)[0] + CogView4TextEncoderOutputsCachingStrategy.COGVIEW4_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
def is_disk_cached_outputs_expected(self, npz_path: str):
if not self.cache_to_disk:
return False
if not os.path.exists(npz_path):
return False
if self.skip_disk_cache_validity_check:
return True
try:
npz = np.load(npz_path)
required_fields = ["hidden_states", "attention_mask", "apply_attention_mask"]
for field in required_fields:
if field not in npz:
return False
npz_apply_attention_mask = bool(npz["apply_attention_mask"])
if npz_apply_attention_mask != self.apply_attention_mask:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
logger.exception(e)
return False
return True
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
data = np.load(npz_path)
hidden_states = data["hidden_states"]
attention_mask = data["attention_mask"]
return [
hidden_states, # l_pooled replacement
hidden_states, # t5_out replacement
np.zeros((hidden_states.shape[0], hidden_states.shape[1], 3), dtype=np.float32), # txt_ids
attention_mask # attention mask
]
def cache_batch_outputs(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
):
if not self.warn_fp8_weights:
model_dtype = next(models[0].parameters()).dtype
if model_dtype == torch.float8_e4m3fn or model_dtype == torch.float8_e5m2:
logger.warning(
"Model is using fp8 weights for caching. This may affect the quality of the cached outputs."
" / モデルはfp8の重みを使用しています。これはキャッシュの品質に影響を与える可能性があります。"
)
self.warn_fp8_weights = True
captions = [info.caption for info in infos]
tokens_and_masks = tokenize_strategy.tokenize(captions)
with torch.no_grad():
hidden_states, _, _, attention_mask = text_encoding_strategy.encode_tokens(
tokenize_strategy, models, tokens_and_masks
)
if hidden_states.dtype == torch.bfloat16:
hidden_states = hidden_states.float()
hidden_states = hidden_states.cpu().numpy()
attention_mask = attention_mask.cpu().numpy() if attention_mask is not None else None
for i, info in enumerate(infos):
hidden_states_i = hidden_states[i]
attention_mask_i = attention_mask[i] if attention_mask is not None else None
if self.cache_to_disk and hasattr(info, 'text_encoder_outputs_npz'):
np.savez(
info.text_encoder_outputs_npz,
hidden_states=hidden_states_i,
attention_mask=attention_mask_i,
apply_attention_mask=self.apply_attention_mask,
)
else:
info.text_encoder_outputs = (hidden_states_i, hidden_states_i, np.zeros((hidden_states_i.shape[0], 3), dtype=np.float32), attention_mask_i)
class CogView4LatentsCachingStrategy(LatentsCachingStrategy):
COGVIEW4_LATENTS_NPZ_SUFFIX = "_cogview4.npz"
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
@property
def cache_suffix(self) -> str:
return CogView4LatentsCachingStrategy.COGVIEW4_LATENTS_NPZ_SUFFIX
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
"""Get the path for cached latents.
Args:
absolute_path: Absolute path to the source image
image_size: Tuple of (height, width) for the target resolution
Returns:
Path to the cached latents file
"""
return (
os.path.splitext(absolute_path)[0]
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
+ CogView4LatentsCachingStrategy.COGVIEW4_LATENTS_NPZ_SUFFIX
)
def is_disk_cached_latents_expected(
self,
bucket_reso: Tuple[int, int],
npz_path: str,
flip_aug: bool,
alpha_mask: bool
) -> bool:
"""Check if the latents are already cached and valid.
Args:
bucket_reso: Target resolution as (height, width)
npz_path: Path to the cached latents file
flip_aug: Whether flip augmentation was applied
alpha_mask: Whether alpha mask was used
Returns:
bool: True if valid cache exists, False otherwise
"""
# Using 8 as the default number of frames for compatibility
return self._default_is_disk_cached_latents_expected(
8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True
)
def load_latents_from_disk(
self,
npz_path: str,
bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
"""Load latents from disk.
Args:
npz_path: Path to the cached latents file
bucket_reso: Target resolution as (height, width)
Returns:
Tuple containing:
- latents: The loaded latents or None if loading failed
- original_size: Original image size as [height, width]
- crop_top_left: Crop offset as [top, left]
- alpha_mask: Alpha mask if available
- alpha_mask_origin: Original alpha mask if available
"""
# Using 8 as the default number of frames for compatibility
return self._default_load_latents_from_disk(8, npz_path, bucket_reso)
def cache_batch_latents(
self,
vae: Any,
image_infos: List[Any],
flip_aug: bool,
alpha_mask: bool,
random_crop: bool
) -> None:
"""Cache a batch of latents.
Args:
vae: The VAE model used for encoding
image_infos: List of image information objects
flip_aug: Whether to apply flip augmentation
alpha_mask: Whether to use alpha mask
random_crop: Whether to apply random crop
"""
# Define encoding function that moves output to CPU
def encode_by_vae(img_tensor: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
return vae.encode(img_tensor).to("cpu")
# Get VAE device and dtype
vae_device = vae.device
vae_dtype = vae.dtype
# Cache latents using the default implementation
self._default_cache_batch_latents(
encode_by_vae,
vae_device,
vae_dtype,
image_infos,
flip_aug,
alpha_mask,
random_crop,
multi_resolution=True
)
# Clean up GPU memory if not in high VRAM mode
if not train_util.HIGH_VRAM:
train_util.clean_memory_on_device(vae.device)
if __name__ == "__main__":
# Test code for CogView4TokenizeStrategy
tokenizer = CogView4TokenizeStrategy(512)
text = "hello world"
# Test single text tokenization
input_ids, attention_mask = tokenizer.tokenize(text)
print("Input IDs:", input_ids)
print("Attention Mask:", attention_mask)
# Test batch tokenization
texts = ["hello world", "the quick brown fox jumps over the lazy dog"]
batch_input_ids, batch_attention_mask = tokenizer.tokenize(texts)
print("\nBatch Input IDs:", batch_input_ids.shape)
print("Batch Attention Mask:", batch_attention_mask.shape)
# Test with a long text
long_text = ",".join(["hello world! this is long text"] * 10)
long_input_ids, long_attention_mask = tokenizer.tokenize(long_text)
print("\nLong text input IDs shape:", long_input_ids.shape)
print("Long text attention mask shape:", long_attention_mask.shape)
# Test text encoding strategy
print("\nTesting text encoding strategy...")
from transformers import AutoModel
# Load a small GLM model for testing
model = AutoModel.from_pretrained("THUDM/glm-10b-chinese", trust_remote_code=True)
model.eval()
encoding_strategy = CogView4TextEncodingStrategy()
tokens = tokenizer.tokenize(texts)
encoded = encoding_strategy.encode_tokens(tokenizer, [model], tokens)
print(f"Number of outputs: {len(encoded)}")
print(f"Hidden states shape: {encoded[0].shape}")
print(f"Attention mask shape: {encoded[3].shape if encoded[3] is not None else 'None'}")
# Test caching strategy
print("\nTesting caching strategy...")
import tempfile
import os
class DummyInfo:
def __init__(self, caption):
self.caption = caption
self.text_encoder_outputs_npz = tempfile.mktemp(suffix=".npz")
# Create test data
infos = [DummyInfo(text) for text in texts]
# Test caching
caching_strategy = CogView4TextEncoderOutputsCachingStrategy(
cache_to_disk=True,
batch_size=2,
skip_disk_cache_validity_check=False
)
# Cache the outputs
caching_strategy.cache_batch_outputs(tokenizer, [model], encoding_strategy, infos)
# Check if files were created
for info in infos:
exists = os.path.exists(info.text_encoder_outputs_npz)
print(f"Cache file {info.text_encoder_outputs_npz} exists: {exists}")
# Clean up
for info in infos:
if os.path.exists(info.text_encoder_outputs_npz):
os.remove(info.text_encoder_outputs_npz)