Compare commits

...

3 Commits

Author SHA1 Message Date
Qing Long
003cb8ee1c Merge 201e1997a2 into fa53f71ec0 2026-04-05 01:17:07 +00:00
Kohya S.
fa53f71ec0 fix: improve numerical stability by conditionally using float32 in Anima (#2302)
* fix: improve numerical stability by conditionally using float32 in block computations

* doc: update README for improvement stability for fp16 training on Anima in version 0.10.3
2026-04-02 12:36:29 +09:00
sdbds
201e1997a2 init 2025-05-26 10:03:30 +08:00
5 changed files with 1014 additions and 4 deletions

View File

@@ -50,6 +50,9 @@ Stable Diffusion等の画像生成モデルの学習、モデルによる画像
### 更新履歴
- **Version 0.10.3 (2026-04-02):**
- Animaでfp16で学習する際の安定性をさらに改善しました。[PR #2302](https://github.com/kohya-ss/sd-scripts/pull/2302) 問題をご報告いただいた方々に深く感謝します。
- **Version 0.10.2 (2026-03-30):**
- SD/SDXLのLECO学習に対応しました。[PR #2285](https://github.com/kohya-ss/sd-scripts/pull/2285) および [PR #2294](https://github.com/kohya-ss/sd-scripts/pull/2294) umisetokikaze氏に深く感謝します。
- 詳細は[ドキュメント](./docs/train_leco.md)をご覧ください。

View File

@@ -47,6 +47,9 @@ If you find this project helpful, please consider supporting its development via
### Change History
- **Version 0.10.3 (2026-04-02):**
- Stability when training with fp16 on Anima has been further improved. See [PR #2302](https://github.com/kohya-ss/sd-scripts/pull/2302) for details. We deeply appreciate those who reported the issue.
- **Version 0.10.2 (2026-03-30):**
- LECO training for SD/SDXL is now supported. Many thanks to umisetokikaze for [PR #2285](https://github.com/kohya-ss/sd-scripts/pull/2285) and [PR #2294](https://github.com/kohya-ss/sd-scripts/pull/2294).
- Please refer to the [documentation](./docs/train_leco.md) for details.

617
cogview4_train_network.py Normal file
View File

@@ -0,0 +1,617 @@
import argparse
import copy
import math
import random
import os
import inspect
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from accelerate import Accelerator
from diffusers import AutoencoderKL, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
from diffusers.pipelines.cogview4.pipeline_cogview4 import CogView4Pipeline
from PIL import Image
from transformers import AutoTokenizer, GlmModel
from library.device_utils import clean_memory_on_device, init_ipex
init_ipex()
import train_network
from library import (
flux_models,
flux_train_utils,
flux_utils,
sd3_train_utils,
strategy_base,
strategy_cogview4,
train_util,
)
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
class CogView4NetworkTrainer(train_network.NetworkTrainer):
def __init__(self):
super().__init__()
self.sample_prompts_te_outputs = None
self.is_swapping_blocks: bool = False
def assert_extra_args(
self,
args,
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
val_dataset_group: Optional[train_util.DatasetGroup],
):
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
# CogView4 specific argument validation
if hasattr(args, 'fp8_base_unet') and args.fp8_base_unet:
logger.warning("FP8 training is not yet fully supported for CogView4. Disabling fp8_base_unet.")
args.fp8_base_unet = False
if hasattr(args, 'cache_text_encoder_outputs') and args.cache_text_encoder_outputs:
logger.warning("Text encoder output caching is not yet implemented for CogView4. Disabling.")
args.cache_text_encoder_outputs = False
if hasattr(args, 'cache_text_encoder_outputs_to_disk') and args.cache_text_encoder_outputs_to_disk:
logger.warning("Text encoder output disk caching is not yet implemented for CogView4. Disabling.")
args.cache_text_encoder_outputs_to_disk = False
# Set default values for CogView4
if not hasattr(args, 'max_token_length'):
args.max_token_length = 128 # Default token length for GLM
if not hasattr(args, 'resolution'):
args.resolution = 256 # Default resolution for CogView4
# Update dataset resolution if needed
train_dataset_group.set_resolution(args.resolution)
if val_dataset_group is not None:
val_dataset_group.verify_bucket_reso_steps(32) # TODO check this
def load_target_model(self, args, weight_dtype, accelerator):
"""
Load the CogView4 model components including tokenizer, text encoder, VAE, and transformer.
"""
logger.info(f"Loading CogView4 model from {args.pretrained_model_name_or_path}")
# Load tokenizer and text encoder (GLM)
self.tokenizer = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="tokenizer",
use_fast=False,
trust_remote_code=True
)
self.text_encoder = GlmModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="text_encoder",
torch_dtype=weight_dtype,
trust_remote_code=True
)
self.text_encoder.eval()
# Load VAE
self.vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="vae",
torch_dtype=weight_dtype,
trust_remote_code=True
)
self.vae.eval()
# Load transformer
self.transformer = CogView4Transformer2DModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="transformer",
torch_dtype=weight_dtype,
trust_remote_code=True
)
# Create noise scheduler
self.noise_scheduler = FlowMatchEulerDiscreteScheduler(
num_train_timesteps=1000,
beta_start=0.0001,
beta_end=0.02,
beta_schedule="linear",
prediction_type="epsilon",
)
# Move models to device
device = accelerator.device
self.text_encoder = self.text_encoder.to(device)
self.vae = self.vae.to(device)
self.transformer = self.transformer.to(device)
# Set gradient checkpointing if enabled
if args.gradient_checkpointing:
self.transformer.enable_gradient_checkpointing()
# Text encoder gradient checkpointing is handled by prepare_text_encoder_grad_ckpt_workaround
# called by the base trainer if args.train_text_encoder is true for that TE.
# Store components for later use
self.weight_dtype = weight_dtype
# Return components in the expected format
return "cogview4-v1", [self.text_encoder], self.vae, self.transformer
def get_tokenize_strategy(self, args):
# For CogView4, we use a fixed token length for GLM
max_token_length = getattr(args, 'max_token_length', 128)
logger.info(f"Using max_token_length: {max_token_length} for GLM tokenizer")
return strategy_cogview4.CogView4TokenizeStrategy(max_token_length, args.tokenizer_cache_dir)
def get_tokenizers(self, tokenize_strategy):
# For CogView4, we only have one tokenizer (GLM)
return [tokenize_strategy.tokenizer]
def get_latents_caching_strategy(self, args):
return strategy_cogview4.CogView4LatentsCachingStrategy(
args.cache_latents_to_disk,
args.vae_batch_size,
skip_disk_cache_validity_check=False
)
def get_text_encoding_strategy(self, args):
# For CogView4, we use GLM instead of T5, but maintain similar interface
return strategy_cogview4.CogView4TextEncodingStrategy(
apply_attention_mask=getattr(args, 'apply_attention_mask', True)
)
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
# For CogView4, we always return the text encoder (GLM) as it's needed for encoding
return text_encoders
def get_text_encoders_train_flags(self, args, text_encoders):
# For CogView4, we only have one text encoder (GLM)
return [getattr(args, 'train_text_encoder', False)]
def get_text_encoder_outputs_caching_strategy(self, args):
if getattr(args, 'cache_text_encoder_outputs', False):
return strategy_cogview4.CogView4TextEncoderOutputsCachingStrategy(
cache_to_disk=getattr(args, 'cache_text_encoder_outputs_to_disk', False),
batch_size=getattr(args, 'text_encoder_batch_size', 1),
skip_disk_cache_validity_check=getattr(args, 'skip_cache_check', False),
is_partial=getattr(args, 'train_text_encoder', False)
)
return None
def cache_text_encoder_outputs_if_needed(
self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
):
"""Cache text encoder outputs to speed up training.
Args:
args: Training arguments
accelerator: Accelerator instance
unet: UNet model
vae: VAE model
text_encoders: List containing the GLM text encoder
dataset: Dataset to cache text encoder outputs for
weight_dtype: Data type for weights
"""
if getattr(args, 'cache_text_encoder_outputs', False):
if not getattr(args, 'lowram', False):
# Free up GPU memory by moving models to CPU
logger.info("Moving VAE and UNet to CPU to save memory")
org_vae_device = vae.device
org_unet_device = unet.device
vae.to("cpu")
unet.to("cpu")
clean_memory_on_device(accelerator.device)
# Move text encoder to GPU with proper dtype
logger.info("Moving text encoder to GPU")
text_encoder = text_encoders[0] # CogView4 uses a single text encoder (GLM)
text_encoder.to(accelerator.device, dtype=weight_dtype)
# Cache text encoder outputs
with accelerator.autocast():
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
# Cache sample prompts if provided
if getattr(args, 'sample_prompts', None) is not None:
logger.info(f"Caching text encoder outputs for sample prompts: {args.sample_prompts}")
# Initialize CogView4 strategies
tokenize_strategy = strategy_cogview4.CogView4TokenizeStrategy(
max_length=getattr(args, 'max_token_length', 128),
tokenizer_cache_dir=getattr(args, 'tokenizer_cache_dir', None)
)
text_encoding_strategy = strategy_cogview4.CogView4TextEncodingStrategy(
apply_attention_mask=getattr(args, 'apply_attention_mask', True)
)
prompts = train_util.load_prompts(args.sample_prompts)
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
with accelerator.autocast(), torch.no_grad():
for prompt_dict in prompts:
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
if p and p not in sample_prompts_te_outputs: # Skip empty prompts and duplicates
logger.info(f"Caching text encoder outputs for prompt: {p}")
tokens = tokenize_strategy.tokenize(p)
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
tokenize_strategy=tokenize_strategy,
models=text_encoders,
tokens=tokens,
apply_attention_mask=getattr(args, 'apply_attention_mask', True)
)
self.sample_prompts_te_outputs = sample_prompts_te_outputs
accelerator.wait_for_everyone()
# Move text encoder back to CPU if not training it
if not getattr(args, 'train_text_encoder', False):
logger.info("Moving text encoder back to CPU")
text_encoder.to("cpu")
clean_memory_on_device(accelerator.device)
# Move VAE and UNet back to their original devices if not in lowram mode
if not getattr(args, 'lowram', False):
logger.info("Moving VAE and UNet back to original devices")
vae.to(org_vae_device)
unet.to(org_unet_device)
else:
# Keep text encoder in GPU if we're not caching outputs
if text_encoders:
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
# def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
# noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
# # get size embeddings
# orig_size = batch["original_sizes_hw"]
# crop_size = batch["crop_top_lefts"]
# target_size = batch["target_sizes_hw"]
# embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
# # concat embeddings
# encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds
# vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
# text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
# noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
# return noise_pred
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, transformer):
"""
Generate sample images during training to monitor progress.
"""
logger.info(f"Generating sample images at step {global_step}")
# Set models to eval mode
was_training = transformer.training
transformer.eval()
vae.eval()
# Sample prompts to use for generation
sample_prompts = [
"A high quality photo of a cat",
"A beautiful landscape with mountains and a lake",
"A futuristic city at night"
]
# Generate images for each prompt
all_images = []
with torch.no_grad():
for prompt in sample_prompts:
# Tokenize the prompt
text_input = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
input_ids = text_input.input_ids.to(device)
attention_mask = text_input.attention_mask.to(device)
# Get text embeddings
with torch.no_grad():
text_embeddings = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True
).last_hidden_state
# Sample random noise
latents = torch.randn(
(1, 4, args.resolution // 8, args.resolution // 8),
device=device,
dtype=torch.float32
)
# Set the scheduler for inference
self.noise_scheduler.set_timesteps(50, device=device)
# Generate image using the denoising process
for t in self.noise_scheduler.timesteps:
# Expand the latents if we are doing classifier-free guidance
latent_model_input = torch.cat([latents] * 2) if args.guidance_scale > 1.0 else latents
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, t)
# Predict the noise residual
with torch.no_grad():
noise_pred = transformer(
latent_model_input,
t.unsqueeze(0).repeat(latent_model_input.shape[0]),
encoder_hidden_states=torch.cat([text_embeddings] * 2) if args.guidance_scale > 1.0 else text_embeddings,
attention_mask=attention_mask,
return_dict=True,
).sample
# Perform guidance
if args.guidance_scale > 1.0:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_text - noise_pred_uncond)
# Compute the previous noisy sample x_t -> x_t-1
latents = self.noise_scheduler.step(noise_pred, t, latents).prev_sample
# Scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
with torch.no_grad():
image = vae.decode(latents).sample
# Convert to PIL image
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
all_images.extend(pil_images)
# Log images to tensorboard if available
if accelerator.is_main_process and hasattr(accelerator, "log"):
log_images = []
for i, img in enumerate(all_images):
# Convert PIL image to numpy for logging
log_images.append(np.array(img))
# Save individual images
os.makedirs(os.path.join(args.output_dir, "samples"), exist_ok=True)
img.save(os.path.join(args.output_dir, "samples", f"sample_epoch{epoch}_step{global_step}_{i}.png"))
# Log to tensorboard
accelerator.log({
"samples": [
wandb.Image(img, caption=f"{sample_prompts[i]}")
for i, img in enumerate(log_images)
]
}, step=global_step)
# Set models back to training mode if they were training before
if was_training:
transformer.train()
vae.train()
return all_images
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
# Return the noise scheduler that was created during model loading
return self.noise_scheduler
def encode_images_to_latents(self, args, vae, images):
return vae.encode(images)
def shift_scale_latents(self, args, latents):
return latents
def get_noise_pred_and_target(
self,
args,
accelerator,
noise_scheduler,
latents,
batch,
text_encoder_conds,
transformer: CogView4Transformer2DModel,
network,
weight_dtype,
train_unet=True,
is_train=True,
):
"""
Get noise prediction and target for the loss computation.
"""
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
).long()
# Add noise to the latents according to the noise magnitude at each timestep
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get text embeddings
input_ids = batch["input_ids"]
attention_mask = batch.get("attention_mask", None)
# Prepare text encoder outputs
with torch.set_grad_enabled(self.train_text_encoder):
text_embeddings = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True
).last_hidden_state
# Predict the noise residual
with torch.set_grad_enabled(train_unet):
# Add conditioning dropout for classifier-free guidance
if args.guidance_scale > 1.0:
# Randomly drop text conditioning 5% of the time
mask = (torch.rand(bsz, device=latents.device) < 0.05).float().unsqueeze(1).unsqueeze(1)
text_embeddings = text_embeddings * (1 - mask) + torch.zeros_like(text_embeddings) * mask
# Predict noise
model_pred = transformer(
noisy_latents,
timesteps,
encoder_hidden_states=text_embeddings,
attention_mask=attention_mask,
return_dict=True,
).sample
# Calculate target
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
# For classifier-free guidance, we need to do two forward passes
if args.guidance_scale > 1.0:
# Predict the conditional and unconditional outputs
model_pred_uncond = transformer(
noisy_latents,
timesteps,
encoder_hidden_states=torch.zeros_like(text_embeddings),
attention_mask=attention_mask,
return_dict=True,
).sample
# Perform classifier-free guidance
model_pred = model_pred_uncond + args.guidance_scale * (model_pred - model_pred_uncond)
# For training, we only compute the loss on the conditional prediction
if is_train:
model_pred = model_pred_uncond + args.guidance_scale * (model_pred - model_pred_uncond)
# Simple weighting - can be adjusted based on timestep if needed
weighting = torch.ones_like(timesteps, dtype=weight_dtype, device=latents.device)
return model_pred, target, timesteps, weighting
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
"""
Post-process the loss value.
This can include applying timestep weighting, gradient clipping, etc.
"""
# Apply timestep weighting if specified
if hasattr(args, 'timestep_bias_portion') and args.timestep_bias_portion > 0.0:
# Simple timestep weighting - can be made more sophisticated if needed
weights = torch.ones_like(timesteps, dtype=torch.float32)
if hasattr(args, 'timestep_bias_begin') and args.timestep_bias_begin > 0:
mask = timesteps < args.timestep_bias_begin
weights[mask] = 0.0
if hasattr(args, 'timestep_bias_end') and args.timestep_bias_end < 1000:
mask = timesteps > args.timestep_bias_end
weights[mask] = 0.0
if hasattr(args, 'timestep_bias_multiplier') and args.timestep_bias_multiplier != 1.0:
weights = weights * args.timestep_bias_multiplier
loss = loss * weights.to(loss.device)
# Clip loss values if specified
if hasattr(args, 'clip_grad_norm') and args.clip_grad_norm > 0.0:
loss = torch.clamp(loss, -args.clip_grad_norm, args.clip_grad_norm)
return loss
def prepare_extra_step_kwargs(self, generator, eta):
"""
Prepare extra kwargs for the scheduler step, such as the generator for reproducibility.
"""
# Prepare extra step kwargs.
# TODO: Logic should ideally just be moved to base class
accepts_eta = "eta" in set(inspect.signature(self.noise_scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# Check if the scheduler accepts a generator
accepts_generator = "generator" in set(inspect.signature(self.noise_scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def update_metadata(self, metadata, args):
metadata["ss_apply_attn_mask"] = args.apply_attn_mask
metadata["ss_weighting_scheme"] = args.weighting_scheme
metadata["ss_logit_mean"] = args.logit_mean
metadata["ss_logit_std"] = args.logit_std
metadata["ss_mode_scale"] = args.mode_scale
metadata["ss_guidance_scale"] = args.guidance_scale
metadata["ss_timestep_sampling"] = args.timestep_sampling
metadata["ss_sigmoid_scale"] = args.sigmoid_scale
metadata["ss_model_prediction_type"] = args.model_prediction_type
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
def is_text_encoder_not_needed_for_training(self, args):
"""Check if text encoder outputs are cached and not being trained."""
return getattr(args, 'cache_text_encoder_outputs', False) and not getattr(args, 'train_text_encoder', False)
def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
"""Prepare text encoder for gradient checkpointing.
For CogView4, we only have one text encoder (GLM) so we don't need index-based handling.
The base class method handles enabling gradient checkpointing if args specify it.
"""
return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder)
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
"""Prepare text encoder for FP8 training.
Args:
index: Text encoder index (always 0 for CogView4)
text_encoder: The text encoder model (GLM)
te_weight_dtype: Target weight dtype for the encoder
weight_dtype: Base weight dtype for embeddings
"""
if index != 0:
logger.warning(f"Unexpected text encoder index {index} for CogView4, expecting 0.")
# Still proceed, assuming it's the single GLM encoder
logger.info(f"Preparing GLM text encoder (index {index}) for {te_weight_dtype}, embeddings to {weight_dtype}")
text_encoder.to(te_weight_dtype)
# Move embeddings to base weight dtype if they exist
if hasattr(text_encoder, 'word_embeddings'): # GLM typically has word_embeddings
text_encoder.word_embeddings.to(dtype=weight_dtype)
if hasattr(text_encoder, 'position_embeddings'): # GLM might have position_embeddings
text_encoder.position_embeddings.to(dtype=weight_dtype)
# Add other relevant parts of GLM if they need specific dtype handling for FP8
return text_encoder
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
"""Called at the end of each validation step."""
# No special handling needed for CogView4 (e.g., no block swapping)
pass
def prepare_unet_with_accelerator(
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
) -> torch.nn.Module:
"""Prepare UNet model (CogView4Transformer2DModel) with accelerator.
For CogView4, we use standard model preparation.
"""
return super().prepare_unet_with_accelerator(args, accelerator, unet)
def setup_parser() -> argparse.ArgumentParser:
parser = train_network.setup_parser()
train_util.add_dit_training_arguments(parser)
flux_train_utils.add_flux_train_arguments(parser)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
trainer = CogView4NetworkTrainer()
trainer.train(args)

View File

@@ -738,9 +738,9 @@ class FinalLayer(nn.Module):
x_B_T_H_W_D: torch.Tensor,
emb_B_T_D: torch.Tensor,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
use_fp32: bool = False,
):
# Compute AdaLN modulation parameters (in float32 when fp16 to avoid overflow in Linear layers)
use_fp32 = x_B_T_H_W_D.dtype == torch.float16
with torch.autocast(device_type=x_B_T_H_W_D.device.type, dtype=torch.float32, enabled=use_fp32):
if self.use_adaln_lora:
assert adaln_lora_B_T_3D is not None
@@ -863,11 +863,11 @@ class Block(nn.Module):
emb_B_T_D: torch.Tensor,
crossattn_emb: torch.Tensor,
attn_params: attention.AttentionParams,
use_fp32: bool = False,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
use_fp32 = x_B_T_H_W_D.dtype == torch.float16
if use_fp32:
# Cast to float32 for better numerical stability in residual connections. Each module will cast back to float16 by enclosing autocast context.
x_B_T_H_W_D = x_B_T_H_W_D.float()
@@ -959,6 +959,7 @@ class Block(nn.Module):
emb_B_T_D: torch.Tensor,
crossattn_emb: torch.Tensor,
attn_params: attention.AttentionParams,
use_fp32: bool = False,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
@@ -972,6 +973,7 @@ class Block(nn.Module):
emb_B_T_D,
crossattn_emb,
attn_params,
use_fp32,
rope_emb_L_1_1_D,
adaln_lora_B_T_3D,
extra_per_block_pos_emb,
@@ -994,6 +996,7 @@ class Block(nn.Module):
emb_B_T_D,
crossattn_emb,
attn_params,
use_fp32,
rope_emb_L_1_1_D,
adaln_lora_B_T_3D,
extra_per_block_pos_emb,
@@ -1007,6 +1010,7 @@ class Block(nn.Module):
emb_B_T_D,
crossattn_emb,
attn_params,
use_fp32,
rope_emb_L_1_1_D,
adaln_lora_B_T_3D,
extra_per_block_pos_emb,
@@ -1018,6 +1022,7 @@ class Block(nn.Module):
emb_B_T_D,
crossattn_emb,
attn_params,
use_fp32,
rope_emb_L_1_1_D,
adaln_lora_B_T_3D,
extra_per_block_pos_emb,
@@ -1338,16 +1343,19 @@ class Anima(nn.Module):
attn_params = attention.AttentionParams.create_attention_params(self.attn_mode, self.split_attn)
# Determine whether to use float32 for block computations based on input dtype (use float32 for better stability when input is float16)
use_fp32 = x_B_T_H_W_D.dtype == torch.float16
for block_idx, block in enumerate(self.blocks):
if self.blocks_to_swap:
self.offloader.wait_for_block(block_idx)
x_B_T_H_W_D = block(x_B_T_H_W_D, t_embedding_B_T_D, crossattn_emb, attn_params, **block_kwargs)
x_B_T_H_W_D = block(x_B_T_H_W_D, t_embedding_B_T_D, crossattn_emb, attn_params, use_fp32, **block_kwargs)
if self.blocks_to_swap:
self.offloader.submit_move_blocks(self.blocks, block_idx)
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D, use_fp32=use_fp32)
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)
return x_B_C_Tt_Hp_Wp

View File

@@ -0,0 +1,379 @@
import os
import glob
from typing import Any, List, Optional, Tuple, Union
import torch
import numpy as np
from transformers import AutoTokenizer
from library import train_util
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
GLM_TOKENIZER_ID = "THUDM/CogView4-6B"
class CogView4TokenizeStrategy(TokenizeStrategy):
def __init__(self, max_length: int = 512, tokenizer_cache_dir: Optional[str] = None) -> None:
self.max_length = max_length
self.tokenizer = self._load_tokenizer(AutoTokenizer, GLM_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
# Add special tokens if needed
self.tokenizer.pad_token = self.tokenizer.eos_token
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
text = [text] if isinstance(text, str) else text
# Tokenize with GLM tokenizer
tokens = self.tokenizer(
text,
max_length=self.max_length,
padding="max_length",
truncation=True,
return_tensors="pt"
)
input_ids = tokens["input_ids"]
attention_mask = tokens["attention_mask"]
return [input_ids, attention_mask]
class CogView4TextEncodingStrategy(TextEncodingStrategy):
def __init__(self, apply_attention_mask: bool = True) -> None:
"""
Args:
apply_attention_mask: Whether to apply attention mask during encoding.
"""
self.apply_attention_mask = apply_attention_mask
def encode_tokens(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
tokens: List[torch.Tensor],
apply_attention_mask: Optional[bool] = None,
) -> List[torch.Tensor]:
# supports single model inference
if apply_attention_mask is None:
apply_attention_mask = self.apply_attention_mask
# Get GLM model (should be the only model in the list)
glm_model = models[0]
input_ids = tokens[0]
attention_mask = tokens[1] if len(tokens) > 1 else None
# Move tensors to the correct device
device = glm_model.device
input_ids = input_ids.to(device)
if attention_mask is not None:
attention_mask = attention_mask.to(device)
# Get GLM model outputs
with torch.no_grad():
outputs = glm_model(
input_ids=input_ids,
attention_mask=attention_mask if apply_attention_mask else None,
output_hidden_states=True,
return_dict=True
)
# Get the last hidden state
hidden_states = outputs.hidden_states[-1] # [batch_size, seq_len, hidden_size]
# For compatibility with existing code, we'll return a list similar to the original
# but with GLM's hidden states instead of CLIP/T5 outputs
return [
hidden_states, # Replaces l_pooled
hidden_states, # Replaces t5_out (same tensor for now, can be modified if needed)
torch.zeros(hidden_states.shape[0], hidden_states.shape[1], 3, device=device), # txt_ids placeholder
attention_mask # attention mask
]
class CogView4TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
COGVIEW4_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_cogview4_te.npz"
def __init__(
self,
cache_to_disk: bool,
batch_size: int,
skip_disk_cache_validity_check: bool,
is_partial: bool = False,
apply_attention_mask: bool = True,
) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
self.apply_attention_mask = apply_attention_mask
self.warn_fp8_weights = False
def get_outputs_npz_path(self, image_abs_path: str) -> str:
return os.path.splitext(image_abs_path)[0] + CogView4TextEncoderOutputsCachingStrategy.COGVIEW4_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
def is_disk_cached_outputs_expected(self, npz_path: str):
if not self.cache_to_disk:
return False
if not os.path.exists(npz_path):
return False
if self.skip_disk_cache_validity_check:
return True
try:
npz = np.load(npz_path)
required_fields = ["hidden_states", "attention_mask", "apply_attention_mask"]
for field in required_fields:
if field not in npz:
return False
npz_apply_attention_mask = bool(npz["apply_attention_mask"])
if npz_apply_attention_mask != self.apply_attention_mask:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
logger.exception(e)
return False
return True
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
data = np.load(npz_path)
hidden_states = data["hidden_states"]
attention_mask = data["attention_mask"]
return [
hidden_states, # l_pooled replacement
hidden_states, # t5_out replacement
np.zeros((hidden_states.shape[0], hidden_states.shape[1], 3), dtype=np.float32), # txt_ids
attention_mask # attention mask
]
def cache_batch_outputs(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
):
if not self.warn_fp8_weights:
model_dtype = next(models[0].parameters()).dtype
if model_dtype == torch.float8_e4m3fn or model_dtype == torch.float8_e5m2:
logger.warning(
"Model is using fp8 weights for caching. This may affect the quality of the cached outputs."
" / モデルはfp8の重みを使用しています。これはキャッシュの品質に影響を与える可能性があります。"
)
self.warn_fp8_weights = True
captions = [info.caption for info in infos]
tokens_and_masks = tokenize_strategy.tokenize(captions)
with torch.no_grad():
hidden_states, _, _, attention_mask = text_encoding_strategy.encode_tokens(
tokenize_strategy, models, tokens_and_masks
)
if hidden_states.dtype == torch.bfloat16:
hidden_states = hidden_states.float()
hidden_states = hidden_states.cpu().numpy()
attention_mask = attention_mask.cpu().numpy() if attention_mask is not None else None
for i, info in enumerate(infos):
hidden_states_i = hidden_states[i]
attention_mask_i = attention_mask[i] if attention_mask is not None else None
if self.cache_to_disk and hasattr(info, 'text_encoder_outputs_npz'):
np.savez(
info.text_encoder_outputs_npz,
hidden_states=hidden_states_i,
attention_mask=attention_mask_i,
apply_attention_mask=self.apply_attention_mask,
)
else:
info.text_encoder_outputs = (hidden_states_i, hidden_states_i, np.zeros((hidden_states_i.shape[0], 3), dtype=np.float32), attention_mask_i)
class CogView4LatentsCachingStrategy(LatentsCachingStrategy):
COGVIEW4_LATENTS_NPZ_SUFFIX = "_cogview4.npz"
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
@property
def cache_suffix(self) -> str:
return CogView4LatentsCachingStrategy.COGVIEW4_LATENTS_NPZ_SUFFIX
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
"""Get the path for cached latents.
Args:
absolute_path: Absolute path to the source image
image_size: Tuple of (height, width) for the target resolution
Returns:
Path to the cached latents file
"""
return (
os.path.splitext(absolute_path)[0]
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
+ CogView4LatentsCachingStrategy.COGVIEW4_LATENTS_NPZ_SUFFIX
)
def is_disk_cached_latents_expected(
self,
bucket_reso: Tuple[int, int],
npz_path: str,
flip_aug: bool,
alpha_mask: bool
) -> bool:
"""Check if the latents are already cached and valid.
Args:
bucket_reso: Target resolution as (height, width)
npz_path: Path to the cached latents file
flip_aug: Whether flip augmentation was applied
alpha_mask: Whether alpha mask was used
Returns:
bool: True if valid cache exists, False otherwise
"""
# Using 8 as the default number of frames for compatibility
return self._default_is_disk_cached_latents_expected(
8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True
)
def load_latents_from_disk(
self,
npz_path: str,
bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
"""Load latents from disk.
Args:
npz_path: Path to the cached latents file
bucket_reso: Target resolution as (height, width)
Returns:
Tuple containing:
- latents: The loaded latents or None if loading failed
- original_size: Original image size as [height, width]
- crop_top_left: Crop offset as [top, left]
- alpha_mask: Alpha mask if available
- alpha_mask_origin: Original alpha mask if available
"""
# Using 8 as the default number of frames for compatibility
return self._default_load_latents_from_disk(8, npz_path, bucket_reso)
def cache_batch_latents(
self,
vae: Any,
image_infos: List[Any],
flip_aug: bool,
alpha_mask: bool,
random_crop: bool
) -> None:
"""Cache a batch of latents.
Args:
vae: The VAE model used for encoding
image_infos: List of image information objects
flip_aug: Whether to apply flip augmentation
alpha_mask: Whether to use alpha mask
random_crop: Whether to apply random crop
"""
# Define encoding function that moves output to CPU
def encode_by_vae(img_tensor: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
return vae.encode(img_tensor).to("cpu")
# Get VAE device and dtype
vae_device = vae.device
vae_dtype = vae.dtype
# Cache latents using the default implementation
self._default_cache_batch_latents(
encode_by_vae,
vae_device,
vae_dtype,
image_infos,
flip_aug,
alpha_mask,
random_crop,
multi_resolution=True
)
# Clean up GPU memory if not in high VRAM mode
if not train_util.HIGH_VRAM:
train_util.clean_memory_on_device(vae.device)
if __name__ == "__main__":
# Test code for CogView4TokenizeStrategy
tokenizer = CogView4TokenizeStrategy(512)
text = "hello world"
# Test single text tokenization
input_ids, attention_mask = tokenizer.tokenize(text)
print("Input IDs:", input_ids)
print("Attention Mask:", attention_mask)
# Test batch tokenization
texts = ["hello world", "the quick brown fox jumps over the lazy dog"]
batch_input_ids, batch_attention_mask = tokenizer.tokenize(texts)
print("\nBatch Input IDs:", batch_input_ids.shape)
print("Batch Attention Mask:", batch_attention_mask.shape)
# Test with a long text
long_text = ",".join(["hello world! this is long text"] * 10)
long_input_ids, long_attention_mask = tokenizer.tokenize(long_text)
print("\nLong text input IDs shape:", long_input_ids.shape)
print("Long text attention mask shape:", long_attention_mask.shape)
# Test text encoding strategy
print("\nTesting text encoding strategy...")
from transformers import AutoModel
# Load a small GLM model for testing
model = AutoModel.from_pretrained("THUDM/glm-10b-chinese", trust_remote_code=True)
model.eval()
encoding_strategy = CogView4TextEncodingStrategy()
tokens = tokenizer.tokenize(texts)
encoded = encoding_strategy.encode_tokens(tokenizer, [model], tokens)
print(f"Number of outputs: {len(encoded)}")
print(f"Hidden states shape: {encoded[0].shape}")
print(f"Attention mask shape: {encoded[3].shape if encoded[3] is not None else 'None'}")
# Test caching strategy
print("\nTesting caching strategy...")
import tempfile
import os
class DummyInfo:
def __init__(self, caption):
self.caption = caption
self.text_encoder_outputs_npz = tempfile.mktemp(suffix=".npz")
# Create test data
infos = [DummyInfo(text) for text in texts]
# Test caching
caching_strategy = CogView4TextEncoderOutputsCachingStrategy(
cache_to_disk=True,
batch_size=2,
skip_disk_cache_validity_check=False
)
# Cache the outputs
caching_strategy.cache_batch_outputs(tokenizer, [model], encoding_strategy, infos)
# Check if files were created
for info in infos:
exists = os.path.exists(info.text_encoder_outputs_npz)
print(f"Cache file {info.text_encoder_outputs_npz} exists: {exists}")
# Clean up
for info in infos:
if os.path.exists(info.text_encoder_outputs_npz):
os.remove(info.text_encoder_outputs_npz)