From 24d2ea86c70482ec062412e4214ae221a22cd0a0 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Sun, 20 Jul 2025 12:56:42 +0900 Subject: [PATCH] feat: support Chroma model in loading and inference processes --- flux_minimal_inference.py | 49 +++++++++++------ flux_train.py | 4 +- flux_train_control_net.py | 4 +- flux_train_network.py | 4 +- library/chroma_models.py | 85 +++++------------------------ library/flux_utils.py | 110 +++++++++++++++++++++++++------------- 6 files changed, 123 insertions(+), 133 deletions(-) diff --git a/flux_minimal_inference.py b/flux_minimal_inference.py index 7ab224f1..a7bff74d 100644 --- a/flux_minimal_inference.py +++ b/flux_minimal_inference.py @@ -108,12 +108,18 @@ def denoise( else: b_img = img + # For Chroma model, y might be None, so create dummy tensor + if b_vec is None: + y_input = torch.zeros_like(b_txt[:, :1, :]) # dummy tensor + else: + y_input = b_vec + pred = model( img=b_img, img_ids=b_img_ids, txt=b_txt, txt_ids=b_txt_ids, - y=b_vec, + y=y_input, timesteps=t_vec, guidance=guidance_vec, txt_attention_mask=b_t5_attn_mask, @@ -134,7 +140,7 @@ def do_sample( model: flux_models.Flux, img: torch.Tensor, img_ids: torch.Tensor, - l_pooled: torch.Tensor, + l_pooled: Optional[torch.Tensor], t5_out: torch.Tensor, txt_ids: torch.Tensor, num_steps: int, @@ -192,7 +198,7 @@ def do_sample( def generate_image( model, - clip_l: CLIPTextModel, + clip_l: Optional[CLIPTextModel], t5xxl, ae, prompt: str, @@ -231,7 +237,7 @@ def generate_image( img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width) # prepare fp8 models - if is_fp8(clip_l_dtype) and (not hasattr(clip_l, "fp8_prepared") or not clip_l.fp8_prepared): + if clip_l is not None and is_fp8(clip_l_dtype) and (not hasattr(clip_l, "fp8_prepared") or not clip_l.fp8_prepared): logger.info(f"prepare CLIP-L for fp8: set to {clip_l_dtype}, set embeddings to {torch.bfloat16}") clip_l.to(clip_l_dtype) # fp8 clip_l.text_model.embeddings.to(dtype=torch.bfloat16) @@ -267,18 +273,22 @@ def generate_image( # prepare embeddings logger.info("Encoding prompts...") - clip_l = clip_l.to(device) + if clip_l is not None: + clip_l = clip_l.to(device) t5xxl = t5xxl.to(device) def encode(prpt: str): tokens_and_masks = tokenize_strategy.tokenize(prpt) with torch.no_grad(): - if is_fp8(clip_l_dtype): - with accelerator.autocast(): - l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) + if clip_l is not None: + if is_fp8(clip_l_dtype): + with accelerator.autocast(): + l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) + else: + with torch.autocast(device_type=device.type, dtype=clip_l_dtype): + l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) else: - with torch.autocast(device_type=device.type, dtype=clip_l_dtype): - l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) + l_pooled = None if is_fp8(t5xxl_dtype): with accelerator.autocast(): @@ -288,7 +298,7 @@ def generate_image( else: with torch.autocast(device_type=device.type, dtype=t5xxl_dtype): _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( - tokenize_strategy, [None, t5xxl], tokens_and_masks, args.apply_t5_attn_mask + tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask ) return l_pooled, t5_out, txt_ids, t5_attn_mask @@ -305,7 +315,8 @@ def generate_image( raise ValueError("NaN in t5_out") if args.offload: - clip_l = clip_l.cpu() + if clip_l is not None: + clip_l = clip_l.cpu() t5xxl = t5xxl.cpu() # del clip_l, t5xxl device_utils.clean_memory() @@ -385,6 +396,7 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--ckpt_path", type=str, required=True) + parser.add_argument("--model_type", type=str, choices=["flux", "chroma"], default="flux", help="Model type to use") parser.add_argument("--clip_l", type=str, required=False) parser.add_argument("--t5xxl", type=str, required=False) parser.add_argument("--ae", type=str, required=False) @@ -438,10 +450,13 @@ if __name__ == "__main__": else: accelerator = None - # load clip_l - logger.info(f"Loading clip_l from {args.clip_l}...") - clip_l = flux_utils.load_clip_l(args.clip_l, clip_l_dtype, loading_device) - clip_l.eval() + # load clip_l (skip for chroma model) + if args.model_type == "flux": + logger.info(f"Loading clip_l from {args.clip_l}...") + clip_l = flux_utils.load_clip_l(args.clip_l, clip_l_dtype, loading_device) + clip_l.eval() + else: + clip_l = None logger.info(f"Loading t5xxl from {args.t5xxl}...") t5xxl = flux_utils.load_t5xxl(args.t5xxl, t5xxl_dtype, loading_device) @@ -453,7 +468,7 @@ if __name__ == "__main__": # t5xxl = accelerator.prepare(t5xxl) # DiT - is_schnell, model = flux_utils.load_flow_model(args.ckpt_path, None, loading_device) + model_type, is_schnell, model = flux_utils.load_flow_model(args.ckpt_path, None, loading_device, model_type=args.model_type) model.eval() logger.info(f"Casting model to {flux_dtype}") model.to(flux_dtype) # make sure model is dtype diff --git a/flux_train.py b/flux_train.py index 6f98adea..1d2cc68b 100644 --- a/flux_train.py +++ b/flux_train.py @@ -270,8 +270,8 @@ def train(args): clean_memory_on_device(accelerator.device) # load FLUX - _, flux = flux_utils.load_flow_model( - args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors + model_type, _, flux = flux_utils.load_flow_model( + args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors, model_type="flux" ) if args.gradient_checkpointing: diff --git a/flux_train_control_net.py b/flux_train_control_net.py index cecd0001..3c038c32 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -258,8 +258,8 @@ def train(args): clean_memory_on_device(accelerator.device) # load FLUX - is_schnell, flux = flux_utils.load_flow_model( - args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors + model_type, is_schnell, flux = flux_utils.load_flow_model( + args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors, model_type="flux" ) flux.requires_grad_(False) diff --git a/flux_train_network.py b/flux_train_network.py index def44155..b2bf8e7c 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -95,8 +95,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): loading_dtype = None if args.fp8_base else weight_dtype # if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future - self.is_schnell, model = flux_utils.load_flow_model( - args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors + self.model_type, self.is_schnell, model = flux_utils.load_flow_model( + args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, model_type="flux" ) if args.fp8_base: # check dtype of model diff --git a/library/chroma_models.py b/library/chroma_models.py index 9f21afad..e1da751b 100644 --- a/library/chroma_models.py +++ b/library/chroma_models.py @@ -11,17 +11,7 @@ from torch import Tensor, nn import torch.nn.functional as F import torch.utils.checkpoint as ckpt -from .flux_models import ( - attention, - rope, - apply_rope, - EmbedND, - timestep_embedding, - MLPEmbedder, - RMSNorm, - QKNorm, - SelfAttention -) +from .flux_models import attention, rope, apply_rope, EmbedND, timestep_embedding, MLPEmbedder, RMSNorm, QKNorm, SelfAttention, Flux from . import custom_offloading_utils @@ -468,13 +458,13 @@ def modify_mask_to_attend_padding(mask, max_seq_length, num_extra_padding=8): return modified_mask -class Chroma(nn.Module): +class Chroma(Flux): """ Transformer model for flow matching on sequences. """ def __init__(self, params: ChromaParams): - super().__init__() + nn.Module.__init__(self) self.params = params self.in_channels = params.in_channels self.out_channels = self.in_channels @@ -548,60 +538,9 @@ class Chroma(nn.Module): self.num_double_blocks = len(self.double_blocks) self.num_single_blocks = len(self.single_blocks) - @property - def device(self): - # Get the device of the module (assumes all parameters are on the same device) - return next(self.parameters()).device - - def enable_gradient_checkpointing(self): - self.distilled_guidance_layer.enable_gradient_checkpointing() - for block in self.double_blocks + self.single_blocks: - block.enable_gradient_checkpointing() - - def disable_gradient_checkpointing(self): - self.distilled_guidance_layer.disable_gradient_checkpointing() - for block in self.double_blocks + self.single_blocks: - block.disable_gradient_checkpointing() - - def enable_block_swap(self, num_blocks: int, device: torch.device): - self.blocks_to_swap = num_blocks - double_blocks_to_swap = num_blocks // 2 - single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2 - - assert double_blocks_to_swap <= self.num_double_blocks - 2 and single_blocks_to_swap <= self.num_single_blocks - 2, ( - f"Cannot swap more than {self.num_double_blocks - 2} double blocks and {self.num_single_blocks - 2} single blocks. " - f"Requested {double_blocks_to_swap} double blocks and {single_blocks_to_swap} single blocks." - ) - - self.offloader_double = custom_offloading_utils.ModelOffloader( - self.double_blocks, double_blocks_to_swap, device - ) - self.offloader_single = custom_offloading_utils.ModelOffloader( - self.single_blocks, single_blocks_to_swap, device - ) - print( - f"Chroma: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}." - ) - - def move_to_device_except_swap_blocks(self, device: torch.device): - # assume model is on cpu. do not move blocks to device to reduce temporary memory usage - if self.blocks_to_swap: - save_double_blocks = self.double_blocks - save_single_blocks = self.single_blocks - self.double_blocks = None - self.single_blocks = None - - self.to(device) - - if self.blocks_to_swap: - self.double_blocks = save_double_blocks - self.single_blocks = save_single_blocks - - def prepare_block_swap_before_forward(self): - if self.blocks_to_swap is None or self.blocks_to_swap == 0: - return - self.offloader_double.prepare_block_devices_before_forward(self.double_blocks) - self.offloader_single.prepare_block_devices_before_forward(self.single_blocks) + # Initialize properties required by Flux parent class + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False def forward( self, @@ -609,10 +548,12 @@ class Chroma(nn.Module): img_ids: Tensor, txt: Tensor, txt_ids: Tensor, - txt_mask: Tensor, timesteps: Tensor, - guidance: Tensor, - attn_padding: int = 1, + y: Tensor, + block_controlnet_hidden_states=None, + block_controlnet_single_hidden_states=None, + guidance: Tensor | None = None, + txt_attention_mask: Tensor | None = None, ) -> Tensor: if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.") @@ -654,11 +595,11 @@ class Chroma(nn.Module): # mask with torch.no_grad(): - txt_mask_w_padding = modify_mask_to_attend_padding(txt_mask, max_len, attn_padding) + txt_mask_w_padding = modify_mask_to_attend_padding(txt_attention_mask, max_len, 1) txt_img_mask = torch.cat( [ txt_mask_w_padding, - torch.ones([img.shape[0], img.shape[1]], device=txt_mask.device), + torch.ones([img.shape[0], img.shape[1]], device=txt_attention_mask.device), ], dim=1, ) diff --git a/library/flux_utils.py b/library/flux_utils.py index 8be1d63e..a5cfcdff 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -92,50 +92,84 @@ def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int def load_flow_model( - ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False -) -> Tuple[bool, flux_models.Flux]: - is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path) - name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL + ckpt_path: str, + dtype: Optional[torch.dtype], + device: Union[str, torch.device], + disable_mmap: bool = False, + model_type: str = "flux", +) -> Tuple[str, bool, flux_models.Flux]: + if model_type == "flux": + is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path) + name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL - # build model - logger.info(f"Building Flux model {name} from {'Diffusers' if is_diffusers else 'BFL'} checkpoint") - with torch.device("meta"): - params = flux_models.configs[name].params + # build model + logger.info(f"Building Flux model {name} from {'Diffusers' if is_diffusers else 'BFL'} checkpoint") + with torch.device("meta"): + params = flux_models.configs[name].params - # set the number of blocks - if params.depth != num_double_blocks: - logger.info(f"Setting the number of double blocks from {params.depth} to {num_double_blocks}") - params = replace(params, depth=num_double_blocks) - if params.depth_single_blocks != num_single_blocks: - logger.info(f"Setting the number of single blocks from {params.depth_single_blocks} to {num_single_blocks}") - params = replace(params, depth_single_blocks=num_single_blocks) + # set the number of blocks + if params.depth != num_double_blocks: + logger.info(f"Setting the number of double blocks from {params.depth} to {num_double_blocks}") + params = replace(params, depth=num_double_blocks) + if params.depth_single_blocks != num_single_blocks: + logger.info(f"Setting the number of single blocks from {params.depth_single_blocks} to {num_single_blocks}") + params = replace(params, depth_single_blocks=num_single_blocks) - model = flux_models.Flux(params) - if dtype is not None: - model = model.to(dtype) + model = flux_models.Flux(params) + if dtype is not None: + model = model.to(dtype) - # load_sft doesn't support torch.device - logger.info(f"Loading state dict from {ckpt_path}") - sd = {} - for ckpt_path in ckpt_paths: - sd.update(load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)) + # load_sft doesn't support torch.device + logger.info(f"Loading state dict from {ckpt_path}") + sd = {} + for ckpt_path in ckpt_paths: + sd.update(load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)) - # convert Diffusers to BFL - if is_diffusers: - logger.info("Converting Diffusers to BFL") - sd = convert_diffusers_sd_to_bfl(sd, num_double_blocks, num_single_blocks) - logger.info("Converted Diffusers to BFL") + # convert Diffusers to BFL + if is_diffusers: + logger.info("Converting Diffusers to BFL") + sd = convert_diffusers_sd_to_bfl(sd, num_double_blocks, num_single_blocks) + logger.info("Converted Diffusers to BFL") - # if the key has annoying prefix, remove it - for key in list(sd.keys()): - new_key = key.replace("model.diffusion_model.", "") - if new_key == key: - break # the model doesn't have annoying prefix - sd[new_key] = sd.pop(key) + # if the key has annoying prefix, remove it + for key in list(sd.keys()): + new_key = key.replace("model.diffusion_model.", "") + if new_key == key: + break # the model doesn't have annoying prefix + sd[new_key] = sd.pop(key) - info = model.load_state_dict(sd, strict=False, assign=True) - logger.info(f"Loaded Flux: {info}") - return is_schnell, model + info = model.load_state_dict(sd, strict=False, assign=True) + logger.info(f"Loaded Flux: {info}") + return model_type, is_schnell, model + + elif model_type == "chroma": + from . import chroma_models + + # build model + logger.info("Building Chroma model from BFL checkpoint") + with torch.device("meta"): + model = chroma_models.Chroma(chroma_models.chroma_params) + if dtype is not None: + model = model.to(dtype) + + # load_sft doesn't support torch.device + logger.info(f"Loading state dict from {ckpt_path}") + sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) + + # if the key has annoying prefix, remove it + for key in list(sd.keys()): + new_key = key.replace("model.diffusion_model.", "") + if new_key == key: + break # the model doesn't have annoying prefix + sd[new_key] = sd.pop(key) + + info = model.load_state_dict(sd, strict=False, assign=True) + logger.info(f"Loaded Chroma: {info}") + is_schnell = False # Chroma is not schnell + return model_type, is_schnell, model + + else: + raise ValueError(f"Unsupported model_type: {model_type}. Supported types are 'flux' and 'chroma'.") def load_ae( @@ -166,7 +200,7 @@ def load_controlnet( sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) info = controlnet.load_state_dict(sd, strict=False, assign=True) logger.info(f"Loaded ControlNet: {info}") - return controlnet + return controlnet def load_clip_l(