mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
feat: add LoRA training support for Chroma
This commit is contained in:
@@ -35,6 +35,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
self.sample_prompts_te_outputs = None
|
||||
self.is_schnell: Optional[bool] = None
|
||||
self.is_swapping_blocks: bool = False
|
||||
self.model_type: Optional[str] = None
|
||||
|
||||
def assert_extra_args(
|
||||
self,
|
||||
@@ -45,6 +46,12 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
|
||||
# sdxl_train_util.verify_sdxl_training_args(args)
|
||||
|
||||
self.model_type = args.model_type # "flux" or "chroma"
|
||||
if self.model_type != "chroma":
|
||||
self.use_clip_l = True
|
||||
else:
|
||||
self.use_clip_l = False # Chroma does not use CLIP-L
|
||||
|
||||
if args.fp8_base_unet:
|
||||
args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for FLUX.1
|
||||
|
||||
@@ -60,7 +67,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
|
||||
|
||||
# prepare CLIP-L/T5XXL training flags
|
||||
self.train_clip_l = not args.network_train_unet_only
|
||||
self.train_clip_l = not args.network_train_unet_only and self.use_clip_l
|
||||
self.train_t5xxl = False # default is False even if args.network_train_unet_only is False
|
||||
|
||||
if args.max_token_length is not None:
|
||||
@@ -95,8 +102,12 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
loading_dtype = None if args.fp8_base else weight_dtype
|
||||
|
||||
# if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future
|
||||
self.model_type, self.is_schnell, model = flux_utils.load_flow_model(
|
||||
args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, model_type="flux"
|
||||
_, model = flux_utils.load_flow_model(
|
||||
args.pretrained_model_name_or_path,
|
||||
loading_dtype,
|
||||
"cpu",
|
||||
disable_mmap=args.disable_mmap_load_safetensors,
|
||||
model_type=self.model_type,
|
||||
)
|
||||
if args.fp8_base:
|
||||
# check dtype of model
|
||||
@@ -120,7 +131,10 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
|
||||
model.enable_block_swap(args.blocks_to_swap, accelerator.device)
|
||||
|
||||
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
||||
if self.use_clip_l:
|
||||
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
||||
else:
|
||||
clip_l = flux_utils.dummy_clip_l() # dummy CLIP-L for Chroma, which does not use CLIP-L
|
||||
clip_l.eval()
|
||||
|
||||
# if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8)
|
||||
@@ -141,13 +155,20 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
|
||||
ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
||||
|
||||
return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model
|
||||
model_version = flux_utils.MODEL_VERSION_FLUX_V1 if self.model_type != "chroma" else flux_utils.MODEL_VERSION_CHROMA
|
||||
return model_version, [clip_l, t5xxl], ae, model
|
||||
|
||||
def get_tokenize_strategy(self, args):
|
||||
_, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
|
||||
# This method is called before `assert_extra_args`, so we cannot use `self.is_schnell` here.
|
||||
# Instead, we analyze the checkpoint state to determine if it is schnell.
|
||||
if args.model_type != "chroma":
|
||||
_, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
|
||||
else:
|
||||
is_schnell = False
|
||||
self.is_schnell = is_schnell
|
||||
|
||||
if args.t5xxl_max_token_length is None:
|
||||
if is_schnell:
|
||||
if self.is_schnell:
|
||||
t5xxl_max_token_length = 256
|
||||
else:
|
||||
t5xxl_max_token_length = 512
|
||||
@@ -268,23 +289,6 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoders[1].to(accelerator.device)
|
||||
|
||||
# def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
|
||||
# noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
||||
|
||||
# # get size embeddings
|
||||
# orig_size = batch["original_sizes_hw"]
|
||||
# crop_size = batch["crop_top_lefts"]
|
||||
# target_size = batch["target_sizes_hw"]
|
||||
# embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
|
||||
|
||||
# # concat embeddings
|
||||
# encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds
|
||||
# vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
|
||||
# text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
|
||||
|
||||
# noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
|
||||
# return noise_pred
|
||||
|
||||
def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux):
|
||||
text_encoders = text_encoder # for compatibility
|
||||
text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders)
|
||||
@@ -292,36 +296,6 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
flux_train_utils.sample_images(
|
||||
accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs
|
||||
)
|
||||
# return
|
||||
|
||||
"""
|
||||
class FluxUpperLowerWrapper(torch.nn.Module):
|
||||
def __init__(self, flux_upper: flux_models.FluxUpper, flux_lower: flux_models.FluxLower, device: torch.device):
|
||||
super().__init__()
|
||||
self.flux_upper = flux_upper
|
||||
self.flux_lower = flux_lower
|
||||
self.target_device = device
|
||||
|
||||
def prepare_block_swap_before_forward(self):
|
||||
pass
|
||||
|
||||
def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_attention_mask=None):
|
||||
self.flux_lower.to("cpu")
|
||||
clean_memory_on_device(self.target_device)
|
||||
self.flux_upper.to(self.target_device)
|
||||
img, txt, vec, pe = self.flux_upper(img, img_ids, txt, txt_ids, timesteps, y, guidance, txt_attention_mask)
|
||||
self.flux_upper.to("cpu")
|
||||
clean_memory_on_device(self.target_device)
|
||||
self.flux_lower.to(self.target_device)
|
||||
return self.flux_lower(img, txt, vec, pe, txt_attention_mask)
|
||||
|
||||
wrapper = FluxUpperLowerWrapper(self.flux_upper, flux, accelerator.device)
|
||||
clean_memory_on_device(accelerator.device)
|
||||
flux_train_utils.sample_images(
|
||||
accelerator, args, epoch, global_step, wrapper, ae, text_encoders, self.sample_prompts_te_outputs
|
||||
)
|
||||
clean_memory_on_device(accelerator.device)
|
||||
"""
|
||||
|
||||
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
|
||||
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
|
||||
@@ -366,7 +340,11 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
# ensure guidance_scale in args is float
|
||||
guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device)
|
||||
|
||||
# ensure the hidden state will require grad
|
||||
# get modulation vectors for Chroma
|
||||
input_vec = None
|
||||
if self.model_type == "chroma":
|
||||
input_vec = unet.get_input_vec(timesteps=timesteps, guidance=guidance_vec, batch_size=bsz)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
noisy_model_input.requires_grad_(True)
|
||||
for t in text_encoder_conds:
|
||||
@@ -374,13 +352,15 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
t.requires_grad_(True)
|
||||
img_ids.requires_grad_(True)
|
||||
guidance_vec.requires_grad_(True)
|
||||
if input_vec is not None:
|
||||
input_vec.requires_grad_(True)
|
||||
|
||||
# Predict the noise residual
|
||||
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
|
||||
if not args.apply_t5_attn_mask:
|
||||
t5_attn_mask = None
|
||||
|
||||
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
|
||||
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask, input_vec):
|
||||
# grad is enabled even if unet is not in train mode, because Text Encoder is in train mode
|
||||
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||
@@ -393,6 +373,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
timesteps=timesteps / 1000,
|
||||
guidance=guidance_vec,
|
||||
txt_attention_mask=t5_attn_mask,
|
||||
input_vec=input_vec,
|
||||
)
|
||||
return model_pred
|
||||
|
||||
@@ -405,6 +386,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
timesteps=timesteps,
|
||||
guidance_vec=guidance_vec,
|
||||
t5_attn_mask=t5_attn_mask,
|
||||
input_vec=input_vec,
|
||||
)
|
||||
|
||||
# unpack latents
|
||||
@@ -436,6 +418,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
timesteps=timesteps[diff_output_pr_indices],
|
||||
guidance_vec=guidance_vec[diff_output_pr_indices] if guidance_vec is not None else None,
|
||||
t5_attn_mask=t5_attn_mask[diff_output_pr_indices] if t5_attn_mask is not None else None,
|
||||
input_vec=input_vec[diff_output_pr_indices] if input_vec is not None else None,
|
||||
)
|
||||
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
|
||||
|
||||
@@ -454,9 +437,14 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
return loss
|
||||
|
||||
def get_sai_model_spec(self, args):
|
||||
return train_util.get_sai_model_spec(None, args, False, True, False, flux="dev")
|
||||
if self.model_type != "chroma":
|
||||
model_description = "schnell" if self.is_schnell else "dev"
|
||||
else:
|
||||
model_description = "chroma"
|
||||
return train_util.get_sai_model_spec(None, args, False, True, False, flux=model_description)
|
||||
|
||||
def update_metadata(self, metadata, args):
|
||||
metadata["ss_model_type"] = args.model_type
|
||||
metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask
|
||||
metadata["ss_weighting_scheme"] = args.weighting_scheme
|
||||
metadata["ss_logit_mean"] = args.logit_mean
|
||||
|
||||
Reference in New Issue
Block a user