feat: support Chroma model in loading and inference processes

This commit is contained in:
kohya-ss
2025-07-20 12:56:42 +09:00
parent a96d684ffa
commit 24d2ea86c7
6 changed files with 123 additions and 133 deletions

View File

@@ -108,12 +108,18 @@ def denoise(
else:
b_img = img
# For Chroma model, y might be None, so create dummy tensor
if b_vec is None:
y_input = torch.zeros_like(b_txt[:, :1, :]) # dummy tensor
else:
y_input = b_vec
pred = model(
img=b_img,
img_ids=b_img_ids,
txt=b_txt,
txt_ids=b_txt_ids,
y=b_vec,
y=y_input,
timesteps=t_vec,
guidance=guidance_vec,
txt_attention_mask=b_t5_attn_mask,
@@ -134,7 +140,7 @@ def do_sample(
model: flux_models.Flux,
img: torch.Tensor,
img_ids: torch.Tensor,
l_pooled: torch.Tensor,
l_pooled: Optional[torch.Tensor],
t5_out: torch.Tensor,
txt_ids: torch.Tensor,
num_steps: int,
@@ -192,7 +198,7 @@ def do_sample(
def generate_image(
model,
clip_l: CLIPTextModel,
clip_l: Optional[CLIPTextModel],
t5xxl,
ae,
prompt: str,
@@ -231,7 +237,7 @@ def generate_image(
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width)
# prepare fp8 models
if is_fp8(clip_l_dtype) and (not hasattr(clip_l, "fp8_prepared") or not clip_l.fp8_prepared):
if clip_l is not None and is_fp8(clip_l_dtype) and (not hasattr(clip_l, "fp8_prepared") or not clip_l.fp8_prepared):
logger.info(f"prepare CLIP-L for fp8: set to {clip_l_dtype}, set embeddings to {torch.bfloat16}")
clip_l.to(clip_l_dtype) # fp8
clip_l.text_model.embeddings.to(dtype=torch.bfloat16)
@@ -267,18 +273,22 @@ def generate_image(
# prepare embeddings
logger.info("Encoding prompts...")
clip_l = clip_l.to(device)
if clip_l is not None:
clip_l = clip_l.to(device)
t5xxl = t5xxl.to(device)
def encode(prpt: str):
tokens_and_masks = tokenize_strategy.tokenize(prpt)
with torch.no_grad():
if is_fp8(clip_l_dtype):
with accelerator.autocast():
l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
if clip_l is not None:
if is_fp8(clip_l_dtype):
with accelerator.autocast():
l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
else:
with torch.autocast(device_type=device.type, dtype=clip_l_dtype):
l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
else:
with torch.autocast(device_type=device.type, dtype=clip_l_dtype):
l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
l_pooled = None
if is_fp8(t5xxl_dtype):
with accelerator.autocast():
@@ -288,7 +298,7 @@ def generate_image(
else:
with torch.autocast(device_type=device.type, dtype=t5xxl_dtype):
_, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens(
tokenize_strategy, [None, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
)
return l_pooled, t5_out, txt_ids, t5_attn_mask
@@ -305,7 +315,8 @@ def generate_image(
raise ValueError("NaN in t5_out")
if args.offload:
clip_l = clip_l.cpu()
if clip_l is not None:
clip_l = clip_l.cpu()
t5xxl = t5xxl.cpu()
# del clip_l, t5xxl
device_utils.clean_memory()
@@ -385,6 +396,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt_path", type=str, required=True)
parser.add_argument("--model_type", type=str, choices=["flux", "chroma"], default="flux", help="Model type to use")
parser.add_argument("--clip_l", type=str, required=False)
parser.add_argument("--t5xxl", type=str, required=False)
parser.add_argument("--ae", type=str, required=False)
@@ -438,10 +450,13 @@ if __name__ == "__main__":
else:
accelerator = None
# load clip_l
logger.info(f"Loading clip_l from {args.clip_l}...")
clip_l = flux_utils.load_clip_l(args.clip_l, clip_l_dtype, loading_device)
clip_l.eval()
# load clip_l (skip for chroma model)
if args.model_type == "flux":
logger.info(f"Loading clip_l from {args.clip_l}...")
clip_l = flux_utils.load_clip_l(args.clip_l, clip_l_dtype, loading_device)
clip_l.eval()
else:
clip_l = None
logger.info(f"Loading t5xxl from {args.t5xxl}...")
t5xxl = flux_utils.load_t5xxl(args.t5xxl, t5xxl_dtype, loading_device)
@@ -453,7 +468,7 @@ if __name__ == "__main__":
# t5xxl = accelerator.prepare(t5xxl)
# DiT
is_schnell, model = flux_utils.load_flow_model(args.ckpt_path, None, loading_device)
model_type, is_schnell, model = flux_utils.load_flow_model(args.ckpt_path, None, loading_device, model_type=args.model_type)
model.eval()
logger.info(f"Casting model to {flux_dtype}")
model.to(flux_dtype) # make sure model is dtype

View File

@@ -270,8 +270,8 @@ def train(args):
clean_memory_on_device(accelerator.device)
# load FLUX
_, flux = flux_utils.load_flow_model(
args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors
model_type, _, flux = flux_utils.load_flow_model(
args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors, model_type="flux"
)
if args.gradient_checkpointing:

View File

@@ -258,8 +258,8 @@ def train(args):
clean_memory_on_device(accelerator.device)
# load FLUX
is_schnell, flux = flux_utils.load_flow_model(
args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors
model_type, is_schnell, flux = flux_utils.load_flow_model(
args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors, model_type="flux"
)
flux.requires_grad_(False)

View File

@@ -95,8 +95,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
loading_dtype = None if args.fp8_base else weight_dtype
# if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future
self.is_schnell, model = flux_utils.load_flow_model(
args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors
self.model_type, self.is_schnell, model = flux_utils.load_flow_model(
args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, model_type="flux"
)
if args.fp8_base:
# check dtype of model

View File

@@ -11,17 +11,7 @@ from torch import Tensor, nn
import torch.nn.functional as F
import torch.utils.checkpoint as ckpt
from .flux_models import (
attention,
rope,
apply_rope,
EmbedND,
timestep_embedding,
MLPEmbedder,
RMSNorm,
QKNorm,
SelfAttention
)
from .flux_models import attention, rope, apply_rope, EmbedND, timestep_embedding, MLPEmbedder, RMSNorm, QKNorm, SelfAttention, Flux
from . import custom_offloading_utils
@@ -468,13 +458,13 @@ def modify_mask_to_attend_padding(mask, max_seq_length, num_extra_padding=8):
return modified_mask
class Chroma(nn.Module):
class Chroma(Flux):
"""
Transformer model for flow matching on sequences.
"""
def __init__(self, params: ChromaParams):
super().__init__()
nn.Module.__init__(self)
self.params = params
self.in_channels = params.in_channels
self.out_channels = self.in_channels
@@ -548,60 +538,9 @@ class Chroma(nn.Module):
self.num_double_blocks = len(self.double_blocks)
self.num_single_blocks = len(self.single_blocks)
@property
def device(self):
# Get the device of the module (assumes all parameters are on the same device)
return next(self.parameters()).device
def enable_gradient_checkpointing(self):
self.distilled_guidance_layer.enable_gradient_checkpointing()
for block in self.double_blocks + self.single_blocks:
block.enable_gradient_checkpointing()
def disable_gradient_checkpointing(self):
self.distilled_guidance_layer.disable_gradient_checkpointing()
for block in self.double_blocks + self.single_blocks:
block.disable_gradient_checkpointing()
def enable_block_swap(self, num_blocks: int, device: torch.device):
self.blocks_to_swap = num_blocks
double_blocks_to_swap = num_blocks // 2
single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2
assert double_blocks_to_swap <= self.num_double_blocks - 2 and single_blocks_to_swap <= self.num_single_blocks - 2, (
f"Cannot swap more than {self.num_double_blocks - 2} double blocks and {self.num_single_blocks - 2} single blocks. "
f"Requested {double_blocks_to_swap} double blocks and {single_blocks_to_swap} single blocks."
)
self.offloader_double = custom_offloading_utils.ModelOffloader(
self.double_blocks, double_blocks_to_swap, device
)
self.offloader_single = custom_offloading_utils.ModelOffloader(
self.single_blocks, single_blocks_to_swap, device
)
print(
f"Chroma: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
)
def move_to_device_except_swap_blocks(self, device: torch.device):
# assume model is on cpu. do not move blocks to device to reduce temporary memory usage
if self.blocks_to_swap:
save_double_blocks = self.double_blocks
save_single_blocks = self.single_blocks
self.double_blocks = None
self.single_blocks = None
self.to(device)
if self.blocks_to_swap:
self.double_blocks = save_double_blocks
self.single_blocks = save_single_blocks
def prepare_block_swap_before_forward(self):
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return
self.offloader_double.prepare_block_devices_before_forward(self.double_blocks)
self.offloader_single.prepare_block_devices_before_forward(self.single_blocks)
# Initialize properties required by Flux parent class
self.gradient_checkpointing = False
self.cpu_offload_checkpointing = False
def forward(
self,
@@ -609,10 +548,12 @@ class Chroma(nn.Module):
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
txt_mask: Tensor,
timesteps: Tensor,
guidance: Tensor,
attn_padding: int = 1,
y: Tensor,
block_controlnet_hidden_states=None,
block_controlnet_single_hidden_states=None,
guidance: Tensor | None = None,
txt_attention_mask: Tensor | None = None,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
@@ -654,11 +595,11 @@ class Chroma(nn.Module):
# mask
with torch.no_grad():
txt_mask_w_padding = modify_mask_to_attend_padding(txt_mask, max_len, attn_padding)
txt_mask_w_padding = modify_mask_to_attend_padding(txt_attention_mask, max_len, 1)
txt_img_mask = torch.cat(
[
txt_mask_w_padding,
torch.ones([img.shape[0], img.shape[1]], device=txt_mask.device),
torch.ones([img.shape[0], img.shape[1]], device=txt_attention_mask.device),
],
dim=1,
)

View File

@@ -92,50 +92,84 @@ def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int
def load_flow_model(
ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False
) -> Tuple[bool, flux_models.Flux]:
is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path)
name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL
ckpt_path: str,
dtype: Optional[torch.dtype],
device: Union[str, torch.device],
disable_mmap: bool = False,
model_type: str = "flux",
) -> Tuple[str, bool, flux_models.Flux]:
if model_type == "flux":
is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path)
name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL
# build model
logger.info(f"Building Flux model {name} from {'Diffusers' if is_diffusers else 'BFL'} checkpoint")
with torch.device("meta"):
params = flux_models.configs[name].params
# build model
logger.info(f"Building Flux model {name} from {'Diffusers' if is_diffusers else 'BFL'} checkpoint")
with torch.device("meta"):
params = flux_models.configs[name].params
# set the number of blocks
if params.depth != num_double_blocks:
logger.info(f"Setting the number of double blocks from {params.depth} to {num_double_blocks}")
params = replace(params, depth=num_double_blocks)
if params.depth_single_blocks != num_single_blocks:
logger.info(f"Setting the number of single blocks from {params.depth_single_blocks} to {num_single_blocks}")
params = replace(params, depth_single_blocks=num_single_blocks)
# set the number of blocks
if params.depth != num_double_blocks:
logger.info(f"Setting the number of double blocks from {params.depth} to {num_double_blocks}")
params = replace(params, depth=num_double_blocks)
if params.depth_single_blocks != num_single_blocks:
logger.info(f"Setting the number of single blocks from {params.depth_single_blocks} to {num_single_blocks}")
params = replace(params, depth_single_blocks=num_single_blocks)
model = flux_models.Flux(params)
if dtype is not None:
model = model.to(dtype)
model = flux_models.Flux(params)
if dtype is not None:
model = model.to(dtype)
# load_sft doesn't support torch.device
logger.info(f"Loading state dict from {ckpt_path}")
sd = {}
for ckpt_path in ckpt_paths:
sd.update(load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype))
# load_sft doesn't support torch.device
logger.info(f"Loading state dict from {ckpt_path}")
sd = {}
for ckpt_path in ckpt_paths:
sd.update(load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype))
# convert Diffusers to BFL
if is_diffusers:
logger.info("Converting Diffusers to BFL")
sd = convert_diffusers_sd_to_bfl(sd, num_double_blocks, num_single_blocks)
logger.info("Converted Diffusers to BFL")
# convert Diffusers to BFL
if is_diffusers:
logger.info("Converting Diffusers to BFL")
sd = convert_diffusers_sd_to_bfl(sd, num_double_blocks, num_single_blocks)
logger.info("Converted Diffusers to BFL")
# if the key has annoying prefix, remove it
for key in list(sd.keys()):
new_key = key.replace("model.diffusion_model.", "")
if new_key == key:
break # the model doesn't have annoying prefix
sd[new_key] = sd.pop(key)
# if the key has annoying prefix, remove it
for key in list(sd.keys()):
new_key = key.replace("model.diffusion_model.", "")
if new_key == key:
break # the model doesn't have annoying prefix
sd[new_key] = sd.pop(key)
info = model.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded Flux: {info}")
return is_schnell, model
info = model.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded Flux: {info}")
return model_type, is_schnell, model
elif model_type == "chroma":
from . import chroma_models
# build model
logger.info("Building Chroma model from BFL checkpoint")
with torch.device("meta"):
model = chroma_models.Chroma(chroma_models.chroma_params)
if dtype is not None:
model = model.to(dtype)
# load_sft doesn't support torch.device
logger.info(f"Loading state dict from {ckpt_path}")
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
# if the key has annoying prefix, remove it
for key in list(sd.keys()):
new_key = key.replace("model.diffusion_model.", "")
if new_key == key:
break # the model doesn't have annoying prefix
sd[new_key] = sd.pop(key)
info = model.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded Chroma: {info}")
is_schnell = False # Chroma is not schnell
return model_type, is_schnell, model
else:
raise ValueError(f"Unsupported model_type: {model_type}. Supported types are 'flux' and 'chroma'.")
def load_ae(
@@ -166,7 +200,7 @@ def load_controlnet(
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
info = controlnet.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded ControlNet: {info}")
return controlnet
return controlnet
def load_clip_l(