mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 00:49:40 +00:00
- Add --cdc_adaptive_k flag to enable adaptive k based on bucket size - Add --cdc_min_bucket_size to set minimum bucket threshold (default: 16) - Fixed mode (default): Skip buckets with < k_neighbors samples - Adaptive mode: Use k=min(k_neighbors, bucket_size-1) for buckets >= min_bucket_size - Update CDCPreprocessor to support adaptive k per bucket - Add metadata tracking for adaptive_k and min_bucket_size - Add comprehensive pytest tests for adaptive k behavior This allows CDC-FM to work effectively with multi-resolution bucketing where bucket sizes may vary widely. Users can choose between strict paper methodology (fixed k) or pragmatic approach (adaptive k).
628 lines
30 KiB
Python
628 lines
30 KiB
Python
import argparse
|
||
import copy
|
||
from typing import Any, Optional, Union
|
||
|
||
import torch
|
||
from accelerate import Accelerator
|
||
|
||
from library.device_utils import clean_memory_on_device, init_ipex
|
||
|
||
init_ipex()
|
||
|
||
import train_network
|
||
from library import (
|
||
flux_models,
|
||
flux_train_utils,
|
||
flux_utils,
|
||
sd3_train_utils,
|
||
strategy_base,
|
||
strategy_flux,
|
||
train_util,
|
||
)
|
||
from library.utils import setup_logging
|
||
|
||
setup_logging()
|
||
import logging
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||
def __init__(self):
|
||
super().__init__()
|
||
self.sample_prompts_te_outputs = None
|
||
self.is_schnell: Optional[bool] = None
|
||
self.is_swapping_blocks: bool = False
|
||
self.model_type: Optional[str] = None
|
||
self.gamma_b_dataset = None # CDC-FM Γ_b dataset
|
||
|
||
def assert_extra_args(
|
||
self,
|
||
args,
|
||
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
|
||
val_dataset_group: Optional[train_util.DatasetGroup],
|
||
):
|
||
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
|
||
# sdxl_train_util.verify_sdxl_training_args(args)
|
||
|
||
self.model_type = args.model_type # "flux" or "chroma"
|
||
if self.model_type != "chroma":
|
||
self.use_clip_l = True
|
||
else:
|
||
self.use_clip_l = False # Chroma does not use CLIP-L
|
||
assert args.apply_t5_attn_mask, "apply_t5_attn_mask must be True for Chroma / Chromaではapply_t5_attn_maskを指定する必要があります"
|
||
|
||
if args.fp8_base_unet:
|
||
args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for FLUX.1
|
||
|
||
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
|
||
logger.warning(
|
||
"cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります"
|
||
)
|
||
args.cache_text_encoder_outputs = True
|
||
|
||
if args.cache_text_encoder_outputs:
|
||
assert (
|
||
train_dataset_group.is_text_encoder_output_cacheable()
|
||
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
|
||
|
||
# prepare CLIP-L/T5XXL training flags
|
||
self.train_clip_l = not args.network_train_unet_only and self.use_clip_l
|
||
self.train_t5xxl = False # default is False even if args.network_train_unet_only is False
|
||
|
||
if args.max_token_length is not None:
|
||
logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません")
|
||
|
||
assert (
|
||
args.blocks_to_swap is None or args.blocks_to_swap == 0
|
||
) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません"
|
||
|
||
# deprecated split_mode option
|
||
if args.split_mode:
|
||
if args.blocks_to_swap is not None:
|
||
logger.warning(
|
||
"split_mode is deprecated. Because `--blocks_to_swap` is set, `--split_mode` is ignored."
|
||
" / split_modeは非推奨です。`--blocks_to_swap`が設定されているため、`--split_mode`は無視されます。"
|
||
)
|
||
else:
|
||
logger.warning(
|
||
"split_mode is deprecated. Please use `--blocks_to_swap` instead. `--blocks_to_swap 18` is automatically set."
|
||
" / split_modeは非推奨です。代わりに`--blocks_to_swap`を使用してください。`--blocks_to_swap 18`が自動的に設定されました。"
|
||
)
|
||
args.blocks_to_swap = 18 # 18 is safe for most cases
|
||
|
||
train_dataset_group.verify_bucket_reso_steps(32) # TODO check this
|
||
if val_dataset_group is not None:
|
||
val_dataset_group.verify_bucket_reso_steps(32) # TODO check this
|
||
|
||
def load_target_model(self, args, weight_dtype, accelerator):
|
||
# currently offload to cpu for some models
|
||
|
||
# if the file is fp8 and we are using fp8_base, we can load it as is (fp8)
|
||
loading_dtype = None if args.fp8_base else weight_dtype
|
||
|
||
# if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future
|
||
_, model = flux_utils.load_flow_model(
|
||
args.pretrained_model_name_or_path,
|
||
loading_dtype,
|
||
"cpu",
|
||
disable_mmap=args.disable_mmap_load_safetensors,
|
||
model_type=self.model_type,
|
||
)
|
||
if args.fp8_base:
|
||
# check dtype of model
|
||
if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz:
|
||
raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}")
|
||
elif model.dtype == torch.float8_e4m3fn:
|
||
logger.info("Loaded fp8 FLUX model")
|
||
else:
|
||
logger.info(
|
||
"Cast FLUX model to fp8. This may take a while. You can reduce the time by using fp8 checkpoint."
|
||
" / FLUXモデルをfp8に変換しています。これには時間がかかる場合があります。fp8チェックポイントを使用することで時間を短縮できます。"
|
||
)
|
||
model.to(torch.float8_e4m3fn)
|
||
|
||
# if args.split_mode:
|
||
# model = self.prepare_split_model(model, weight_dtype, accelerator)
|
||
|
||
self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
|
||
if self.is_swapping_blocks:
|
||
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
|
||
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
|
||
model.enable_block_swap(args.blocks_to_swap, accelerator.device)
|
||
|
||
if self.use_clip_l:
|
||
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
||
else:
|
||
clip_l = flux_utils.dummy_clip_l() # dummy CLIP-L for Chroma, which does not use CLIP-L
|
||
clip_l.eval()
|
||
|
||
# if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8)
|
||
if args.fp8_base and not args.fp8_base_unet:
|
||
loading_dtype = None # as is
|
||
else:
|
||
loading_dtype = weight_dtype
|
||
|
||
# loading t5xxl to cpu takes a long time, so we should load to gpu in future
|
||
t5xxl = flux_utils.load_t5xxl(args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
||
t5xxl.eval()
|
||
if args.fp8_base and not args.fp8_base_unet:
|
||
# check dtype of model
|
||
if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz:
|
||
raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}")
|
||
elif t5xxl.dtype == torch.float8_e4m3fn:
|
||
logger.info("Loaded fp8 T5XXL model")
|
||
|
||
ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
||
|
||
model_version = flux_utils.MODEL_VERSION_FLUX_V1 if self.model_type != "chroma" else flux_utils.MODEL_VERSION_CHROMA
|
||
return model_version, [clip_l, t5xxl], ae, model
|
||
|
||
def get_tokenize_strategy(self, args):
|
||
# This method is called before `assert_extra_args`, so we cannot use `self.is_schnell` here.
|
||
# Instead, we analyze the checkpoint state to determine if it is schnell.
|
||
if args.model_type != "chroma":
|
||
_, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
|
||
else:
|
||
is_schnell = False
|
||
self.is_schnell = is_schnell
|
||
|
||
if args.t5xxl_max_token_length is None:
|
||
if self.is_schnell:
|
||
t5xxl_max_token_length = 256
|
||
else:
|
||
t5xxl_max_token_length = 512
|
||
else:
|
||
t5xxl_max_token_length = args.t5xxl_max_token_length
|
||
|
||
logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}")
|
||
return strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir)
|
||
|
||
def get_tokenizers(self, tokenize_strategy: strategy_flux.FluxTokenizeStrategy):
|
||
return [tokenize_strategy.clip_l, tokenize_strategy.t5xxl]
|
||
|
||
def get_latents_caching_strategy(self, args):
|
||
latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False)
|
||
return latents_caching_strategy
|
||
|
||
def get_text_encoding_strategy(self, args):
|
||
return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask)
|
||
|
||
def post_process_network(self, args, accelerator, network, text_encoders, unet):
|
||
# check t5xxl is trained or not
|
||
self.train_t5xxl = network.train_t5xxl
|
||
|
||
if self.train_t5xxl and args.cache_text_encoder_outputs:
|
||
raise ValueError(
|
||
"T5XXL is trained, so cache_text_encoder_outputs cannot be used / T5XXL学習時はcache_text_encoder_outputsは使用できません"
|
||
)
|
||
|
||
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
|
||
if args.cache_text_encoder_outputs:
|
||
if self.train_clip_l and not self.train_t5xxl:
|
||
return text_encoders[0:1] # only CLIP-L is needed for encoding because T5XXL is cached
|
||
else:
|
||
return None # no text encoders are needed for encoding because both are cached
|
||
else:
|
||
return text_encoders # both CLIP-L and T5XXL are needed for encoding
|
||
|
||
def get_text_encoders_train_flags(self, args, text_encoders):
|
||
return [self.train_clip_l, self.train_t5xxl]
|
||
|
||
def get_text_encoder_outputs_caching_strategy(self, args):
|
||
if args.cache_text_encoder_outputs:
|
||
# if the text encoders is trained, we need tokenization, so is_partial is True
|
||
return strategy_flux.FluxTextEncoderOutputsCachingStrategy(
|
||
args.cache_text_encoder_outputs_to_disk,
|
||
args.text_encoder_batch_size,
|
||
args.skip_cache_check,
|
||
is_partial=self.train_clip_l or self.train_t5xxl,
|
||
apply_t5_attn_mask=args.apply_t5_attn_mask,
|
||
)
|
||
else:
|
||
return None
|
||
|
||
def cache_text_encoder_outputs_if_needed(
|
||
self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
|
||
):
|
||
if args.cache_text_encoder_outputs:
|
||
if not args.lowram:
|
||
# メモリ消費を減らす
|
||
logger.info("move vae and unet to cpu to save memory")
|
||
org_vae_device = vae.device
|
||
org_unet_device = unet.device
|
||
vae.to("cpu")
|
||
unet.to("cpu")
|
||
clean_memory_on_device(accelerator.device)
|
||
|
||
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
|
||
logger.info("move text encoders to gpu")
|
||
text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8
|
||
text_encoders[1].to(accelerator.device)
|
||
|
||
if text_encoders[1].dtype == torch.float8_e4m3fn:
|
||
# if we load fp8 weights, the model is already fp8, so we use it as is
|
||
self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype)
|
||
else:
|
||
# otherwise, we need to convert it to target dtype
|
||
text_encoders[1].to(weight_dtype)
|
||
|
||
with accelerator.autocast():
|
||
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
|
||
|
||
# cache sample prompts
|
||
if args.sample_prompts is not None:
|
||
logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}")
|
||
|
||
tokenize_strategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
|
||
text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
|
||
|
||
prompts = train_util.load_prompts(args.sample_prompts)
|
||
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
|
||
with accelerator.autocast(), torch.no_grad():
|
||
for prompt_dict in prompts:
|
||
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
|
||
if p not in sample_prompts_te_outputs:
|
||
logger.info(f"cache Text Encoder outputs for prompt: {p}")
|
||
tokens_and_masks = tokenize_strategy.tokenize(p)
|
||
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
|
||
tokenize_strategy, text_encoders, tokens_and_masks, args.apply_t5_attn_mask
|
||
)
|
||
self.sample_prompts_te_outputs = sample_prompts_te_outputs
|
||
|
||
accelerator.wait_for_everyone()
|
||
|
||
# move back to cpu
|
||
if not self.is_train_text_encoder(args):
|
||
logger.info("move CLIP-L back to cpu")
|
||
text_encoders[0].to("cpu")
|
||
logger.info("move t5XXL back to cpu")
|
||
text_encoders[1].to("cpu")
|
||
clean_memory_on_device(accelerator.device)
|
||
|
||
if not args.lowram:
|
||
logger.info("move vae and unet back to original device")
|
||
vae.to(org_vae_device)
|
||
unet.to(org_unet_device)
|
||
else:
|
||
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
|
||
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
|
||
text_encoders[1].to(accelerator.device)
|
||
|
||
def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux):
|
||
text_encoders = text_encoder # for compatibility
|
||
text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders)
|
||
|
||
flux_train_utils.sample_images(
|
||
accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs
|
||
)
|
||
|
||
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
|
||
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
|
||
self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
|
||
return noise_scheduler
|
||
|
||
def encode_images_to_latents(self, args, vae, images):
|
||
return vae.encode(images)
|
||
|
||
def shift_scale_latents(self, args, latents):
|
||
return latents
|
||
|
||
def get_noise_pred_and_target(
|
||
self,
|
||
args,
|
||
accelerator,
|
||
noise_scheduler,
|
||
latents,
|
||
batch,
|
||
text_encoder_conds,
|
||
unet: flux_models.Flux,
|
||
network,
|
||
weight_dtype,
|
||
train_unet,
|
||
is_train=True,
|
||
):
|
||
# Sample noise that we'll add to the latents
|
||
noise = torch.randn_like(latents)
|
||
bsz = latents.shape[0]
|
||
|
||
# Get CDC parameters if enabled
|
||
gamma_b_dataset = self.gamma_b_dataset if (self.gamma_b_dataset is not None and "image_keys" in batch) else None
|
||
image_keys = batch.get("image_keys") if gamma_b_dataset is not None else None
|
||
|
||
# Get noisy model input and timesteps
|
||
# If CDC is enabled, this will transform the noise with geometry-aware covariance
|
||
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
|
||
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype,
|
||
gamma_b_dataset=gamma_b_dataset, image_keys=image_keys
|
||
)
|
||
|
||
# pack latents and get img_ids
|
||
packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
|
||
packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
|
||
img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device)
|
||
|
||
# get guidance
|
||
# ensure guidance_scale in args is float
|
||
guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device)
|
||
|
||
# get modulation vectors for Chroma
|
||
with accelerator.autocast(), torch.no_grad():
|
||
mod_vectors = unet.get_mod_vectors(timesteps=timesteps / 1000, guidance=guidance_vec, batch_size=bsz)
|
||
|
||
if args.gradient_checkpointing:
|
||
noisy_model_input.requires_grad_(True)
|
||
for t in text_encoder_conds:
|
||
if t is not None and t.dtype.is_floating_point:
|
||
t.requires_grad_(True)
|
||
img_ids.requires_grad_(True)
|
||
guidance_vec.requires_grad_(True)
|
||
if mod_vectors is not None:
|
||
mod_vectors.requires_grad_(True)
|
||
|
||
# Predict the noise residual
|
||
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
|
||
if not args.apply_t5_attn_mask:
|
||
t5_attn_mask = None
|
||
|
||
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask, mod_vectors):
|
||
# grad is enabled even if unet is not in train mode, because Text Encoder is in train mode
|
||
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
||
model_pred = unet(
|
||
img=img,
|
||
img_ids=img_ids,
|
||
txt=t5_out,
|
||
txt_ids=txt_ids,
|
||
y=l_pooled,
|
||
timesteps=timesteps / 1000,
|
||
guidance=guidance_vec,
|
||
txt_attention_mask=t5_attn_mask,
|
||
mod_vectors=mod_vectors,
|
||
)
|
||
return model_pred
|
||
|
||
model_pred = call_dit(
|
||
img=packed_noisy_model_input,
|
||
img_ids=img_ids,
|
||
t5_out=t5_out,
|
||
txt_ids=txt_ids,
|
||
l_pooled=l_pooled,
|
||
timesteps=timesteps,
|
||
guidance_vec=guidance_vec,
|
||
t5_attn_mask=t5_attn_mask,
|
||
mod_vectors=mod_vectors,
|
||
)
|
||
|
||
# unpack latents
|
||
model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
|
||
|
||
# apply model prediction type
|
||
model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
|
||
|
||
# flow matching loss: this is different from SD3
|
||
target = noise - latents
|
||
|
||
# differential output preservation
|
||
if "custom_attributes" in batch:
|
||
diff_output_pr_indices = []
|
||
for i, custom_attributes in enumerate(batch["custom_attributes"]):
|
||
if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]:
|
||
diff_output_pr_indices.append(i)
|
||
|
||
if len(diff_output_pr_indices) > 0:
|
||
network.set_multiplier(0.0)
|
||
unet.prepare_block_swap_before_forward()
|
||
with torch.no_grad():
|
||
model_pred_prior = call_dit(
|
||
img=packed_noisy_model_input[diff_output_pr_indices],
|
||
img_ids=img_ids[diff_output_pr_indices],
|
||
t5_out=t5_out[diff_output_pr_indices],
|
||
txt_ids=txt_ids[diff_output_pr_indices],
|
||
l_pooled=l_pooled[diff_output_pr_indices],
|
||
timesteps=timesteps[diff_output_pr_indices],
|
||
guidance_vec=guidance_vec[diff_output_pr_indices] if guidance_vec is not None else None,
|
||
t5_attn_mask=t5_attn_mask[diff_output_pr_indices] if t5_attn_mask is not None else None,
|
||
mod_vectors=mod_vectors[diff_output_pr_indices] if mod_vectors is not None else None,
|
||
)
|
||
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
|
||
|
||
model_pred_prior = flux_utils.unpack_latents(model_pred_prior, packed_latent_height, packed_latent_width)
|
||
model_pred_prior, _ = flux_train_utils.apply_model_prediction_type(
|
||
args,
|
||
model_pred_prior,
|
||
noisy_model_input[diff_output_pr_indices],
|
||
sigmas[diff_output_pr_indices] if sigmas is not None else None,
|
||
)
|
||
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
|
||
|
||
return model_pred, target, timesteps, weighting
|
||
|
||
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
||
return loss
|
||
|
||
def get_sai_model_spec(self, args):
|
||
if self.model_type != "chroma":
|
||
model_description = "schnell" if self.is_schnell else "dev"
|
||
else:
|
||
model_description = "chroma"
|
||
return train_util.get_sai_model_spec(None, args, False, True, False, flux=model_description)
|
||
|
||
def update_metadata(self, metadata, args):
|
||
metadata["ss_model_type"] = args.model_type
|
||
metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask
|
||
metadata["ss_weighting_scheme"] = args.weighting_scheme
|
||
metadata["ss_logit_mean"] = args.logit_mean
|
||
metadata["ss_logit_std"] = args.logit_std
|
||
metadata["ss_mode_scale"] = args.mode_scale
|
||
metadata["ss_guidance_scale"] = args.guidance_scale
|
||
metadata["ss_timestep_sampling"] = args.timestep_sampling
|
||
metadata["ss_sigmoid_scale"] = args.sigmoid_scale
|
||
metadata["ss_model_prediction_type"] = args.model_prediction_type
|
||
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
|
||
|
||
# CDC-FM metadata
|
||
metadata["ss_use_cdc_fm"] = getattr(args, "use_cdc_fm", False)
|
||
metadata["ss_cdc_k_neighbors"] = getattr(args, "cdc_k_neighbors", None)
|
||
metadata["ss_cdc_k_bandwidth"] = getattr(args, "cdc_k_bandwidth", None)
|
||
metadata["ss_cdc_d_cdc"] = getattr(args, "cdc_d_cdc", None)
|
||
metadata["ss_cdc_gamma"] = getattr(args, "cdc_gamma", None)
|
||
metadata["ss_cdc_adaptive_k"] = getattr(args, "cdc_adaptive_k", None)
|
||
metadata["ss_cdc_min_bucket_size"] = getattr(args, "cdc_min_bucket_size", None)
|
||
|
||
def is_text_encoder_not_needed_for_training(self, args):
|
||
return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
|
||
|
||
def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
|
||
if index == 0: # CLIP-L
|
||
return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder)
|
||
else: # T5XXL
|
||
text_encoder.encoder.embed_tokens.requires_grad_(True)
|
||
|
||
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
|
||
if index == 0: # CLIP-L
|
||
logger.info(f"prepare CLIP-L for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}")
|
||
text_encoder.to(te_weight_dtype) # fp8
|
||
text_encoder.text_model.embeddings.to(dtype=weight_dtype)
|
||
else: # T5XXL
|
||
|
||
def prepare_fp8(text_encoder, target_dtype):
|
||
def forward_hook(module):
|
||
def forward(hidden_states):
|
||
hidden_gelu = module.act(module.wi_0(hidden_states))
|
||
hidden_linear = module.wi_1(hidden_states)
|
||
hidden_states = hidden_gelu * hidden_linear
|
||
hidden_states = module.dropout(hidden_states)
|
||
|
||
hidden_states = module.wo(hidden_states)
|
||
return hidden_states
|
||
|
||
return forward
|
||
|
||
for module in text_encoder.modules():
|
||
if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]:
|
||
# print("set", module.__class__.__name__, "to", target_dtype)
|
||
module.to(target_dtype)
|
||
if module.__class__.__name__ in ["T5DenseGatedActDense"]:
|
||
# print("set", module.__class__.__name__, "hooks")
|
||
module.forward = forward_hook(module)
|
||
|
||
if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype:
|
||
logger.info("T5XXL already prepared for fp8")
|
||
else:
|
||
logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks")
|
||
text_encoder.to(te_weight_dtype) # fp8
|
||
prepare_fp8(text_encoder, weight_dtype)
|
||
|
||
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
|
||
if self.is_swapping_blocks:
|
||
# prepare for next forward: because backward pass is not called, we need to prepare it here
|
||
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
|
||
|
||
def prepare_unet_with_accelerator(
|
||
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
|
||
) -> torch.nn.Module:
|
||
if not self.is_swapping_blocks:
|
||
return super().prepare_unet_with_accelerator(args, accelerator, unet)
|
||
|
||
# if we doesn't swap blocks, we can move the model to device
|
||
flux: flux_models.Flux = unet
|
||
flux = accelerator.prepare(flux, device_placement=[not self.is_swapping_blocks])
|
||
accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
|
||
accelerator.unwrap_model(flux).prepare_block_swap_before_forward()
|
||
|
||
return flux
|
||
|
||
|
||
def setup_parser() -> argparse.ArgumentParser:
|
||
parser = train_network.setup_parser()
|
||
train_util.add_dit_training_arguments(parser)
|
||
flux_train_utils.add_flux_train_arguments(parser)
|
||
|
||
parser.add_argument(
|
||
"--split_mode",
|
||
action="store_true",
|
||
# help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required"
|
||
# + "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要",
|
||
help="[Deprecated] This option is deprecated. Please use `--blocks_to_swap` instead."
|
||
" / このオプションは非推奨です。代わりに`--blocks_to_swap`を使用してください。",
|
||
)
|
||
|
||
# CDC-FM arguments
|
||
parser.add_argument(
|
||
"--use_cdc_fm",
|
||
action="store_true",
|
||
help="Enable CDC-FM (Carré du Champ Flow Matching) for geometry-aware noise during training"
|
||
" / CDC-FM(Carré du Champ Flow Matching)を有効にして幾何学的ノイズを使用",
|
||
)
|
||
parser.add_argument(
|
||
"--cdc_k_neighbors",
|
||
type=int,
|
||
default=256,
|
||
help="Number of neighbors for k-NN graph in CDC-FM (default: 256)"
|
||
" / CDC-FMのk-NNグラフの近傍数(デフォルト: 256)",
|
||
)
|
||
parser.add_argument(
|
||
"--cdc_k_bandwidth",
|
||
type=int,
|
||
default=8,
|
||
help="Number of neighbors for bandwidth estimation in CDC-FM (default: 8)"
|
||
" / CDC-FMの帯域幅推定の近傍数(デフォルト: 8)",
|
||
)
|
||
parser.add_argument(
|
||
"--cdc_d_cdc",
|
||
type=int,
|
||
default=8,
|
||
help="Dimension of CDC subspace (default: 8)"
|
||
" / CDCサブ空間の次元(デフォルト: 8)",
|
||
)
|
||
parser.add_argument(
|
||
"--cdc_gamma",
|
||
type=float,
|
||
default=1.0,
|
||
help="CDC strength parameter (default: 1.0)"
|
||
" / CDC強度パラメータ(デフォルト: 1.0)",
|
||
)
|
||
parser.add_argument(
|
||
"--force_recache_cdc",
|
||
action="store_true",
|
||
help="Force recompute CDC cache even if valid cache exists"
|
||
" / 有効なCDCキャッシュが存在してもCDCキャッシュを再計算",
|
||
)
|
||
parser.add_argument(
|
||
"--cdc_debug",
|
||
action="store_true",
|
||
help="Enable verbose CDC debug output showing bucket details"
|
||
" / CDCの詳細デバッグ出力を有効化(バケット詳細表示)",
|
||
)
|
||
parser.add_argument(
|
||
"--cdc_adaptive_k",
|
||
action="store_true",
|
||
help="Use adaptive k_neighbors based on bucket size. If enabled, buckets smaller than k_neighbors will use "
|
||
"k=bucket_size-1 instead of skipping CDC entirely. Buckets smaller than cdc_min_bucket_size are still skipped."
|
||
" / バケットサイズに基づいてk_neighborsを適応的に調整。有効にすると、k_neighbors未満のバケットは"
|
||
"CDCをスキップせずk=バケットサイズ-1を使用。cdc_min_bucket_size未満のバケットは引き続きスキップ。",
|
||
)
|
||
parser.add_argument(
|
||
"--cdc_min_bucket_size",
|
||
type=int,
|
||
default=16,
|
||
help="Minimum bucket size for CDC computation. Buckets with fewer samples will use standard Gaussian noise. "
|
||
"Only relevant when --cdc_adaptive_k is enabled (default: 16)"
|
||
" / CDC計算の最小バケットサイズ。これより少ないサンプルのバケットは標準ガウスノイズを使用。"
|
||
"--cdc_adaptive_k有効時のみ関連(デフォルト: 16)",
|
||
)
|
||
|
||
return parser
|
||
|
||
|
||
if __name__ == "__main__":
|
||
parser = setup_parser()
|
||
|
||
args = parser.parse_args()
|
||
train_util.verify_command_line_training_args(args)
|
||
args = train_util.read_config_from_file(args, parser)
|
||
|
||
trainer = FluxNetworkTrainer()
|
||
trainer.train(args)
|