feat: add LoRA training support for Chroma

This commit is contained in:
Kohya S
2025-07-20 19:00:09 +09:00
parent c4958b5dca
commit b4e862626a
10 changed files with 158 additions and 260 deletions

View File

@@ -35,6 +35,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
self.sample_prompts_te_outputs = None
self.is_schnell: Optional[bool] = None
self.is_swapping_blocks: bool = False
self.model_type: Optional[str] = None
def assert_extra_args(
self,
@@ -45,6 +46,12 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
# sdxl_train_util.verify_sdxl_training_args(args)
self.model_type = args.model_type # "flux" or "chroma"
if self.model_type != "chroma":
self.use_clip_l = True
else:
self.use_clip_l = False # Chroma does not use CLIP-L
if args.fp8_base_unet:
args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for FLUX.1
@@ -60,7 +67,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
# prepare CLIP-L/T5XXL training flags
self.train_clip_l = not args.network_train_unet_only
self.train_clip_l = not args.network_train_unet_only and self.use_clip_l
self.train_t5xxl = False # default is False even if args.network_train_unet_only is False
if args.max_token_length is not None:
@@ -95,8 +102,12 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
loading_dtype = None if args.fp8_base else weight_dtype
# if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future
self.model_type, self.is_schnell, model = flux_utils.load_flow_model(
args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, model_type="flux"
_, model = flux_utils.load_flow_model(
args.pretrained_model_name_or_path,
loading_dtype,
"cpu",
disable_mmap=args.disable_mmap_load_safetensors,
model_type=self.model_type,
)
if args.fp8_base:
# check dtype of model
@@ -120,7 +131,10 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
model.enable_block_swap(args.blocks_to_swap, accelerator.device)
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
if self.use_clip_l:
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
else:
clip_l = flux_utils.dummy_clip_l() # dummy CLIP-L for Chroma, which does not use CLIP-L
clip_l.eval()
# if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8)
@@ -141,13 +155,20 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model
model_version = flux_utils.MODEL_VERSION_FLUX_V1 if self.model_type != "chroma" else flux_utils.MODEL_VERSION_CHROMA
return model_version, [clip_l, t5xxl], ae, model
def get_tokenize_strategy(self, args):
_, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
# This method is called before `assert_extra_args`, so we cannot use `self.is_schnell` here.
# Instead, we analyze the checkpoint state to determine if it is schnell.
if args.model_type != "chroma":
_, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
else:
is_schnell = False
self.is_schnell = is_schnell
if args.t5xxl_max_token_length is None:
if is_schnell:
if self.is_schnell:
t5xxl_max_token_length = 256
else:
t5xxl_max_token_length = 512
@@ -268,23 +289,6 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
text_encoders[1].to(accelerator.device)
# def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
# noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
# # get size embeddings
# orig_size = batch["original_sizes_hw"]
# crop_size = batch["crop_top_lefts"]
# target_size = batch["target_sizes_hw"]
# embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
# # concat embeddings
# encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds
# vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
# text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
# noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
# return noise_pred
def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux):
text_encoders = text_encoder # for compatibility
text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders)
@@ -292,36 +296,6 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
flux_train_utils.sample_images(
accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs
)
# return
"""
class FluxUpperLowerWrapper(torch.nn.Module):
def __init__(self, flux_upper: flux_models.FluxUpper, flux_lower: flux_models.FluxLower, device: torch.device):
super().__init__()
self.flux_upper = flux_upper
self.flux_lower = flux_lower
self.target_device = device
def prepare_block_swap_before_forward(self):
pass
def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_attention_mask=None):
self.flux_lower.to("cpu")
clean_memory_on_device(self.target_device)
self.flux_upper.to(self.target_device)
img, txt, vec, pe = self.flux_upper(img, img_ids, txt, txt_ids, timesteps, y, guidance, txt_attention_mask)
self.flux_upper.to("cpu")
clean_memory_on_device(self.target_device)
self.flux_lower.to(self.target_device)
return self.flux_lower(img, txt, vec, pe, txt_attention_mask)
wrapper = FluxUpperLowerWrapper(self.flux_upper, flux, accelerator.device)
clean_memory_on_device(accelerator.device)
flux_train_utils.sample_images(
accelerator, args, epoch, global_step, wrapper, ae, text_encoders, self.sample_prompts_te_outputs
)
clean_memory_on_device(accelerator.device)
"""
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
@@ -366,7 +340,11 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
# ensure guidance_scale in args is float
guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device)
# ensure the hidden state will require grad
# get modulation vectors for Chroma
input_vec = None
if self.model_type == "chroma":
input_vec = unet.get_input_vec(timesteps=timesteps, guidance=guidance_vec, batch_size=bsz)
if args.gradient_checkpointing:
noisy_model_input.requires_grad_(True)
for t in text_encoder_conds:
@@ -374,13 +352,15 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
t.requires_grad_(True)
img_ids.requires_grad_(True)
guidance_vec.requires_grad_(True)
if input_vec is not None:
input_vec.requires_grad_(True)
# Predict the noise residual
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
if not args.apply_t5_attn_mask:
t5_attn_mask = None
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask, input_vec):
# grad is enabled even if unet is not in train mode, because Text Encoder is in train mode
with torch.set_grad_enabled(is_train), accelerator.autocast():
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
@@ -393,6 +373,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
timesteps=timesteps / 1000,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
input_vec=input_vec,
)
return model_pred
@@ -405,6 +386,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
timesteps=timesteps,
guidance_vec=guidance_vec,
t5_attn_mask=t5_attn_mask,
input_vec=input_vec,
)
# unpack latents
@@ -436,6 +418,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
timesteps=timesteps[diff_output_pr_indices],
guidance_vec=guidance_vec[diff_output_pr_indices] if guidance_vec is not None else None,
t5_attn_mask=t5_attn_mask[diff_output_pr_indices] if t5_attn_mask is not None else None,
input_vec=input_vec[diff_output_pr_indices] if input_vec is not None else None,
)
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
@@ -454,9 +437,14 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
return loss
def get_sai_model_spec(self, args):
return train_util.get_sai_model_spec(None, args, False, True, False, flux="dev")
if self.model_type != "chroma":
model_description = "schnell" if self.is_schnell else "dev"
else:
model_description = "chroma"
return train_util.get_sai_model_spec(None, args, False, True, False, flux=model_description)
def update_metadata(self, metadata, args):
metadata["ss_model_type"] = args.model_type
metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask
metadata["ss_weighting_scheme"] = args.weighting_scheme
metadata["ss_logit_mean"] = args.logit_mean