feat: support Chroma model in loading and inference processes

This commit is contained in:
kohya-ss
2025-07-20 12:56:42 +09:00
parent a96d684ffa
commit 24d2ea86c7
6 changed files with 123 additions and 133 deletions

View File

@@ -108,12 +108,18 @@ def denoise(
else:
b_img = img
# For Chroma model, y might be None, so create dummy tensor
if b_vec is None:
y_input = torch.zeros_like(b_txt[:, :1, :]) # dummy tensor
else:
y_input = b_vec
pred = model(
img=b_img,
img_ids=b_img_ids,
txt=b_txt,
txt_ids=b_txt_ids,
y=b_vec,
y=y_input,
timesteps=t_vec,
guidance=guidance_vec,
txt_attention_mask=b_t5_attn_mask,
@@ -134,7 +140,7 @@ def do_sample(
model: flux_models.Flux,
img: torch.Tensor,
img_ids: torch.Tensor,
l_pooled: torch.Tensor,
l_pooled: Optional[torch.Tensor],
t5_out: torch.Tensor,
txt_ids: torch.Tensor,
num_steps: int,
@@ -192,7 +198,7 @@ def do_sample(
def generate_image(
model,
clip_l: CLIPTextModel,
clip_l: Optional[CLIPTextModel],
t5xxl,
ae,
prompt: str,
@@ -231,7 +237,7 @@ def generate_image(
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width)
# prepare fp8 models
if is_fp8(clip_l_dtype) and (not hasattr(clip_l, "fp8_prepared") or not clip_l.fp8_prepared):
if clip_l is not None and is_fp8(clip_l_dtype) and (not hasattr(clip_l, "fp8_prepared") or not clip_l.fp8_prepared):
logger.info(f"prepare CLIP-L for fp8: set to {clip_l_dtype}, set embeddings to {torch.bfloat16}")
clip_l.to(clip_l_dtype) # fp8
clip_l.text_model.embeddings.to(dtype=torch.bfloat16)
@@ -267,18 +273,22 @@ def generate_image(
# prepare embeddings
logger.info("Encoding prompts...")
clip_l = clip_l.to(device)
if clip_l is not None:
clip_l = clip_l.to(device)
t5xxl = t5xxl.to(device)
def encode(prpt: str):
tokens_and_masks = tokenize_strategy.tokenize(prpt)
with torch.no_grad():
if is_fp8(clip_l_dtype):
with accelerator.autocast():
l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
if clip_l is not None:
if is_fp8(clip_l_dtype):
with accelerator.autocast():
l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
else:
with torch.autocast(device_type=device.type, dtype=clip_l_dtype):
l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
else:
with torch.autocast(device_type=device.type, dtype=clip_l_dtype):
l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
l_pooled = None
if is_fp8(t5xxl_dtype):
with accelerator.autocast():
@@ -288,7 +298,7 @@ def generate_image(
else:
with torch.autocast(device_type=device.type, dtype=t5xxl_dtype):
_, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens(
tokenize_strategy, [None, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
)
return l_pooled, t5_out, txt_ids, t5_attn_mask
@@ -305,7 +315,8 @@ def generate_image(
raise ValueError("NaN in t5_out")
if args.offload:
clip_l = clip_l.cpu()
if clip_l is not None:
clip_l = clip_l.cpu()
t5xxl = t5xxl.cpu()
# del clip_l, t5xxl
device_utils.clean_memory()
@@ -385,6 +396,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt_path", type=str, required=True)
parser.add_argument("--model_type", type=str, choices=["flux", "chroma"], default="flux", help="Model type to use")
parser.add_argument("--clip_l", type=str, required=False)
parser.add_argument("--t5xxl", type=str, required=False)
parser.add_argument("--ae", type=str, required=False)
@@ -438,10 +450,13 @@ if __name__ == "__main__":
else:
accelerator = None
# load clip_l
logger.info(f"Loading clip_l from {args.clip_l}...")
clip_l = flux_utils.load_clip_l(args.clip_l, clip_l_dtype, loading_device)
clip_l.eval()
# load clip_l (skip for chroma model)
if args.model_type == "flux":
logger.info(f"Loading clip_l from {args.clip_l}...")
clip_l = flux_utils.load_clip_l(args.clip_l, clip_l_dtype, loading_device)
clip_l.eval()
else:
clip_l = None
logger.info(f"Loading t5xxl from {args.t5xxl}...")
t5xxl = flux_utils.load_t5xxl(args.t5xxl, t5xxl_dtype, loading_device)
@@ -453,7 +468,7 @@ if __name__ == "__main__":
# t5xxl = accelerator.prepare(t5xxl)
# DiT
is_schnell, model = flux_utils.load_flow_model(args.ckpt_path, None, loading_device)
model_type, is_schnell, model = flux_utils.load_flow_model(args.ckpt_path, None, loading_device, model_type=args.model_type)
model.eval()
logger.info(f"Casting model to {flux_dtype}")
model.to(flux_dtype) # make sure model is dtype