mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 06:54:17 +00:00
618 lines
26 KiB
Python
618 lines
26 KiB
Python
import argparse
|
|
import copy
|
|
import math
|
|
import random
|
|
import os
|
|
import inspect
|
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
from accelerate import Accelerator
|
|
from diffusers import AutoencoderKL, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
|
|
from diffusers.pipelines.cogview4.pipeline_cogview4 import CogView4Pipeline
|
|
from PIL import Image
|
|
from transformers import AutoTokenizer, GlmModel
|
|
|
|
from library.device_utils import clean_memory_on_device, init_ipex
|
|
|
|
init_ipex()
|
|
|
|
import train_network
|
|
from library import (
|
|
flux_models,
|
|
flux_train_utils,
|
|
flux_utils,
|
|
sd3_train_utils,
|
|
strategy_base,
|
|
strategy_cogview4,
|
|
train_util,
|
|
)
|
|
from library.utils import setup_logging
|
|
|
|
setup_logging()
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class CogView4NetworkTrainer(train_network.NetworkTrainer):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.sample_prompts_te_outputs = None
|
|
self.is_swapping_blocks: bool = False
|
|
|
|
def assert_extra_args(
|
|
self,
|
|
args,
|
|
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
|
|
val_dataset_group: Optional[train_util.DatasetGroup],
|
|
):
|
|
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
|
|
|
|
# CogView4 specific argument validation
|
|
if hasattr(args, 'fp8_base_unet') and args.fp8_base_unet:
|
|
logger.warning("FP8 training is not yet fully supported for CogView4. Disabling fp8_base_unet.")
|
|
args.fp8_base_unet = False
|
|
|
|
if hasattr(args, 'cache_text_encoder_outputs') and args.cache_text_encoder_outputs:
|
|
logger.warning("Text encoder output caching is not yet implemented for CogView4. Disabling.")
|
|
args.cache_text_encoder_outputs = False
|
|
|
|
if hasattr(args, 'cache_text_encoder_outputs_to_disk') and args.cache_text_encoder_outputs_to_disk:
|
|
logger.warning("Text encoder output disk caching is not yet implemented for CogView4. Disabling.")
|
|
args.cache_text_encoder_outputs_to_disk = False
|
|
|
|
# Set default values for CogView4
|
|
if not hasattr(args, 'max_token_length'):
|
|
args.max_token_length = 128 # Default token length for GLM
|
|
|
|
if not hasattr(args, 'resolution'):
|
|
args.resolution = 256 # Default resolution for CogView4
|
|
|
|
# Update dataset resolution if needed
|
|
train_dataset_group.set_resolution(args.resolution)
|
|
if val_dataset_group is not None:
|
|
val_dataset_group.verify_bucket_reso_steps(32) # TODO check this
|
|
|
|
def load_target_model(self, args, weight_dtype, accelerator):
|
|
"""
|
|
Load the CogView4 model components including tokenizer, text encoder, VAE, and transformer.
|
|
"""
|
|
logger.info(f"Loading CogView4 model from {args.pretrained_model_name_or_path}")
|
|
|
|
# Load tokenizer and text encoder (GLM)
|
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
args.pretrained_model_name_or_path,
|
|
subfolder="tokenizer",
|
|
use_fast=False,
|
|
trust_remote_code=True
|
|
)
|
|
|
|
self.text_encoder = GlmModel.from_pretrained(
|
|
args.pretrained_model_name_or_path,
|
|
subfolder="text_encoder",
|
|
torch_dtype=weight_dtype,
|
|
trust_remote_code=True
|
|
)
|
|
self.text_encoder.eval()
|
|
|
|
# Load VAE
|
|
self.vae = AutoencoderKL.from_pretrained(
|
|
args.pretrained_model_name_or_path,
|
|
subfolder="vae",
|
|
torch_dtype=weight_dtype,
|
|
trust_remote_code=True
|
|
)
|
|
self.vae.eval()
|
|
|
|
# Load transformer
|
|
self.transformer = CogView4Transformer2DModel.from_pretrained(
|
|
args.pretrained_model_name_or_path,
|
|
subfolder="transformer",
|
|
torch_dtype=weight_dtype,
|
|
trust_remote_code=True
|
|
)
|
|
|
|
# Create noise scheduler
|
|
self.noise_scheduler = FlowMatchEulerDiscreteScheduler(
|
|
num_train_timesteps=1000,
|
|
beta_start=0.0001,
|
|
beta_end=0.02,
|
|
beta_schedule="linear",
|
|
prediction_type="epsilon",
|
|
)
|
|
|
|
# Move models to device
|
|
device = accelerator.device
|
|
self.text_encoder = self.text_encoder.to(device)
|
|
self.vae = self.vae.to(device)
|
|
self.transformer = self.transformer.to(device)
|
|
|
|
# Set gradient checkpointing if enabled
|
|
if args.gradient_checkpointing:
|
|
self.transformer.enable_gradient_checkpointing()
|
|
# Text encoder gradient checkpointing is handled by prepare_text_encoder_grad_ckpt_workaround
|
|
# called by the base trainer if args.train_text_encoder is true for that TE.
|
|
|
|
# Store components for later use
|
|
self.weight_dtype = weight_dtype
|
|
|
|
# Return components in the expected format
|
|
return "cogview4-v1", [self.text_encoder], self.vae, self.transformer
|
|
|
|
def get_tokenize_strategy(self, args):
|
|
# For CogView4, we use a fixed token length for GLM
|
|
max_token_length = getattr(args, 'max_token_length', 128)
|
|
logger.info(f"Using max_token_length: {max_token_length} for GLM tokenizer")
|
|
return strategy_cogview4.CogView4TokenizeStrategy(max_token_length, args.tokenizer_cache_dir)
|
|
|
|
def get_tokenizers(self, tokenize_strategy):
|
|
# For CogView4, we only have one tokenizer (GLM)
|
|
return [tokenize_strategy.tokenizer]
|
|
|
|
def get_latents_caching_strategy(self, args):
|
|
return strategy_cogview4.CogView4LatentsCachingStrategy(
|
|
args.cache_latents_to_disk,
|
|
args.vae_batch_size,
|
|
skip_disk_cache_validity_check=False
|
|
)
|
|
|
|
def get_text_encoding_strategy(self, args):
|
|
# For CogView4, we use GLM instead of T5, but maintain similar interface
|
|
return strategy_cogview4.CogView4TextEncodingStrategy(
|
|
apply_attention_mask=getattr(args, 'apply_attention_mask', True)
|
|
)
|
|
|
|
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
|
|
# For CogView4, we always return the text encoder (GLM) as it's needed for encoding
|
|
return text_encoders
|
|
|
|
def get_text_encoders_train_flags(self, args, text_encoders):
|
|
# For CogView4, we only have one text encoder (GLM)
|
|
return [getattr(args, 'train_text_encoder', False)]
|
|
|
|
def get_text_encoder_outputs_caching_strategy(self, args):
|
|
if getattr(args, 'cache_text_encoder_outputs', False):
|
|
return strategy_cogview4.CogView4TextEncoderOutputsCachingStrategy(
|
|
cache_to_disk=getattr(args, 'cache_text_encoder_outputs_to_disk', False),
|
|
batch_size=getattr(args, 'text_encoder_batch_size', 1),
|
|
skip_disk_cache_validity_check=getattr(args, 'skip_cache_check', False),
|
|
is_partial=getattr(args, 'train_text_encoder', False)
|
|
)
|
|
return None
|
|
|
|
def cache_text_encoder_outputs_if_needed(
|
|
self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
|
|
):
|
|
"""Cache text encoder outputs to speed up training.
|
|
|
|
Args:
|
|
args: Training arguments
|
|
accelerator: Accelerator instance
|
|
unet: UNet model
|
|
vae: VAE model
|
|
text_encoders: List containing the GLM text encoder
|
|
dataset: Dataset to cache text encoder outputs for
|
|
weight_dtype: Data type for weights
|
|
"""
|
|
if getattr(args, 'cache_text_encoder_outputs', False):
|
|
if not getattr(args, 'lowram', False):
|
|
# Free up GPU memory by moving models to CPU
|
|
logger.info("Moving VAE and UNet to CPU to save memory")
|
|
org_vae_device = vae.device
|
|
org_unet_device = unet.device
|
|
vae.to("cpu")
|
|
unet.to("cpu")
|
|
clean_memory_on_device(accelerator.device)
|
|
|
|
# Move text encoder to GPU with proper dtype
|
|
logger.info("Moving text encoder to GPU")
|
|
text_encoder = text_encoders[0] # CogView4 uses a single text encoder (GLM)
|
|
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
|
|
|
# Cache text encoder outputs
|
|
with accelerator.autocast():
|
|
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
|
|
|
|
# Cache sample prompts if provided
|
|
if getattr(args, 'sample_prompts', None) is not None:
|
|
logger.info(f"Caching text encoder outputs for sample prompts: {args.sample_prompts}")
|
|
|
|
# Initialize CogView4 strategies
|
|
tokenize_strategy = strategy_cogview4.CogView4TokenizeStrategy(
|
|
max_length=getattr(args, 'max_token_length', 128),
|
|
tokenizer_cache_dir=getattr(args, 'tokenizer_cache_dir', None)
|
|
)
|
|
text_encoding_strategy = strategy_cogview4.CogView4TextEncodingStrategy(
|
|
apply_attention_mask=getattr(args, 'apply_attention_mask', True)
|
|
)
|
|
|
|
prompts = train_util.load_prompts(args.sample_prompts)
|
|
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
|
|
|
|
with accelerator.autocast(), torch.no_grad():
|
|
for prompt_dict in prompts:
|
|
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
|
|
if p and p not in sample_prompts_te_outputs: # Skip empty prompts and duplicates
|
|
logger.info(f"Caching text encoder outputs for prompt: {p}")
|
|
tokens = tokenize_strategy.tokenize(p)
|
|
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
|
|
tokenize_strategy=tokenize_strategy,
|
|
models=text_encoders,
|
|
tokens=tokens,
|
|
apply_attention_mask=getattr(args, 'apply_attention_mask', True)
|
|
)
|
|
self.sample_prompts_te_outputs = sample_prompts_te_outputs
|
|
|
|
accelerator.wait_for_everyone()
|
|
|
|
# Move text encoder back to CPU if not training it
|
|
if not getattr(args, 'train_text_encoder', False):
|
|
logger.info("Moving text encoder back to CPU")
|
|
text_encoder.to("cpu")
|
|
clean_memory_on_device(accelerator.device)
|
|
|
|
# Move VAE and UNet back to their original devices if not in lowram mode
|
|
if not getattr(args, 'lowram', False):
|
|
logger.info("Moving VAE and UNet back to original devices")
|
|
vae.to(org_vae_device)
|
|
unet.to(org_unet_device)
|
|
else:
|
|
# Keep text encoder in GPU if we're not caching outputs
|
|
if text_encoders:
|
|
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
|
|
|
|
# def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
|
|
# noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
|
|
|
# # get size embeddings
|
|
# orig_size = batch["original_sizes_hw"]
|
|
# crop_size = batch["crop_top_lefts"]
|
|
# target_size = batch["target_sizes_hw"]
|
|
# embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
|
|
|
|
# # concat embeddings
|
|
# encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds
|
|
# vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
|
|
# text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
|
|
|
|
# noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
|
|
# return noise_pred
|
|
|
|
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, transformer):
|
|
"""
|
|
Generate sample images during training to monitor progress.
|
|
"""
|
|
logger.info(f"Generating sample images at step {global_step}")
|
|
|
|
# Set models to eval mode
|
|
was_training = transformer.training
|
|
transformer.eval()
|
|
vae.eval()
|
|
|
|
# Sample prompts to use for generation
|
|
sample_prompts = [
|
|
"A high quality photo of a cat",
|
|
"A beautiful landscape with mountains and a lake",
|
|
"A futuristic city at night"
|
|
]
|
|
|
|
# Generate images for each prompt
|
|
all_images = []
|
|
|
|
with torch.no_grad():
|
|
for prompt in sample_prompts:
|
|
# Tokenize the prompt
|
|
text_input = self.tokenizer(
|
|
prompt,
|
|
padding="max_length",
|
|
max_length=self.tokenizer.model_max_length,
|
|
truncation=True,
|
|
return_tensors="pt",
|
|
)
|
|
input_ids = text_input.input_ids.to(device)
|
|
attention_mask = text_input.attention_mask.to(device)
|
|
|
|
# Get text embeddings
|
|
with torch.no_grad():
|
|
text_embeddings = self.text_encoder(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
return_dict=True
|
|
).last_hidden_state
|
|
|
|
# Sample random noise
|
|
latents = torch.randn(
|
|
(1, 4, args.resolution // 8, args.resolution // 8),
|
|
device=device,
|
|
dtype=torch.float32
|
|
)
|
|
|
|
# Set the scheduler for inference
|
|
self.noise_scheduler.set_timesteps(50, device=device)
|
|
|
|
# Generate image using the denoising process
|
|
for t in self.noise_scheduler.timesteps:
|
|
# Expand the latents if we are doing classifier-free guidance
|
|
latent_model_input = torch.cat([latents] * 2) if args.guidance_scale > 1.0 else latents
|
|
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, t)
|
|
|
|
# Predict the noise residual
|
|
with torch.no_grad():
|
|
noise_pred = transformer(
|
|
latent_model_input,
|
|
t.unsqueeze(0).repeat(latent_model_input.shape[0]),
|
|
encoder_hidden_states=torch.cat([text_embeddings] * 2) if args.guidance_scale > 1.0 else text_embeddings,
|
|
attention_mask=attention_mask,
|
|
return_dict=True,
|
|
).sample
|
|
|
|
# Perform guidance
|
|
if args.guidance_scale > 1.0:
|
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
|
noise_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
|
|
|
# Compute the previous noisy sample x_t -> x_t-1
|
|
latents = self.noise_scheduler.step(noise_pred, t, latents).prev_sample
|
|
|
|
# Scale and decode the image latents with vae
|
|
latents = 1 / 0.18215 * latents
|
|
with torch.no_grad():
|
|
image = vae.decode(latents).sample
|
|
|
|
# Convert to PIL image
|
|
image = (image / 2 + 0.5).clamp(0, 1)
|
|
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
|
|
images = (image * 255).round().astype("uint8")
|
|
pil_images = [Image.fromarray(image) for image in images]
|
|
all_images.extend(pil_images)
|
|
|
|
# Log images to tensorboard if available
|
|
if accelerator.is_main_process and hasattr(accelerator, "log"):
|
|
log_images = []
|
|
for i, img in enumerate(all_images):
|
|
# Convert PIL image to numpy for logging
|
|
log_images.append(np.array(img))
|
|
|
|
# Save individual images
|
|
os.makedirs(os.path.join(args.output_dir, "samples"), exist_ok=True)
|
|
img.save(os.path.join(args.output_dir, "samples", f"sample_epoch{epoch}_step{global_step}_{i}.png"))
|
|
|
|
# Log to tensorboard
|
|
accelerator.log({
|
|
"samples": [
|
|
wandb.Image(img, caption=f"{sample_prompts[i]}")
|
|
for i, img in enumerate(log_images)
|
|
]
|
|
}, step=global_step)
|
|
|
|
# Set models back to training mode if they were training before
|
|
if was_training:
|
|
transformer.train()
|
|
vae.train()
|
|
|
|
return all_images
|
|
|
|
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
|
|
# Return the noise scheduler that was created during model loading
|
|
return self.noise_scheduler
|
|
|
|
def encode_images_to_latents(self, args, vae, images):
|
|
return vae.encode(images)
|
|
|
|
def shift_scale_latents(self, args, latents):
|
|
return latents
|
|
|
|
def get_noise_pred_and_target(
|
|
self,
|
|
args,
|
|
accelerator,
|
|
noise_scheduler,
|
|
latents,
|
|
batch,
|
|
text_encoder_conds,
|
|
transformer: CogView4Transformer2DModel,
|
|
network,
|
|
weight_dtype,
|
|
train_unet=True,
|
|
is_train=True,
|
|
):
|
|
"""
|
|
Get noise prediction and target for the loss computation.
|
|
"""
|
|
# Sample noise that we'll add to the latents
|
|
noise = torch.randn_like(latents)
|
|
bsz = latents.shape[0]
|
|
|
|
# Sample a random timestep for each image
|
|
timesteps = torch.randint(
|
|
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
|
|
).long()
|
|
|
|
# Add noise to the latents according to the noise magnitude at each timestep
|
|
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
|
|
|
# Get text embeddings
|
|
input_ids = batch["input_ids"]
|
|
attention_mask = batch.get("attention_mask", None)
|
|
|
|
# Prepare text encoder outputs
|
|
with torch.set_grad_enabled(self.train_text_encoder):
|
|
text_embeddings = self.text_encoder(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
return_dict=True
|
|
).last_hidden_state
|
|
|
|
# Predict the noise residual
|
|
with torch.set_grad_enabled(train_unet):
|
|
# Add conditioning dropout for classifier-free guidance
|
|
if args.guidance_scale > 1.0:
|
|
# Randomly drop text conditioning 5% of the time
|
|
mask = (torch.rand(bsz, device=latents.device) < 0.05).float().unsqueeze(1).unsqueeze(1)
|
|
text_embeddings = text_embeddings * (1 - mask) + torch.zeros_like(text_embeddings) * mask
|
|
|
|
# Predict noise
|
|
model_pred = transformer(
|
|
noisy_latents,
|
|
timesteps,
|
|
encoder_hidden_states=text_embeddings,
|
|
attention_mask=attention_mask,
|
|
return_dict=True,
|
|
).sample
|
|
|
|
# Calculate target
|
|
if noise_scheduler.config.prediction_type == "epsilon":
|
|
target = noise
|
|
elif noise_scheduler.config.prediction_type == "v_prediction":
|
|
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
|
else:
|
|
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
|
|
|
# For classifier-free guidance, we need to do two forward passes
|
|
if args.guidance_scale > 1.0:
|
|
# Predict the conditional and unconditional outputs
|
|
model_pred_uncond = transformer(
|
|
noisy_latents,
|
|
timesteps,
|
|
encoder_hidden_states=torch.zeros_like(text_embeddings),
|
|
attention_mask=attention_mask,
|
|
return_dict=True,
|
|
).sample
|
|
|
|
# Perform classifier-free guidance
|
|
model_pred = model_pred_uncond + args.guidance_scale * (model_pred - model_pred_uncond)
|
|
|
|
# For training, we only compute the loss on the conditional prediction
|
|
if is_train:
|
|
model_pred = model_pred_uncond + args.guidance_scale * (model_pred - model_pred_uncond)
|
|
|
|
# Simple weighting - can be adjusted based on timestep if needed
|
|
weighting = torch.ones_like(timesteps, dtype=weight_dtype, device=latents.device)
|
|
|
|
return model_pred, target, timesteps, weighting
|
|
|
|
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
|
"""
|
|
Post-process the loss value.
|
|
This can include applying timestep weighting, gradient clipping, etc.
|
|
"""
|
|
# Apply timestep weighting if specified
|
|
if hasattr(args, 'timestep_bias_portion') and args.timestep_bias_portion > 0.0:
|
|
# Simple timestep weighting - can be made more sophisticated if needed
|
|
weights = torch.ones_like(timesteps, dtype=torch.float32)
|
|
if hasattr(args, 'timestep_bias_begin') and args.timestep_bias_begin > 0:
|
|
mask = timesteps < args.timestep_bias_begin
|
|
weights[mask] = 0.0
|
|
if hasattr(args, 'timestep_bias_end') and args.timestep_bias_end < 1000:
|
|
mask = timesteps > args.timestep_bias_end
|
|
weights[mask] = 0.0
|
|
if hasattr(args, 'timestep_bias_multiplier') and args.timestep_bias_multiplier != 1.0:
|
|
weights = weights * args.timestep_bias_multiplier
|
|
|
|
loss = loss * weights.to(loss.device)
|
|
|
|
# Clip loss values if specified
|
|
if hasattr(args, 'clip_grad_norm') and args.clip_grad_norm > 0.0:
|
|
loss = torch.clamp(loss, -args.clip_grad_norm, args.clip_grad_norm)
|
|
|
|
return loss
|
|
|
|
def prepare_extra_step_kwargs(self, generator, eta):
|
|
"""
|
|
Prepare extra kwargs for the scheduler step, such as the generator for reproducibility.
|
|
"""
|
|
# Prepare extra step kwargs.
|
|
# TODO: Logic should ideally just be moved to base class
|
|
accepts_eta = "eta" in set(inspect.signature(self.noise_scheduler.step).parameters.keys())
|
|
extra_step_kwargs = {}
|
|
if accepts_eta:
|
|
extra_step_kwargs["eta"] = eta
|
|
|
|
# Check if the scheduler accepts a generator
|
|
accepts_generator = "generator" in set(inspect.signature(self.noise_scheduler.step).parameters.keys())
|
|
if accepts_generator:
|
|
extra_step_kwargs["generator"] = generator
|
|
|
|
return extra_step_kwargs
|
|
|
|
def update_metadata(self, metadata, args):
|
|
metadata["ss_apply_attn_mask"] = args.apply_attn_mask
|
|
metadata["ss_weighting_scheme"] = args.weighting_scheme
|
|
metadata["ss_logit_mean"] = args.logit_mean
|
|
metadata["ss_logit_std"] = args.logit_std
|
|
metadata["ss_mode_scale"] = args.mode_scale
|
|
metadata["ss_guidance_scale"] = args.guidance_scale
|
|
metadata["ss_timestep_sampling"] = args.timestep_sampling
|
|
metadata["ss_sigmoid_scale"] = args.sigmoid_scale
|
|
metadata["ss_model_prediction_type"] = args.model_prediction_type
|
|
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
|
|
|
|
def is_text_encoder_not_needed_for_training(self, args):
|
|
"""Check if text encoder outputs are cached and not being trained."""
|
|
return getattr(args, 'cache_text_encoder_outputs', False) and not getattr(args, 'train_text_encoder', False)
|
|
|
|
def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
|
|
"""Prepare text encoder for gradient checkpointing.
|
|
|
|
For CogView4, we only have one text encoder (GLM) so we don't need index-based handling.
|
|
The base class method handles enabling gradient checkpointing if args specify it.
|
|
"""
|
|
return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder)
|
|
|
|
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
|
|
"""Prepare text encoder for FP8 training.
|
|
|
|
Args:
|
|
index: Text encoder index (always 0 for CogView4)
|
|
text_encoder: The text encoder model (GLM)
|
|
te_weight_dtype: Target weight dtype for the encoder
|
|
weight_dtype: Base weight dtype for embeddings
|
|
"""
|
|
if index != 0:
|
|
logger.warning(f"Unexpected text encoder index {index} for CogView4, expecting 0.")
|
|
# Still proceed, assuming it's the single GLM encoder
|
|
|
|
logger.info(f"Preparing GLM text encoder (index {index}) for {te_weight_dtype}, embeddings to {weight_dtype}")
|
|
text_encoder.to(te_weight_dtype)
|
|
|
|
# Move embeddings to base weight dtype if they exist
|
|
if hasattr(text_encoder, 'word_embeddings'): # GLM typically has word_embeddings
|
|
text_encoder.word_embeddings.to(dtype=weight_dtype)
|
|
if hasattr(text_encoder, 'position_embeddings'): # GLM might have position_embeddings
|
|
text_encoder.position_embeddings.to(dtype=weight_dtype)
|
|
# Add other relevant parts of GLM if they need specific dtype handling for FP8
|
|
return text_encoder
|
|
|
|
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
|
|
"""Called at the end of each validation step."""
|
|
# No special handling needed for CogView4 (e.g., no block swapping)
|
|
pass
|
|
|
|
def prepare_unet_with_accelerator(
|
|
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
|
|
) -> torch.nn.Module:
|
|
"""Prepare UNet model (CogView4Transformer2DModel) with accelerator.
|
|
|
|
For CogView4, we use standard model preparation.
|
|
"""
|
|
return super().prepare_unet_with_accelerator(args, accelerator, unet)
|
|
|
|
|
|
def setup_parser() -> argparse.ArgumentParser:
|
|
parser = train_network.setup_parser()
|
|
train_util.add_dit_training_arguments(parser)
|
|
flux_train_utils.add_flux_train_arguments(parser)
|
|
return parser
|
|
|
|
if __name__ == "__main__":
|
|
parser = setup_parser()
|
|
|
|
args = parser.parse_args()
|
|
train_util.verify_command_line_training_args(args)
|
|
args = train_util.read_config_from_file(args, parser)
|
|
|
|
trainer = CogView4NetworkTrainer()
|
|
trainer.train(args)
|