mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 06:28:48 +00:00
Merge 201e1997a2 into 1dae34b0af
This commit is contained in:
617
cogview4_train_network.py
Normal file
617
cogview4_train_network.py
Normal file
@@ -0,0 +1,617 @@
|
||||
import argparse
|
||||
import copy
|
||||
import math
|
||||
import random
|
||||
import os
|
||||
import inspect
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from diffusers import AutoencoderKL, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.pipelines.cogview4.pipeline_cogview4 import CogView4Pipeline
|
||||
from PIL import Image
|
||||
from transformers import AutoTokenizer, GlmModel
|
||||
|
||||
from library.device_utils import clean_memory_on_device, init_ipex
|
||||
|
||||
init_ipex()
|
||||
|
||||
import train_network
|
||||
from library import (
|
||||
flux_models,
|
||||
flux_train_utils,
|
||||
flux_utils,
|
||||
sd3_train_utils,
|
||||
strategy_base,
|
||||
strategy_cogview4,
|
||||
train_util,
|
||||
)
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CogView4NetworkTrainer(train_network.NetworkTrainer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.sample_prompts_te_outputs = None
|
||||
self.is_swapping_blocks: bool = False
|
||||
|
||||
def assert_extra_args(
|
||||
self,
|
||||
args,
|
||||
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
|
||||
val_dataset_group: Optional[train_util.DatasetGroup],
|
||||
):
|
||||
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
|
||||
|
||||
# CogView4 specific argument validation
|
||||
if hasattr(args, 'fp8_base_unet') and args.fp8_base_unet:
|
||||
logger.warning("FP8 training is not yet fully supported for CogView4. Disabling fp8_base_unet.")
|
||||
args.fp8_base_unet = False
|
||||
|
||||
if hasattr(args, 'cache_text_encoder_outputs') and args.cache_text_encoder_outputs:
|
||||
logger.warning("Text encoder output caching is not yet implemented for CogView4. Disabling.")
|
||||
args.cache_text_encoder_outputs = False
|
||||
|
||||
if hasattr(args, 'cache_text_encoder_outputs_to_disk') and args.cache_text_encoder_outputs_to_disk:
|
||||
logger.warning("Text encoder output disk caching is not yet implemented for CogView4. Disabling.")
|
||||
args.cache_text_encoder_outputs_to_disk = False
|
||||
|
||||
# Set default values for CogView4
|
||||
if not hasattr(args, 'max_token_length'):
|
||||
args.max_token_length = 128 # Default token length for GLM
|
||||
|
||||
if not hasattr(args, 'resolution'):
|
||||
args.resolution = 256 # Default resolution for CogView4
|
||||
|
||||
# Update dataset resolution if needed
|
||||
train_dataset_group.set_resolution(args.resolution)
|
||||
if val_dataset_group is not None:
|
||||
val_dataset_group.verify_bucket_reso_steps(32) # TODO check this
|
||||
|
||||
def load_target_model(self, args, weight_dtype, accelerator):
|
||||
"""
|
||||
Load the CogView4 model components including tokenizer, text encoder, VAE, and transformer.
|
||||
"""
|
||||
logger.info(f"Loading CogView4 model from {args.pretrained_model_name_or_path}")
|
||||
|
||||
# Load tokenizer and text encoder (GLM)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="tokenizer",
|
||||
use_fast=False,
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
self.text_encoder = GlmModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="text_encoder",
|
||||
torch_dtype=weight_dtype,
|
||||
trust_remote_code=True
|
||||
)
|
||||
self.text_encoder.eval()
|
||||
|
||||
# Load VAE
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="vae",
|
||||
torch_dtype=weight_dtype,
|
||||
trust_remote_code=True
|
||||
)
|
||||
self.vae.eval()
|
||||
|
||||
# Load transformer
|
||||
self.transformer = CogView4Transformer2DModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="transformer",
|
||||
torch_dtype=weight_dtype,
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
# Create noise scheduler
|
||||
self.noise_scheduler = FlowMatchEulerDiscreteScheduler(
|
||||
num_train_timesteps=1000,
|
||||
beta_start=0.0001,
|
||||
beta_end=0.02,
|
||||
beta_schedule="linear",
|
||||
prediction_type="epsilon",
|
||||
)
|
||||
|
||||
# Move models to device
|
||||
device = accelerator.device
|
||||
self.text_encoder = self.text_encoder.to(device)
|
||||
self.vae = self.vae.to(device)
|
||||
self.transformer = self.transformer.to(device)
|
||||
|
||||
# Set gradient checkpointing if enabled
|
||||
if args.gradient_checkpointing:
|
||||
self.transformer.enable_gradient_checkpointing()
|
||||
# Text encoder gradient checkpointing is handled by prepare_text_encoder_grad_ckpt_workaround
|
||||
# called by the base trainer if args.train_text_encoder is true for that TE.
|
||||
|
||||
# Store components for later use
|
||||
self.weight_dtype = weight_dtype
|
||||
|
||||
# Return components in the expected format
|
||||
return "cogview4-v1", [self.text_encoder], self.vae, self.transformer
|
||||
|
||||
def get_tokenize_strategy(self, args):
|
||||
# For CogView4, we use a fixed token length for GLM
|
||||
max_token_length = getattr(args, 'max_token_length', 128)
|
||||
logger.info(f"Using max_token_length: {max_token_length} for GLM tokenizer")
|
||||
return strategy_cogview4.CogView4TokenizeStrategy(max_token_length, args.tokenizer_cache_dir)
|
||||
|
||||
def get_tokenizers(self, tokenize_strategy):
|
||||
# For CogView4, we only have one tokenizer (GLM)
|
||||
return [tokenize_strategy.tokenizer]
|
||||
|
||||
def get_latents_caching_strategy(self, args):
|
||||
return strategy_cogview4.CogView4LatentsCachingStrategy(
|
||||
args.cache_latents_to_disk,
|
||||
args.vae_batch_size,
|
||||
skip_disk_cache_validity_check=False
|
||||
)
|
||||
|
||||
def get_text_encoding_strategy(self, args):
|
||||
# For CogView4, we use GLM instead of T5, but maintain similar interface
|
||||
return strategy_cogview4.CogView4TextEncodingStrategy(
|
||||
apply_attention_mask=getattr(args, 'apply_attention_mask', True)
|
||||
)
|
||||
|
||||
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
|
||||
# For CogView4, we always return the text encoder (GLM) as it's needed for encoding
|
||||
return text_encoders
|
||||
|
||||
def get_text_encoders_train_flags(self, args, text_encoders):
|
||||
# For CogView4, we only have one text encoder (GLM)
|
||||
return [getattr(args, 'train_text_encoder', False)]
|
||||
|
||||
def get_text_encoder_outputs_caching_strategy(self, args):
|
||||
if getattr(args, 'cache_text_encoder_outputs', False):
|
||||
return strategy_cogview4.CogView4TextEncoderOutputsCachingStrategy(
|
||||
cache_to_disk=getattr(args, 'cache_text_encoder_outputs_to_disk', False),
|
||||
batch_size=getattr(args, 'text_encoder_batch_size', 1),
|
||||
skip_disk_cache_validity_check=getattr(args, 'skip_cache_check', False),
|
||||
is_partial=getattr(args, 'train_text_encoder', False)
|
||||
)
|
||||
return None
|
||||
|
||||
def cache_text_encoder_outputs_if_needed(
|
||||
self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
|
||||
):
|
||||
"""Cache text encoder outputs to speed up training.
|
||||
|
||||
Args:
|
||||
args: Training arguments
|
||||
accelerator: Accelerator instance
|
||||
unet: UNet model
|
||||
vae: VAE model
|
||||
text_encoders: List containing the GLM text encoder
|
||||
dataset: Dataset to cache text encoder outputs for
|
||||
weight_dtype: Data type for weights
|
||||
"""
|
||||
if getattr(args, 'cache_text_encoder_outputs', False):
|
||||
if not getattr(args, 'lowram', False):
|
||||
# Free up GPU memory by moving models to CPU
|
||||
logger.info("Moving VAE and UNet to CPU to save memory")
|
||||
org_vae_device = vae.device
|
||||
org_unet_device = unet.device
|
||||
vae.to("cpu")
|
||||
unet.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
# Move text encoder to GPU with proper dtype
|
||||
logger.info("Moving text encoder to GPU")
|
||||
text_encoder = text_encoders[0] # CogView4 uses a single text encoder (GLM)
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# Cache text encoder outputs
|
||||
with accelerator.autocast():
|
||||
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
|
||||
|
||||
# Cache sample prompts if provided
|
||||
if getattr(args, 'sample_prompts', None) is not None:
|
||||
logger.info(f"Caching text encoder outputs for sample prompts: {args.sample_prompts}")
|
||||
|
||||
# Initialize CogView4 strategies
|
||||
tokenize_strategy = strategy_cogview4.CogView4TokenizeStrategy(
|
||||
max_length=getattr(args, 'max_token_length', 128),
|
||||
tokenizer_cache_dir=getattr(args, 'tokenizer_cache_dir', None)
|
||||
)
|
||||
text_encoding_strategy = strategy_cogview4.CogView4TextEncodingStrategy(
|
||||
apply_attention_mask=getattr(args, 'apply_attention_mask', True)
|
||||
)
|
||||
|
||||
prompts = train_util.load_prompts(args.sample_prompts)
|
||||
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
|
||||
|
||||
with accelerator.autocast(), torch.no_grad():
|
||||
for prompt_dict in prompts:
|
||||
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
|
||||
if p and p not in sample_prompts_te_outputs: # Skip empty prompts and duplicates
|
||||
logger.info(f"Caching text encoder outputs for prompt: {p}")
|
||||
tokens = tokenize_strategy.tokenize(p)
|
||||
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy=tokenize_strategy,
|
||||
models=text_encoders,
|
||||
tokens=tokens,
|
||||
apply_attention_mask=getattr(args, 'apply_attention_mask', True)
|
||||
)
|
||||
self.sample_prompts_te_outputs = sample_prompts_te_outputs
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# Move text encoder back to CPU if not training it
|
||||
if not getattr(args, 'train_text_encoder', False):
|
||||
logger.info("Moving text encoder back to CPU")
|
||||
text_encoder.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
# Move VAE and UNet back to their original devices if not in lowram mode
|
||||
if not getattr(args, 'lowram', False):
|
||||
logger.info("Moving VAE and UNet back to original devices")
|
||||
vae.to(org_vae_device)
|
||||
unet.to(org_unet_device)
|
||||
else:
|
||||
# Keep text encoder in GPU if we're not caching outputs
|
||||
if text_encoders:
|
||||
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
|
||||
# noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
||||
|
||||
# # get size embeddings
|
||||
# orig_size = batch["original_sizes_hw"]
|
||||
# crop_size = batch["crop_top_lefts"]
|
||||
# target_size = batch["target_sizes_hw"]
|
||||
# embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
|
||||
|
||||
# # concat embeddings
|
||||
# encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds
|
||||
# vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
|
||||
# text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
|
||||
|
||||
# noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
|
||||
# return noise_pred
|
||||
|
||||
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, transformer):
|
||||
"""
|
||||
Generate sample images during training to monitor progress.
|
||||
"""
|
||||
logger.info(f"Generating sample images at step {global_step}")
|
||||
|
||||
# Set models to eval mode
|
||||
was_training = transformer.training
|
||||
transformer.eval()
|
||||
vae.eval()
|
||||
|
||||
# Sample prompts to use for generation
|
||||
sample_prompts = [
|
||||
"A high quality photo of a cat",
|
||||
"A beautiful landscape with mountains and a lake",
|
||||
"A futuristic city at night"
|
||||
]
|
||||
|
||||
# Generate images for each prompt
|
||||
all_images = []
|
||||
|
||||
with torch.no_grad():
|
||||
for prompt in sample_prompts:
|
||||
# Tokenize the prompt
|
||||
text_input = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
input_ids = text_input.input_ids.to(device)
|
||||
attention_mask = text_input.attention_mask.to(device)
|
||||
|
||||
# Get text embeddings
|
||||
with torch.no_grad():
|
||||
text_embeddings = self.text_encoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
return_dict=True
|
||||
).last_hidden_state
|
||||
|
||||
# Sample random noise
|
||||
latents = torch.randn(
|
||||
(1, 4, args.resolution // 8, args.resolution // 8),
|
||||
device=device,
|
||||
dtype=torch.float32
|
||||
)
|
||||
|
||||
# Set the scheduler for inference
|
||||
self.noise_scheduler.set_timesteps(50, device=device)
|
||||
|
||||
# Generate image using the denoising process
|
||||
for t in self.noise_scheduler.timesteps:
|
||||
# Expand the latents if we are doing classifier-free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if args.guidance_scale > 1.0 else latents
|
||||
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# Predict the noise residual
|
||||
with torch.no_grad():
|
||||
noise_pred = transformer(
|
||||
latent_model_input,
|
||||
t.unsqueeze(0).repeat(latent_model_input.shape[0]),
|
||||
encoder_hidden_states=torch.cat([text_embeddings] * 2) if args.guidance_scale > 1.0 else text_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
return_dict=True,
|
||||
).sample
|
||||
|
||||
# Perform guidance
|
||||
if args.guidance_scale > 1.0:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# Compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.noise_scheduler.step(noise_pred, t, latents).prev_sample
|
||||
|
||||
# Scale and decode the image latents with vae
|
||||
latents = 1 / 0.18215 * latents
|
||||
with torch.no_grad():
|
||||
image = vae.decode(latents).sample
|
||||
|
||||
# Convert to PIL image
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
|
||||
images = (image * 255).round().astype("uint8")
|
||||
pil_images = [Image.fromarray(image) for image in images]
|
||||
all_images.extend(pil_images)
|
||||
|
||||
# Log images to tensorboard if available
|
||||
if accelerator.is_main_process and hasattr(accelerator, "log"):
|
||||
log_images = []
|
||||
for i, img in enumerate(all_images):
|
||||
# Convert PIL image to numpy for logging
|
||||
log_images.append(np.array(img))
|
||||
|
||||
# Save individual images
|
||||
os.makedirs(os.path.join(args.output_dir, "samples"), exist_ok=True)
|
||||
img.save(os.path.join(args.output_dir, "samples", f"sample_epoch{epoch}_step{global_step}_{i}.png"))
|
||||
|
||||
# Log to tensorboard
|
||||
accelerator.log({
|
||||
"samples": [
|
||||
wandb.Image(img, caption=f"{sample_prompts[i]}")
|
||||
for i, img in enumerate(log_images)
|
||||
]
|
||||
}, step=global_step)
|
||||
|
||||
# Set models back to training mode if they were training before
|
||||
if was_training:
|
||||
transformer.train()
|
||||
vae.train()
|
||||
|
||||
return all_images
|
||||
|
||||
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
|
||||
# Return the noise scheduler that was created during model loading
|
||||
return self.noise_scheduler
|
||||
|
||||
def encode_images_to_latents(self, args, vae, images):
|
||||
return vae.encode(images)
|
||||
|
||||
def shift_scale_latents(self, args, latents):
|
||||
return latents
|
||||
|
||||
def get_noise_pred_and_target(
|
||||
self,
|
||||
args,
|
||||
accelerator,
|
||||
noise_scheduler,
|
||||
latents,
|
||||
batch,
|
||||
text_encoder_conds,
|
||||
transformer: CogView4Transformer2DModel,
|
||||
network,
|
||||
weight_dtype,
|
||||
train_unet=True,
|
||||
is_train=True,
|
||||
):
|
||||
"""
|
||||
Get noise prediction and target for the loss computation.
|
||||
"""
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(
|
||||
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
|
||||
).long()
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# Get text embeddings
|
||||
input_ids = batch["input_ids"]
|
||||
attention_mask = batch.get("attention_mask", None)
|
||||
|
||||
# Prepare text encoder outputs
|
||||
with torch.set_grad_enabled(self.train_text_encoder):
|
||||
text_embeddings = self.text_encoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
return_dict=True
|
||||
).last_hidden_state
|
||||
|
||||
# Predict the noise residual
|
||||
with torch.set_grad_enabled(train_unet):
|
||||
# Add conditioning dropout for classifier-free guidance
|
||||
if args.guidance_scale > 1.0:
|
||||
# Randomly drop text conditioning 5% of the time
|
||||
mask = (torch.rand(bsz, device=latents.device) < 0.05).float().unsqueeze(1).unsqueeze(1)
|
||||
text_embeddings = text_embeddings * (1 - mask) + torch.zeros_like(text_embeddings) * mask
|
||||
|
||||
# Predict noise
|
||||
model_pred = transformer(
|
||||
noisy_latents,
|
||||
timesteps,
|
||||
encoder_hidden_states=text_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
return_dict=True,
|
||||
).sample
|
||||
|
||||
# Calculate target
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
target = noise
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
|
||||
# For classifier-free guidance, we need to do two forward passes
|
||||
if args.guidance_scale > 1.0:
|
||||
# Predict the conditional and unconditional outputs
|
||||
model_pred_uncond = transformer(
|
||||
noisy_latents,
|
||||
timesteps,
|
||||
encoder_hidden_states=torch.zeros_like(text_embeddings),
|
||||
attention_mask=attention_mask,
|
||||
return_dict=True,
|
||||
).sample
|
||||
|
||||
# Perform classifier-free guidance
|
||||
model_pred = model_pred_uncond + args.guidance_scale * (model_pred - model_pred_uncond)
|
||||
|
||||
# For training, we only compute the loss on the conditional prediction
|
||||
if is_train:
|
||||
model_pred = model_pred_uncond + args.guidance_scale * (model_pred - model_pred_uncond)
|
||||
|
||||
# Simple weighting - can be adjusted based on timestep if needed
|
||||
weighting = torch.ones_like(timesteps, dtype=weight_dtype, device=latents.device)
|
||||
|
||||
return model_pred, target, timesteps, weighting
|
||||
|
||||
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
||||
"""
|
||||
Post-process the loss value.
|
||||
This can include applying timestep weighting, gradient clipping, etc.
|
||||
"""
|
||||
# Apply timestep weighting if specified
|
||||
if hasattr(args, 'timestep_bias_portion') and args.timestep_bias_portion > 0.0:
|
||||
# Simple timestep weighting - can be made more sophisticated if needed
|
||||
weights = torch.ones_like(timesteps, dtype=torch.float32)
|
||||
if hasattr(args, 'timestep_bias_begin') and args.timestep_bias_begin > 0:
|
||||
mask = timesteps < args.timestep_bias_begin
|
||||
weights[mask] = 0.0
|
||||
if hasattr(args, 'timestep_bias_end') and args.timestep_bias_end < 1000:
|
||||
mask = timesteps > args.timestep_bias_end
|
||||
weights[mask] = 0.0
|
||||
if hasattr(args, 'timestep_bias_multiplier') and args.timestep_bias_multiplier != 1.0:
|
||||
weights = weights * args.timestep_bias_multiplier
|
||||
|
||||
loss = loss * weights.to(loss.device)
|
||||
|
||||
# Clip loss values if specified
|
||||
if hasattr(args, 'clip_grad_norm') and args.clip_grad_norm > 0.0:
|
||||
loss = torch.clamp(loss, -args.clip_grad_norm, args.clip_grad_norm)
|
||||
|
||||
return loss
|
||||
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
"""
|
||||
Prepare extra kwargs for the scheduler step, such as the generator for reproducibility.
|
||||
"""
|
||||
# Prepare extra step kwargs.
|
||||
# TODO: Logic should ideally just be moved to base class
|
||||
accepts_eta = "eta" in set(inspect.signature(self.noise_scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# Check if the scheduler accepts a generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.noise_scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
|
||||
return extra_step_kwargs
|
||||
|
||||
def update_metadata(self, metadata, args):
|
||||
metadata["ss_apply_attn_mask"] = args.apply_attn_mask
|
||||
metadata["ss_weighting_scheme"] = args.weighting_scheme
|
||||
metadata["ss_logit_mean"] = args.logit_mean
|
||||
metadata["ss_logit_std"] = args.logit_std
|
||||
metadata["ss_mode_scale"] = args.mode_scale
|
||||
metadata["ss_guidance_scale"] = args.guidance_scale
|
||||
metadata["ss_timestep_sampling"] = args.timestep_sampling
|
||||
metadata["ss_sigmoid_scale"] = args.sigmoid_scale
|
||||
metadata["ss_model_prediction_type"] = args.model_prediction_type
|
||||
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
|
||||
|
||||
def is_text_encoder_not_needed_for_training(self, args):
|
||||
"""Check if text encoder outputs are cached and not being trained."""
|
||||
return getattr(args, 'cache_text_encoder_outputs', False) and not getattr(args, 'train_text_encoder', False)
|
||||
|
||||
def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
|
||||
"""Prepare text encoder for gradient checkpointing.
|
||||
|
||||
For CogView4, we only have one text encoder (GLM) so we don't need index-based handling.
|
||||
The base class method handles enabling gradient checkpointing if args specify it.
|
||||
"""
|
||||
return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder)
|
||||
|
||||
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
|
||||
"""Prepare text encoder for FP8 training.
|
||||
|
||||
Args:
|
||||
index: Text encoder index (always 0 for CogView4)
|
||||
text_encoder: The text encoder model (GLM)
|
||||
te_weight_dtype: Target weight dtype for the encoder
|
||||
weight_dtype: Base weight dtype for embeddings
|
||||
"""
|
||||
if index != 0:
|
||||
logger.warning(f"Unexpected text encoder index {index} for CogView4, expecting 0.")
|
||||
# Still proceed, assuming it's the single GLM encoder
|
||||
|
||||
logger.info(f"Preparing GLM text encoder (index {index}) for {te_weight_dtype}, embeddings to {weight_dtype}")
|
||||
text_encoder.to(te_weight_dtype)
|
||||
|
||||
# Move embeddings to base weight dtype if they exist
|
||||
if hasattr(text_encoder, 'word_embeddings'): # GLM typically has word_embeddings
|
||||
text_encoder.word_embeddings.to(dtype=weight_dtype)
|
||||
if hasattr(text_encoder, 'position_embeddings'): # GLM might have position_embeddings
|
||||
text_encoder.position_embeddings.to(dtype=weight_dtype)
|
||||
# Add other relevant parts of GLM if they need specific dtype handling for FP8
|
||||
return text_encoder
|
||||
|
||||
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
|
||||
"""Called at the end of each validation step."""
|
||||
# No special handling needed for CogView4 (e.g., no block swapping)
|
||||
pass
|
||||
|
||||
def prepare_unet_with_accelerator(
|
||||
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
|
||||
) -> torch.nn.Module:
|
||||
"""Prepare UNet model (CogView4Transformer2DModel) with accelerator.
|
||||
|
||||
For CogView4, we use standard model preparation.
|
||||
"""
|
||||
return super().prepare_unet_with_accelerator(args, accelerator, unet)
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = train_network.setup_parser()
|
||||
train_util.add_dit_training_arguments(parser)
|
||||
flux_train_utils.add_flux_train_arguments(parser)
|
||||
return parser
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
train_util.verify_command_line_training_args(args)
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
trainer = CogView4NetworkTrainer()
|
||||
trainer.train(args)
|
||||
379
library/strategy_cogview4.py
Normal file
379
library/strategy_cogview4.py
Normal file
@@ -0,0 +1,379 @@
|
||||
import os
|
||||
import glob
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
import torch
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from library import train_util
|
||||
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
|
||||
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
GLM_TOKENIZER_ID = "THUDM/CogView4-6B"
|
||||
|
||||
|
||||
class CogView4TokenizeStrategy(TokenizeStrategy):
|
||||
def __init__(self, max_length: int = 512, tokenizer_cache_dir: Optional[str] = None) -> None:
|
||||
self.max_length = max_length
|
||||
self.tokenizer = self._load_tokenizer(AutoTokenizer, GLM_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
|
||||
# Add special tokens if needed
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
|
||||
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
|
||||
text = [text] if isinstance(text, str) else text
|
||||
|
||||
# Tokenize with GLM tokenizer
|
||||
tokens = self.tokenizer(
|
||||
text,
|
||||
max_length=self.max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt"
|
||||
)
|
||||
|
||||
input_ids = tokens["input_ids"]
|
||||
attention_mask = tokens["attention_mask"]
|
||||
|
||||
return [input_ids, attention_mask]
|
||||
|
||||
|
||||
class CogView4TextEncodingStrategy(TextEncodingStrategy):
|
||||
def __init__(self, apply_attention_mask: bool = True) -> None:
|
||||
"""
|
||||
Args:
|
||||
apply_attention_mask: Whether to apply attention mask during encoding.
|
||||
"""
|
||||
self.apply_attention_mask = apply_attention_mask
|
||||
|
||||
def encode_tokens(
|
||||
self,
|
||||
tokenize_strategy: TokenizeStrategy,
|
||||
models: List[Any],
|
||||
tokens: List[torch.Tensor],
|
||||
apply_attention_mask: Optional[bool] = None,
|
||||
) -> List[torch.Tensor]:
|
||||
# supports single model inference
|
||||
if apply_attention_mask is None:
|
||||
apply_attention_mask = self.apply_attention_mask
|
||||
|
||||
# Get GLM model (should be the only model in the list)
|
||||
glm_model = models[0]
|
||||
input_ids = tokens[0]
|
||||
attention_mask = tokens[1] if len(tokens) > 1 else None
|
||||
|
||||
# Move tensors to the correct device
|
||||
device = glm_model.device
|
||||
input_ids = input_ids.to(device)
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(device)
|
||||
|
||||
# Get GLM model outputs
|
||||
with torch.no_grad():
|
||||
outputs = glm_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask if apply_attention_mask else None,
|
||||
output_hidden_states=True,
|
||||
return_dict=True
|
||||
)
|
||||
|
||||
# Get the last hidden state
|
||||
hidden_states = outputs.hidden_states[-1] # [batch_size, seq_len, hidden_size]
|
||||
|
||||
# For compatibility with existing code, we'll return a list similar to the original
|
||||
# but with GLM's hidden states instead of CLIP/T5 outputs
|
||||
return [
|
||||
hidden_states, # Replaces l_pooled
|
||||
hidden_states, # Replaces t5_out (same tensor for now, can be modified if needed)
|
||||
torch.zeros(hidden_states.shape[0], hidden_states.shape[1], 3, device=device), # txt_ids placeholder
|
||||
attention_mask # attention mask
|
||||
]
|
||||
|
||||
|
||||
class CogView4TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
COGVIEW4_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_cogview4_te.npz"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_to_disk: bool,
|
||||
batch_size: int,
|
||||
skip_disk_cache_validity_check: bool,
|
||||
is_partial: bool = False,
|
||||
apply_attention_mask: bool = True,
|
||||
) -> None:
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
|
||||
self.apply_attention_mask = apply_attention_mask
|
||||
self.warn_fp8_weights = False
|
||||
|
||||
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
||||
return os.path.splitext(image_abs_path)[0] + CogView4TextEncoderOutputsCachingStrategy.COGVIEW4_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
|
||||
|
||||
def is_disk_cached_outputs_expected(self, npz_path: str):
|
||||
if not self.cache_to_disk:
|
||||
return False
|
||||
if not os.path.exists(npz_path):
|
||||
return False
|
||||
if self.skip_disk_cache_validity_check:
|
||||
return True
|
||||
|
||||
try:
|
||||
npz = np.load(npz_path)
|
||||
required_fields = ["hidden_states", "attention_mask", "apply_attention_mask"]
|
||||
for field in required_fields:
|
||||
if field not in npz:
|
||||
return False
|
||||
|
||||
npz_apply_attention_mask = bool(npz["apply_attention_mask"])
|
||||
if npz_apply_attention_mask != self.apply_attention_mask:
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {npz_path}")
|
||||
logger.exception(e)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
|
||||
data = np.load(npz_path)
|
||||
hidden_states = data["hidden_states"]
|
||||
attention_mask = data["attention_mask"]
|
||||
return [
|
||||
hidden_states, # l_pooled replacement
|
||||
hidden_states, # t5_out replacement
|
||||
np.zeros((hidden_states.shape[0], hidden_states.shape[1], 3), dtype=np.float32), # txt_ids
|
||||
attention_mask # attention mask
|
||||
]
|
||||
|
||||
def cache_batch_outputs(
|
||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
|
||||
):
|
||||
if not self.warn_fp8_weights:
|
||||
model_dtype = next(models[0].parameters()).dtype
|
||||
if model_dtype == torch.float8_e4m3fn or model_dtype == torch.float8_e5m2:
|
||||
logger.warning(
|
||||
"Model is using fp8 weights for caching. This may affect the quality of the cached outputs."
|
||||
" / モデルはfp8の重みを使用しています。これはキャッシュの品質に影響を与える可能性があります。"
|
||||
)
|
||||
self.warn_fp8_weights = True
|
||||
|
||||
captions = [info.caption for info in infos]
|
||||
|
||||
tokens_and_masks = tokenize_strategy.tokenize(captions)
|
||||
|
||||
with torch.no_grad():
|
||||
hidden_states, _, _, attention_mask = text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, models, tokens_and_masks
|
||||
)
|
||||
|
||||
if hidden_states.dtype == torch.bfloat16:
|
||||
hidden_states = hidden_states.float()
|
||||
|
||||
hidden_states = hidden_states.cpu().numpy()
|
||||
attention_mask = attention_mask.cpu().numpy() if attention_mask is not None else None
|
||||
|
||||
for i, info in enumerate(infos):
|
||||
hidden_states_i = hidden_states[i]
|
||||
attention_mask_i = attention_mask[i] if attention_mask is not None else None
|
||||
|
||||
if self.cache_to_disk and hasattr(info, 'text_encoder_outputs_npz'):
|
||||
np.savez(
|
||||
info.text_encoder_outputs_npz,
|
||||
hidden_states=hidden_states_i,
|
||||
attention_mask=attention_mask_i,
|
||||
apply_attention_mask=self.apply_attention_mask,
|
||||
)
|
||||
else:
|
||||
info.text_encoder_outputs = (hidden_states_i, hidden_states_i, np.zeros((hidden_states_i.shape[0], 3), dtype=np.float32), attention_mask_i)
|
||||
|
||||
|
||||
class CogView4LatentsCachingStrategy(LatentsCachingStrategy):
|
||||
COGVIEW4_LATENTS_NPZ_SUFFIX = "_cogview4.npz"
|
||||
|
||||
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
|
||||
|
||||
@property
|
||||
def cache_suffix(self) -> str:
|
||||
return CogView4LatentsCachingStrategy.COGVIEW4_LATENTS_NPZ_SUFFIX
|
||||
|
||||
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
|
||||
"""Get the path for cached latents.
|
||||
|
||||
Args:
|
||||
absolute_path: Absolute path to the source image
|
||||
image_size: Tuple of (height, width) for the target resolution
|
||||
|
||||
Returns:
|
||||
Path to the cached latents file
|
||||
"""
|
||||
return (
|
||||
os.path.splitext(absolute_path)[0]
|
||||
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
|
||||
+ CogView4LatentsCachingStrategy.COGVIEW4_LATENTS_NPZ_SUFFIX
|
||||
)
|
||||
|
||||
def is_disk_cached_latents_expected(
|
||||
self,
|
||||
bucket_reso: Tuple[int, int],
|
||||
npz_path: str,
|
||||
flip_aug: bool,
|
||||
alpha_mask: bool
|
||||
) -> bool:
|
||||
"""Check if the latents are already cached and valid.
|
||||
|
||||
Args:
|
||||
bucket_reso: Target resolution as (height, width)
|
||||
npz_path: Path to the cached latents file
|
||||
flip_aug: Whether flip augmentation was applied
|
||||
alpha_mask: Whether alpha mask was used
|
||||
|
||||
Returns:
|
||||
bool: True if valid cache exists, False otherwise
|
||||
"""
|
||||
# Using 8 as the default number of frames for compatibility
|
||||
return self._default_is_disk_cached_latents_expected(
|
||||
8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True
|
||||
)
|
||||
|
||||
def load_latents_from_disk(
|
||||
self,
|
||||
npz_path: str,
|
||||
bucket_reso: Tuple[int, int]
|
||||
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
||||
"""Load latents from disk.
|
||||
|
||||
Args:
|
||||
npz_path: Path to the cached latents file
|
||||
bucket_reso: Target resolution as (height, width)
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- latents: The loaded latents or None if loading failed
|
||||
- original_size: Original image size as [height, width]
|
||||
- crop_top_left: Crop offset as [top, left]
|
||||
- alpha_mask: Alpha mask if available
|
||||
- alpha_mask_origin: Original alpha mask if available
|
||||
"""
|
||||
# Using 8 as the default number of frames for compatibility
|
||||
return self._default_load_latents_from_disk(8, npz_path, bucket_reso)
|
||||
|
||||
def cache_batch_latents(
|
||||
self,
|
||||
vae: Any,
|
||||
image_infos: List[Any],
|
||||
flip_aug: bool,
|
||||
alpha_mask: bool,
|
||||
random_crop: bool
|
||||
) -> None:
|
||||
"""Cache a batch of latents.
|
||||
|
||||
Args:
|
||||
vae: The VAE model used for encoding
|
||||
image_infos: List of image information objects
|
||||
flip_aug: Whether to apply flip augmentation
|
||||
alpha_mask: Whether to use alpha mask
|
||||
random_crop: Whether to apply random crop
|
||||
"""
|
||||
# Define encoding function that moves output to CPU
|
||||
def encode_by_vae(img_tensor: torch.Tensor) -> torch.Tensor:
|
||||
with torch.no_grad():
|
||||
return vae.encode(img_tensor).to("cpu")
|
||||
|
||||
# Get VAE device and dtype
|
||||
vae_device = vae.device
|
||||
vae_dtype = vae.dtype
|
||||
|
||||
# Cache latents using the default implementation
|
||||
self._default_cache_batch_latents(
|
||||
encode_by_vae,
|
||||
vae_device,
|
||||
vae_dtype,
|
||||
image_infos,
|
||||
flip_aug,
|
||||
alpha_mask,
|
||||
random_crop,
|
||||
multi_resolution=True
|
||||
)
|
||||
|
||||
# Clean up GPU memory if not in high VRAM mode
|
||||
if not train_util.HIGH_VRAM:
|
||||
train_util.clean_memory_on_device(vae.device)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Test code for CogView4TokenizeStrategy
|
||||
tokenizer = CogView4TokenizeStrategy(512)
|
||||
text = "hello world"
|
||||
|
||||
# Test single text tokenization
|
||||
input_ids, attention_mask = tokenizer.tokenize(text)
|
||||
print("Input IDs:", input_ids)
|
||||
print("Attention Mask:", attention_mask)
|
||||
|
||||
# Test batch tokenization
|
||||
texts = ["hello world", "the quick brown fox jumps over the lazy dog"]
|
||||
batch_input_ids, batch_attention_mask = tokenizer.tokenize(texts)
|
||||
print("\nBatch Input IDs:", batch_input_ids.shape)
|
||||
print("Batch Attention Mask:", batch_attention_mask.shape)
|
||||
|
||||
# Test with a long text
|
||||
long_text = ",".join(["hello world! this is long text"] * 10)
|
||||
long_input_ids, long_attention_mask = tokenizer.tokenize(long_text)
|
||||
print("\nLong text input IDs shape:", long_input_ids.shape)
|
||||
print("Long text attention mask shape:", long_attention_mask.shape)
|
||||
|
||||
# Test text encoding strategy
|
||||
print("\nTesting text encoding strategy...")
|
||||
from transformers import AutoModel
|
||||
|
||||
# Load a small GLM model for testing
|
||||
model = AutoModel.from_pretrained("THUDM/glm-10b-chinese", trust_remote_code=True)
|
||||
model.eval()
|
||||
|
||||
encoding_strategy = CogView4TextEncodingStrategy()
|
||||
tokens = tokenizer.tokenize(texts)
|
||||
encoded = encoding_strategy.encode_tokens(tokenizer, [model], tokens)
|
||||
|
||||
print(f"Number of outputs: {len(encoded)}")
|
||||
print(f"Hidden states shape: {encoded[0].shape}")
|
||||
print(f"Attention mask shape: {encoded[3].shape if encoded[3] is not None else 'None'}")
|
||||
|
||||
# Test caching strategy
|
||||
print("\nTesting caching strategy...")
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
class DummyInfo:
|
||||
def __init__(self, caption):
|
||||
self.caption = caption
|
||||
self.text_encoder_outputs_npz = tempfile.mktemp(suffix=".npz")
|
||||
|
||||
# Create test data
|
||||
infos = [DummyInfo(text) for text in texts]
|
||||
|
||||
# Test caching
|
||||
caching_strategy = CogView4TextEncoderOutputsCachingStrategy(
|
||||
cache_to_disk=True,
|
||||
batch_size=2,
|
||||
skip_disk_cache_validity_check=False
|
||||
)
|
||||
|
||||
# Cache the outputs
|
||||
caching_strategy.cache_batch_outputs(tokenizer, [model], encoding_strategy, infos)
|
||||
|
||||
# Check if files were created
|
||||
for info in infos:
|
||||
exists = os.path.exists(info.text_encoder_outputs_npz)
|
||||
print(f"Cache file {info.text_encoder_outputs_npz} exists: {exists}")
|
||||
|
||||
# Clean up
|
||||
for info in infos:
|
||||
if os.path.exists(info.text_encoder_outputs_npz):
|
||||
os.remove(info.text_encoder_outputs_npz)
|
||||
Reference in New Issue
Block a user