Files
Kohya-ss-sd-scripts/library/strategy_cogview4.py
2025-05-26 10:03:30 +08:00

380 lines
14 KiB
Python

import os
import glob
from typing import Any, List, Optional, Tuple, Union
import torch
import numpy as np
from transformers import AutoTokenizer
from library import train_util
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
GLM_TOKENIZER_ID = "THUDM/CogView4-6B"
class CogView4TokenizeStrategy(TokenizeStrategy):
def __init__(self, max_length: int = 512, tokenizer_cache_dir: Optional[str] = None) -> None:
self.max_length = max_length
self.tokenizer = self._load_tokenizer(AutoTokenizer, GLM_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
# Add special tokens if needed
self.tokenizer.pad_token = self.tokenizer.eos_token
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
text = [text] if isinstance(text, str) else text
# Tokenize with GLM tokenizer
tokens = self.tokenizer(
text,
max_length=self.max_length,
padding="max_length",
truncation=True,
return_tensors="pt"
)
input_ids = tokens["input_ids"]
attention_mask = tokens["attention_mask"]
return [input_ids, attention_mask]
class CogView4TextEncodingStrategy(TextEncodingStrategy):
def __init__(self, apply_attention_mask: bool = True) -> None:
"""
Args:
apply_attention_mask: Whether to apply attention mask during encoding.
"""
self.apply_attention_mask = apply_attention_mask
def encode_tokens(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
tokens: List[torch.Tensor],
apply_attention_mask: Optional[bool] = None,
) -> List[torch.Tensor]:
# supports single model inference
if apply_attention_mask is None:
apply_attention_mask = self.apply_attention_mask
# Get GLM model (should be the only model in the list)
glm_model = models[0]
input_ids = tokens[0]
attention_mask = tokens[1] if len(tokens) > 1 else None
# Move tensors to the correct device
device = glm_model.device
input_ids = input_ids.to(device)
if attention_mask is not None:
attention_mask = attention_mask.to(device)
# Get GLM model outputs
with torch.no_grad():
outputs = glm_model(
input_ids=input_ids,
attention_mask=attention_mask if apply_attention_mask else None,
output_hidden_states=True,
return_dict=True
)
# Get the last hidden state
hidden_states = outputs.hidden_states[-1] # [batch_size, seq_len, hidden_size]
# For compatibility with existing code, we'll return a list similar to the original
# but with GLM's hidden states instead of CLIP/T5 outputs
return [
hidden_states, # Replaces l_pooled
hidden_states, # Replaces t5_out (same tensor for now, can be modified if needed)
torch.zeros(hidden_states.shape[0], hidden_states.shape[1], 3, device=device), # txt_ids placeholder
attention_mask # attention mask
]
class CogView4TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
COGVIEW4_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_cogview4_te.npz"
def __init__(
self,
cache_to_disk: bool,
batch_size: int,
skip_disk_cache_validity_check: bool,
is_partial: bool = False,
apply_attention_mask: bool = True,
) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
self.apply_attention_mask = apply_attention_mask
self.warn_fp8_weights = False
def get_outputs_npz_path(self, image_abs_path: str) -> str:
return os.path.splitext(image_abs_path)[0] + CogView4TextEncoderOutputsCachingStrategy.COGVIEW4_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
def is_disk_cached_outputs_expected(self, npz_path: str):
if not self.cache_to_disk:
return False
if not os.path.exists(npz_path):
return False
if self.skip_disk_cache_validity_check:
return True
try:
npz = np.load(npz_path)
required_fields = ["hidden_states", "attention_mask", "apply_attention_mask"]
for field in required_fields:
if field not in npz:
return False
npz_apply_attention_mask = bool(npz["apply_attention_mask"])
if npz_apply_attention_mask != self.apply_attention_mask:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
logger.exception(e)
return False
return True
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
data = np.load(npz_path)
hidden_states = data["hidden_states"]
attention_mask = data["attention_mask"]
return [
hidden_states, # l_pooled replacement
hidden_states, # t5_out replacement
np.zeros((hidden_states.shape[0], hidden_states.shape[1], 3), dtype=np.float32), # txt_ids
attention_mask # attention mask
]
def cache_batch_outputs(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
):
if not self.warn_fp8_weights:
model_dtype = next(models[0].parameters()).dtype
if model_dtype == torch.float8_e4m3fn or model_dtype == torch.float8_e5m2:
logger.warning(
"Model is using fp8 weights for caching. This may affect the quality of the cached outputs."
" / モデルはfp8の重みを使用しています。これはキャッシュの品質に影響を与える可能性があります。"
)
self.warn_fp8_weights = True
captions = [info.caption for info in infos]
tokens_and_masks = tokenize_strategy.tokenize(captions)
with torch.no_grad():
hidden_states, _, _, attention_mask = text_encoding_strategy.encode_tokens(
tokenize_strategy, models, tokens_and_masks
)
if hidden_states.dtype == torch.bfloat16:
hidden_states = hidden_states.float()
hidden_states = hidden_states.cpu().numpy()
attention_mask = attention_mask.cpu().numpy() if attention_mask is not None else None
for i, info in enumerate(infos):
hidden_states_i = hidden_states[i]
attention_mask_i = attention_mask[i] if attention_mask is not None else None
if self.cache_to_disk and hasattr(info, 'text_encoder_outputs_npz'):
np.savez(
info.text_encoder_outputs_npz,
hidden_states=hidden_states_i,
attention_mask=attention_mask_i,
apply_attention_mask=self.apply_attention_mask,
)
else:
info.text_encoder_outputs = (hidden_states_i, hidden_states_i, np.zeros((hidden_states_i.shape[0], 3), dtype=np.float32), attention_mask_i)
class CogView4LatentsCachingStrategy(LatentsCachingStrategy):
COGVIEW4_LATENTS_NPZ_SUFFIX = "_cogview4.npz"
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
@property
def cache_suffix(self) -> str:
return CogView4LatentsCachingStrategy.COGVIEW4_LATENTS_NPZ_SUFFIX
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
"""Get the path for cached latents.
Args:
absolute_path: Absolute path to the source image
image_size: Tuple of (height, width) for the target resolution
Returns:
Path to the cached latents file
"""
return (
os.path.splitext(absolute_path)[0]
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
+ CogView4LatentsCachingStrategy.COGVIEW4_LATENTS_NPZ_SUFFIX
)
def is_disk_cached_latents_expected(
self,
bucket_reso: Tuple[int, int],
npz_path: str,
flip_aug: bool,
alpha_mask: bool
) -> bool:
"""Check if the latents are already cached and valid.
Args:
bucket_reso: Target resolution as (height, width)
npz_path: Path to the cached latents file
flip_aug: Whether flip augmentation was applied
alpha_mask: Whether alpha mask was used
Returns:
bool: True if valid cache exists, False otherwise
"""
# Using 8 as the default number of frames for compatibility
return self._default_is_disk_cached_latents_expected(
8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True
)
def load_latents_from_disk(
self,
npz_path: str,
bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
"""Load latents from disk.
Args:
npz_path: Path to the cached latents file
bucket_reso: Target resolution as (height, width)
Returns:
Tuple containing:
- latents: The loaded latents or None if loading failed
- original_size: Original image size as [height, width]
- crop_top_left: Crop offset as [top, left]
- alpha_mask: Alpha mask if available
- alpha_mask_origin: Original alpha mask if available
"""
# Using 8 as the default number of frames for compatibility
return self._default_load_latents_from_disk(8, npz_path, bucket_reso)
def cache_batch_latents(
self,
vae: Any,
image_infos: List[Any],
flip_aug: bool,
alpha_mask: bool,
random_crop: bool
) -> None:
"""Cache a batch of latents.
Args:
vae: The VAE model used for encoding
image_infos: List of image information objects
flip_aug: Whether to apply flip augmentation
alpha_mask: Whether to use alpha mask
random_crop: Whether to apply random crop
"""
# Define encoding function that moves output to CPU
def encode_by_vae(img_tensor: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
return vae.encode(img_tensor).to("cpu")
# Get VAE device and dtype
vae_device = vae.device
vae_dtype = vae.dtype
# Cache latents using the default implementation
self._default_cache_batch_latents(
encode_by_vae,
vae_device,
vae_dtype,
image_infos,
flip_aug,
alpha_mask,
random_crop,
multi_resolution=True
)
# Clean up GPU memory if not in high VRAM mode
if not train_util.HIGH_VRAM:
train_util.clean_memory_on_device(vae.device)
if __name__ == "__main__":
# Test code for CogView4TokenizeStrategy
tokenizer = CogView4TokenizeStrategy(512)
text = "hello world"
# Test single text tokenization
input_ids, attention_mask = tokenizer.tokenize(text)
print("Input IDs:", input_ids)
print("Attention Mask:", attention_mask)
# Test batch tokenization
texts = ["hello world", "the quick brown fox jumps over the lazy dog"]
batch_input_ids, batch_attention_mask = tokenizer.tokenize(texts)
print("\nBatch Input IDs:", batch_input_ids.shape)
print("Batch Attention Mask:", batch_attention_mask.shape)
# Test with a long text
long_text = ",".join(["hello world! this is long text"] * 10)
long_input_ids, long_attention_mask = tokenizer.tokenize(long_text)
print("\nLong text input IDs shape:", long_input_ids.shape)
print("Long text attention mask shape:", long_attention_mask.shape)
# Test text encoding strategy
print("\nTesting text encoding strategy...")
from transformers import AutoModel
# Load a small GLM model for testing
model = AutoModel.from_pretrained("THUDM/glm-10b-chinese", trust_remote_code=True)
model.eval()
encoding_strategy = CogView4TextEncodingStrategy()
tokens = tokenizer.tokenize(texts)
encoded = encoding_strategy.encode_tokens(tokenizer, [model], tokens)
print(f"Number of outputs: {len(encoded)}")
print(f"Hidden states shape: {encoded[0].shape}")
print(f"Attention mask shape: {encoded[3].shape if encoded[3] is not None else 'None'}")
# Test caching strategy
print("\nTesting caching strategy...")
import tempfile
import os
class DummyInfo:
def __init__(self, caption):
self.caption = caption
self.text_encoder_outputs_npz = tempfile.mktemp(suffix=".npz")
# Create test data
infos = [DummyInfo(text) for text in texts]
# Test caching
caching_strategy = CogView4TextEncoderOutputsCachingStrategy(
cache_to_disk=True,
batch_size=2,
skip_disk_cache_validity_check=False
)
# Cache the outputs
caching_strategy.cache_batch_outputs(tokenizer, [model], encoding_strategy, infos)
# Check if files were created
for info in infos:
exists = os.path.exists(info.text_encoder_outputs_npz)
print(f"Cache file {info.text_encoder_outputs_npz} exists: {exists}")
# Clean up
for info in infos:
if os.path.exists(info.text_encoder_outputs_npz):
os.remove(info.text_encoder_outputs_npz)