mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-17 17:24:21 +00:00
init
This commit is contained in:
617
cogview4_train_network.py
Normal file
617
cogview4_train_network.py
Normal file
@@ -0,0 +1,617 @@
|
||||
import argparse
|
||||
import copy
|
||||
import math
|
||||
import random
|
||||
import os
|
||||
import inspect
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from diffusers import AutoencoderKL, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.pipelines.cogview4.pipeline_cogview4 import CogView4Pipeline
|
||||
from PIL import Image
|
||||
from transformers import AutoTokenizer, GlmModel
|
||||
|
||||
from library.device_utils import clean_memory_on_device, init_ipex
|
||||
|
||||
init_ipex()
|
||||
|
||||
import train_network
|
||||
from library import (
|
||||
flux_models,
|
||||
flux_train_utils,
|
||||
flux_utils,
|
||||
sd3_train_utils,
|
||||
strategy_base,
|
||||
strategy_cogview4,
|
||||
train_util,
|
||||
)
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CogView4NetworkTrainer(train_network.NetworkTrainer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.sample_prompts_te_outputs = None
|
||||
self.is_swapping_blocks: bool = False
|
||||
|
||||
def assert_extra_args(
|
||||
self,
|
||||
args,
|
||||
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
|
||||
val_dataset_group: Optional[train_util.DatasetGroup],
|
||||
):
|
||||
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
|
||||
|
||||
# CogView4 specific argument validation
|
||||
if hasattr(args, 'fp8_base_unet') and args.fp8_base_unet:
|
||||
logger.warning("FP8 training is not yet fully supported for CogView4. Disabling fp8_base_unet.")
|
||||
args.fp8_base_unet = False
|
||||
|
||||
if hasattr(args, 'cache_text_encoder_outputs') and args.cache_text_encoder_outputs:
|
||||
logger.warning("Text encoder output caching is not yet implemented for CogView4. Disabling.")
|
||||
args.cache_text_encoder_outputs = False
|
||||
|
||||
if hasattr(args, 'cache_text_encoder_outputs_to_disk') and args.cache_text_encoder_outputs_to_disk:
|
||||
logger.warning("Text encoder output disk caching is not yet implemented for CogView4. Disabling.")
|
||||
args.cache_text_encoder_outputs_to_disk = False
|
||||
|
||||
# Set default values for CogView4
|
||||
if not hasattr(args, 'max_token_length'):
|
||||
args.max_token_length = 128 # Default token length for GLM
|
||||
|
||||
if not hasattr(args, 'resolution'):
|
||||
args.resolution = 256 # Default resolution for CogView4
|
||||
|
||||
# Update dataset resolution if needed
|
||||
train_dataset_group.set_resolution(args.resolution)
|
||||
if val_dataset_group is not None:
|
||||
val_dataset_group.verify_bucket_reso_steps(32) # TODO check this
|
||||
|
||||
def load_target_model(self, args, weight_dtype, accelerator):
|
||||
"""
|
||||
Load the CogView4 model components including tokenizer, text encoder, VAE, and transformer.
|
||||
"""
|
||||
logger.info(f"Loading CogView4 model from {args.pretrained_model_name_or_path}")
|
||||
|
||||
# Load tokenizer and text encoder (GLM)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="tokenizer",
|
||||
use_fast=False,
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
self.text_encoder = GlmModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="text_encoder",
|
||||
torch_dtype=weight_dtype,
|
||||
trust_remote_code=True
|
||||
)
|
||||
self.text_encoder.eval()
|
||||
|
||||
# Load VAE
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="vae",
|
||||
torch_dtype=weight_dtype,
|
||||
trust_remote_code=True
|
||||
)
|
||||
self.vae.eval()
|
||||
|
||||
# Load transformer
|
||||
self.transformer = CogView4Transformer2DModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="transformer",
|
||||
torch_dtype=weight_dtype,
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
# Create noise scheduler
|
||||
self.noise_scheduler = FlowMatchEulerDiscreteScheduler(
|
||||
num_train_timesteps=1000,
|
||||
beta_start=0.0001,
|
||||
beta_end=0.02,
|
||||
beta_schedule="linear",
|
||||
prediction_type="epsilon",
|
||||
)
|
||||
|
||||
# Move models to device
|
||||
device = accelerator.device
|
||||
self.text_encoder = self.text_encoder.to(device)
|
||||
self.vae = self.vae.to(device)
|
||||
self.transformer = self.transformer.to(device)
|
||||
|
||||
# Set gradient checkpointing if enabled
|
||||
if args.gradient_checkpointing:
|
||||
self.transformer.enable_gradient_checkpointing()
|
||||
# Text encoder gradient checkpointing is handled by prepare_text_encoder_grad_ckpt_workaround
|
||||
# called by the base trainer if args.train_text_encoder is true for that TE.
|
||||
|
||||
# Store components for later use
|
||||
self.weight_dtype = weight_dtype
|
||||
|
||||
# Return components in the expected format
|
||||
return "cogview4-v1", [self.text_encoder], self.vae, self.transformer
|
||||
|
||||
def get_tokenize_strategy(self, args):
|
||||
# For CogView4, we use a fixed token length for GLM
|
||||
max_token_length = getattr(args, 'max_token_length', 128)
|
||||
logger.info(f"Using max_token_length: {max_token_length} for GLM tokenizer")
|
||||
return strategy_cogview4.CogView4TokenizeStrategy(max_token_length, args.tokenizer_cache_dir)
|
||||
|
||||
def get_tokenizers(self, tokenize_strategy):
|
||||
# For CogView4, we only have one tokenizer (GLM)
|
||||
return [tokenize_strategy.tokenizer]
|
||||
|
||||
def get_latents_caching_strategy(self, args):
|
||||
return strategy_cogview4.CogView4LatentsCachingStrategy(
|
||||
args.cache_latents_to_disk,
|
||||
args.vae_batch_size,
|
||||
skip_disk_cache_validity_check=False
|
||||
)
|
||||
|
||||
def get_text_encoding_strategy(self, args):
|
||||
# For CogView4, we use GLM instead of T5, but maintain similar interface
|
||||
return strategy_cogview4.CogView4TextEncodingStrategy(
|
||||
apply_attention_mask=getattr(args, 'apply_attention_mask', True)
|
||||
)
|
||||
|
||||
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
|
||||
# For CogView4, we always return the text encoder (GLM) as it's needed for encoding
|
||||
return text_encoders
|
||||
|
||||
def get_text_encoders_train_flags(self, args, text_encoders):
|
||||
# For CogView4, we only have one text encoder (GLM)
|
||||
return [getattr(args, 'train_text_encoder', False)]
|
||||
|
||||
def get_text_encoder_outputs_caching_strategy(self, args):
|
||||
if getattr(args, 'cache_text_encoder_outputs', False):
|
||||
return strategy_cogview4.CogView4TextEncoderOutputsCachingStrategy(
|
||||
cache_to_disk=getattr(args, 'cache_text_encoder_outputs_to_disk', False),
|
||||
batch_size=getattr(args, 'text_encoder_batch_size', 1),
|
||||
skip_disk_cache_validity_check=getattr(args, 'skip_cache_check', False),
|
||||
is_partial=getattr(args, 'train_text_encoder', False)
|
||||
)
|
||||
return None
|
||||
|
||||
def cache_text_encoder_outputs_if_needed(
|
||||
self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
|
||||
):
|
||||
"""Cache text encoder outputs to speed up training.
|
||||
|
||||
Args:
|
||||
args: Training arguments
|
||||
accelerator: Accelerator instance
|
||||
unet: UNet model
|
||||
vae: VAE model
|
||||
text_encoders: List containing the GLM text encoder
|
||||
dataset: Dataset to cache text encoder outputs for
|
||||
weight_dtype: Data type for weights
|
||||
"""
|
||||
if getattr(args, 'cache_text_encoder_outputs', False):
|
||||
if not getattr(args, 'lowram', False):
|
||||
# Free up GPU memory by moving models to CPU
|
||||
logger.info("Moving VAE and UNet to CPU to save memory")
|
||||
org_vae_device = vae.device
|
||||
org_unet_device = unet.device
|
||||
vae.to("cpu")
|
||||
unet.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
# Move text encoder to GPU with proper dtype
|
||||
logger.info("Moving text encoder to GPU")
|
||||
text_encoder = text_encoders[0] # CogView4 uses a single text encoder (GLM)
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# Cache text encoder outputs
|
||||
with accelerator.autocast():
|
||||
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
|
||||
|
||||
# Cache sample prompts if provided
|
||||
if getattr(args, 'sample_prompts', None) is not None:
|
||||
logger.info(f"Caching text encoder outputs for sample prompts: {args.sample_prompts}")
|
||||
|
||||
# Initialize CogView4 strategies
|
||||
tokenize_strategy = strategy_cogview4.CogView4TokenizeStrategy(
|
||||
max_length=getattr(args, 'max_token_length', 128),
|
||||
tokenizer_cache_dir=getattr(args, 'tokenizer_cache_dir', None)
|
||||
)
|
||||
text_encoding_strategy = strategy_cogview4.CogView4TextEncodingStrategy(
|
||||
apply_attention_mask=getattr(args, 'apply_attention_mask', True)
|
||||
)
|
||||
|
||||
prompts = train_util.load_prompts(args.sample_prompts)
|
||||
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
|
||||
|
||||
with accelerator.autocast(), torch.no_grad():
|
||||
for prompt_dict in prompts:
|
||||
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
|
||||
if p and p not in sample_prompts_te_outputs: # Skip empty prompts and duplicates
|
||||
logger.info(f"Caching text encoder outputs for prompt: {p}")
|
||||
tokens = tokenize_strategy.tokenize(p)
|
||||
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy=tokenize_strategy,
|
||||
models=text_encoders,
|
||||
tokens=tokens,
|
||||
apply_attention_mask=getattr(args, 'apply_attention_mask', True)
|
||||
)
|
||||
self.sample_prompts_te_outputs = sample_prompts_te_outputs
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# Move text encoder back to CPU if not training it
|
||||
if not getattr(args, 'train_text_encoder', False):
|
||||
logger.info("Moving text encoder back to CPU")
|
||||
text_encoder.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
# Move VAE and UNet back to their original devices if not in lowram mode
|
||||
if not getattr(args, 'lowram', False):
|
||||
logger.info("Moving VAE and UNet back to original devices")
|
||||
vae.to(org_vae_device)
|
||||
unet.to(org_unet_device)
|
||||
else:
|
||||
# Keep text encoder in GPU if we're not caching outputs
|
||||
if text_encoders:
|
||||
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
|
||||
# noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
||||
|
||||
# # get size embeddings
|
||||
# orig_size = batch["original_sizes_hw"]
|
||||
# crop_size = batch["crop_top_lefts"]
|
||||
# target_size = batch["target_sizes_hw"]
|
||||
# embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
|
||||
|
||||
# # concat embeddings
|
||||
# encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds
|
||||
# vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
|
||||
# text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
|
||||
|
||||
# noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
|
||||
# return noise_pred
|
||||
|
||||
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, transformer):
|
||||
"""
|
||||
Generate sample images during training to monitor progress.
|
||||
"""
|
||||
logger.info(f"Generating sample images at step {global_step}")
|
||||
|
||||
# Set models to eval mode
|
||||
was_training = transformer.training
|
||||
transformer.eval()
|
||||
vae.eval()
|
||||
|
||||
# Sample prompts to use for generation
|
||||
sample_prompts = [
|
||||
"A high quality photo of a cat",
|
||||
"A beautiful landscape with mountains and a lake",
|
||||
"A futuristic city at night"
|
||||
]
|
||||
|
||||
# Generate images for each prompt
|
||||
all_images = []
|
||||
|
||||
with torch.no_grad():
|
||||
for prompt in sample_prompts:
|
||||
# Tokenize the prompt
|
||||
text_input = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
input_ids = text_input.input_ids.to(device)
|
||||
attention_mask = text_input.attention_mask.to(device)
|
||||
|
||||
# Get text embeddings
|
||||
with torch.no_grad():
|
||||
text_embeddings = self.text_encoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
return_dict=True
|
||||
).last_hidden_state
|
||||
|
||||
# Sample random noise
|
||||
latents = torch.randn(
|
||||
(1, 4, args.resolution // 8, args.resolution // 8),
|
||||
device=device,
|
||||
dtype=torch.float32
|
||||
)
|
||||
|
||||
# Set the scheduler for inference
|
||||
self.noise_scheduler.set_timesteps(50, device=device)
|
||||
|
||||
# Generate image using the denoising process
|
||||
for t in self.noise_scheduler.timesteps:
|
||||
# Expand the latents if we are doing classifier-free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if args.guidance_scale > 1.0 else latents
|
||||
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# Predict the noise residual
|
||||
with torch.no_grad():
|
||||
noise_pred = transformer(
|
||||
latent_model_input,
|
||||
t.unsqueeze(0).repeat(latent_model_input.shape[0]),
|
||||
encoder_hidden_states=torch.cat([text_embeddings] * 2) if args.guidance_scale > 1.0 else text_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
return_dict=True,
|
||||
).sample
|
||||
|
||||
# Perform guidance
|
||||
if args.guidance_scale > 1.0:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# Compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.noise_scheduler.step(noise_pred, t, latents).prev_sample
|
||||
|
||||
# Scale and decode the image latents with vae
|
||||
latents = 1 / 0.18215 * latents
|
||||
with torch.no_grad():
|
||||
image = vae.decode(latents).sample
|
||||
|
||||
# Convert to PIL image
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
|
||||
images = (image * 255).round().astype("uint8")
|
||||
pil_images = [Image.fromarray(image) for image in images]
|
||||
all_images.extend(pil_images)
|
||||
|
||||
# Log images to tensorboard if available
|
||||
if accelerator.is_main_process and hasattr(accelerator, "log"):
|
||||
log_images = []
|
||||
for i, img in enumerate(all_images):
|
||||
# Convert PIL image to numpy for logging
|
||||
log_images.append(np.array(img))
|
||||
|
||||
# Save individual images
|
||||
os.makedirs(os.path.join(args.output_dir, "samples"), exist_ok=True)
|
||||
img.save(os.path.join(args.output_dir, "samples", f"sample_epoch{epoch}_step{global_step}_{i}.png"))
|
||||
|
||||
# Log to tensorboard
|
||||
accelerator.log({
|
||||
"samples": [
|
||||
wandb.Image(img, caption=f"{sample_prompts[i]}")
|
||||
for i, img in enumerate(log_images)
|
||||
]
|
||||
}, step=global_step)
|
||||
|
||||
# Set models back to training mode if they were training before
|
||||
if was_training:
|
||||
transformer.train()
|
||||
vae.train()
|
||||
|
||||
return all_images
|
||||
|
||||
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
|
||||
# Return the noise scheduler that was created during model loading
|
||||
return self.noise_scheduler
|
||||
|
||||
def encode_images_to_latents(self, args, vae, images):
|
||||
return vae.encode(images)
|
||||
|
||||
def shift_scale_latents(self, args, latents):
|
||||
return latents
|
||||
|
||||
def get_noise_pred_and_target(
|
||||
self,
|
||||
args,
|
||||
accelerator,
|
||||
noise_scheduler,
|
||||
latents,
|
||||
batch,
|
||||
text_encoder_conds,
|
||||
transformer: CogView4Transformer2DModel,
|
||||
network,
|
||||
weight_dtype,
|
||||
train_unet=True,
|
||||
is_train=True,
|
||||
):
|
||||
"""
|
||||
Get noise prediction and target for the loss computation.
|
||||
"""
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(
|
||||
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
|
||||
).long()
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# Get text embeddings
|
||||
input_ids = batch["input_ids"]
|
||||
attention_mask = batch.get("attention_mask", None)
|
||||
|
||||
# Prepare text encoder outputs
|
||||
with torch.set_grad_enabled(self.train_text_encoder):
|
||||
text_embeddings = self.text_encoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
return_dict=True
|
||||
).last_hidden_state
|
||||
|
||||
# Predict the noise residual
|
||||
with torch.set_grad_enabled(train_unet):
|
||||
# Add conditioning dropout for classifier-free guidance
|
||||
if args.guidance_scale > 1.0:
|
||||
# Randomly drop text conditioning 5% of the time
|
||||
mask = (torch.rand(bsz, device=latents.device) < 0.05).float().unsqueeze(1).unsqueeze(1)
|
||||
text_embeddings = text_embeddings * (1 - mask) + torch.zeros_like(text_embeddings) * mask
|
||||
|
||||
# Predict noise
|
||||
model_pred = transformer(
|
||||
noisy_latents,
|
||||
timesteps,
|
||||
encoder_hidden_states=text_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
return_dict=True,
|
||||
).sample
|
||||
|
||||
# Calculate target
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
target = noise
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
|
||||
# For classifier-free guidance, we need to do two forward passes
|
||||
if args.guidance_scale > 1.0:
|
||||
# Predict the conditional and unconditional outputs
|
||||
model_pred_uncond = transformer(
|
||||
noisy_latents,
|
||||
timesteps,
|
||||
encoder_hidden_states=torch.zeros_like(text_embeddings),
|
||||
attention_mask=attention_mask,
|
||||
return_dict=True,
|
||||
).sample
|
||||
|
||||
# Perform classifier-free guidance
|
||||
model_pred = model_pred_uncond + args.guidance_scale * (model_pred - model_pred_uncond)
|
||||
|
||||
# For training, we only compute the loss on the conditional prediction
|
||||
if is_train:
|
||||
model_pred = model_pred_uncond + args.guidance_scale * (model_pred - model_pred_uncond)
|
||||
|
||||
# Simple weighting - can be adjusted based on timestep if needed
|
||||
weighting = torch.ones_like(timesteps, dtype=weight_dtype, device=latents.device)
|
||||
|
||||
return model_pred, target, timesteps, weighting
|
||||
|
||||
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
||||
"""
|
||||
Post-process the loss value.
|
||||
This can include applying timestep weighting, gradient clipping, etc.
|
||||
"""
|
||||
# Apply timestep weighting if specified
|
||||
if hasattr(args, 'timestep_bias_portion') and args.timestep_bias_portion > 0.0:
|
||||
# Simple timestep weighting - can be made more sophisticated if needed
|
||||
weights = torch.ones_like(timesteps, dtype=torch.float32)
|
||||
if hasattr(args, 'timestep_bias_begin') and args.timestep_bias_begin > 0:
|
||||
mask = timesteps < args.timestep_bias_begin
|
||||
weights[mask] = 0.0
|
||||
if hasattr(args, 'timestep_bias_end') and args.timestep_bias_end < 1000:
|
||||
mask = timesteps > args.timestep_bias_end
|
||||
weights[mask] = 0.0
|
||||
if hasattr(args, 'timestep_bias_multiplier') and args.timestep_bias_multiplier != 1.0:
|
||||
weights = weights * args.timestep_bias_multiplier
|
||||
|
||||
loss = loss * weights.to(loss.device)
|
||||
|
||||
# Clip loss values if specified
|
||||
if hasattr(args, 'clip_grad_norm') and args.clip_grad_norm > 0.0:
|
||||
loss = torch.clamp(loss, -args.clip_grad_norm, args.clip_grad_norm)
|
||||
|
||||
return loss
|
||||
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
"""
|
||||
Prepare extra kwargs for the scheduler step, such as the generator for reproducibility.
|
||||
"""
|
||||
# Prepare extra step kwargs.
|
||||
# TODO: Logic should ideally just be moved to base class
|
||||
accepts_eta = "eta" in set(inspect.signature(self.noise_scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# Check if the scheduler accepts a generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.noise_scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
|
||||
return extra_step_kwargs
|
||||
|
||||
def update_metadata(self, metadata, args):
|
||||
metadata["ss_apply_attn_mask"] = args.apply_attn_mask
|
||||
metadata["ss_weighting_scheme"] = args.weighting_scheme
|
||||
metadata["ss_logit_mean"] = args.logit_mean
|
||||
metadata["ss_logit_std"] = args.logit_std
|
||||
metadata["ss_mode_scale"] = args.mode_scale
|
||||
metadata["ss_guidance_scale"] = args.guidance_scale
|
||||
metadata["ss_timestep_sampling"] = args.timestep_sampling
|
||||
metadata["ss_sigmoid_scale"] = args.sigmoid_scale
|
||||
metadata["ss_model_prediction_type"] = args.model_prediction_type
|
||||
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
|
||||
|
||||
def is_text_encoder_not_needed_for_training(self, args):
|
||||
"""Check if text encoder outputs are cached and not being trained."""
|
||||
return getattr(args, 'cache_text_encoder_outputs', False) and not getattr(args, 'train_text_encoder', False)
|
||||
|
||||
def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
|
||||
"""Prepare text encoder for gradient checkpointing.
|
||||
|
||||
For CogView4, we only have one text encoder (GLM) so we don't need index-based handling.
|
||||
The base class method handles enabling gradient checkpointing if args specify it.
|
||||
"""
|
||||
return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder)
|
||||
|
||||
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
|
||||
"""Prepare text encoder for FP8 training.
|
||||
|
||||
Args:
|
||||
index: Text encoder index (always 0 for CogView4)
|
||||
text_encoder: The text encoder model (GLM)
|
||||
te_weight_dtype: Target weight dtype for the encoder
|
||||
weight_dtype: Base weight dtype for embeddings
|
||||
"""
|
||||
if index != 0:
|
||||
logger.warning(f"Unexpected text encoder index {index} for CogView4, expecting 0.")
|
||||
# Still proceed, assuming it's the single GLM encoder
|
||||
|
||||
logger.info(f"Preparing GLM text encoder (index {index}) for {te_weight_dtype}, embeddings to {weight_dtype}")
|
||||
text_encoder.to(te_weight_dtype)
|
||||
|
||||
# Move embeddings to base weight dtype if they exist
|
||||
if hasattr(text_encoder, 'word_embeddings'): # GLM typically has word_embeddings
|
||||
text_encoder.word_embeddings.to(dtype=weight_dtype)
|
||||
if hasattr(text_encoder, 'position_embeddings'): # GLM might have position_embeddings
|
||||
text_encoder.position_embeddings.to(dtype=weight_dtype)
|
||||
# Add other relevant parts of GLM if they need specific dtype handling for FP8
|
||||
return text_encoder
|
||||
|
||||
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
|
||||
"""Called at the end of each validation step."""
|
||||
# No special handling needed for CogView4 (e.g., no block swapping)
|
||||
pass
|
||||
|
||||
def prepare_unet_with_accelerator(
|
||||
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
|
||||
) -> torch.nn.Module:
|
||||
"""Prepare UNet model (CogView4Transformer2DModel) with accelerator.
|
||||
|
||||
For CogView4, we use standard model preparation.
|
||||
"""
|
||||
return super().prepare_unet_with_accelerator(args, accelerator, unet)
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = train_network.setup_parser()
|
||||
train_util.add_dit_training_arguments(parser)
|
||||
flux_train_utils.add_flux_train_arguments(parser)
|
||||
return parser
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
train_util.verify_command_line_training_args(args)
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
trainer = CogView4NetworkTrainer()
|
||||
trainer.train(args)
|
||||
Reference in New Issue
Block a user