feat: support Chroma model in loading and inference processes

This commit is contained in:
kohya-ss
2025-07-20 12:56:42 +09:00
parent a96d684ffa
commit 24d2ea86c7
6 changed files with 123 additions and 133 deletions

View File

@@ -11,17 +11,7 @@ from torch import Tensor, nn
import torch.nn.functional as F
import torch.utils.checkpoint as ckpt
from .flux_models import (
attention,
rope,
apply_rope,
EmbedND,
timestep_embedding,
MLPEmbedder,
RMSNorm,
QKNorm,
SelfAttention
)
from .flux_models import attention, rope, apply_rope, EmbedND, timestep_embedding, MLPEmbedder, RMSNorm, QKNorm, SelfAttention, Flux
from . import custom_offloading_utils
@@ -468,13 +458,13 @@ def modify_mask_to_attend_padding(mask, max_seq_length, num_extra_padding=8):
return modified_mask
class Chroma(nn.Module):
class Chroma(Flux):
"""
Transformer model for flow matching on sequences.
"""
def __init__(self, params: ChromaParams):
super().__init__()
nn.Module.__init__(self)
self.params = params
self.in_channels = params.in_channels
self.out_channels = self.in_channels
@@ -548,60 +538,9 @@ class Chroma(nn.Module):
self.num_double_blocks = len(self.double_blocks)
self.num_single_blocks = len(self.single_blocks)
@property
def device(self):
# Get the device of the module (assumes all parameters are on the same device)
return next(self.parameters()).device
def enable_gradient_checkpointing(self):
self.distilled_guidance_layer.enable_gradient_checkpointing()
for block in self.double_blocks + self.single_blocks:
block.enable_gradient_checkpointing()
def disable_gradient_checkpointing(self):
self.distilled_guidance_layer.disable_gradient_checkpointing()
for block in self.double_blocks + self.single_blocks:
block.disable_gradient_checkpointing()
def enable_block_swap(self, num_blocks: int, device: torch.device):
self.blocks_to_swap = num_blocks
double_blocks_to_swap = num_blocks // 2
single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2
assert double_blocks_to_swap <= self.num_double_blocks - 2 and single_blocks_to_swap <= self.num_single_blocks - 2, (
f"Cannot swap more than {self.num_double_blocks - 2} double blocks and {self.num_single_blocks - 2} single blocks. "
f"Requested {double_blocks_to_swap} double blocks and {single_blocks_to_swap} single blocks."
)
self.offloader_double = custom_offloading_utils.ModelOffloader(
self.double_blocks, double_blocks_to_swap, device
)
self.offloader_single = custom_offloading_utils.ModelOffloader(
self.single_blocks, single_blocks_to_swap, device
)
print(
f"Chroma: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
)
def move_to_device_except_swap_blocks(self, device: torch.device):
# assume model is on cpu. do not move blocks to device to reduce temporary memory usage
if self.blocks_to_swap:
save_double_blocks = self.double_blocks
save_single_blocks = self.single_blocks
self.double_blocks = None
self.single_blocks = None
self.to(device)
if self.blocks_to_swap:
self.double_blocks = save_double_blocks
self.single_blocks = save_single_blocks
def prepare_block_swap_before_forward(self):
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return
self.offloader_double.prepare_block_devices_before_forward(self.double_blocks)
self.offloader_single.prepare_block_devices_before_forward(self.single_blocks)
# Initialize properties required by Flux parent class
self.gradient_checkpointing = False
self.cpu_offload_checkpointing = False
def forward(
self,
@@ -609,10 +548,12 @@ class Chroma(nn.Module):
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
txt_mask: Tensor,
timesteps: Tensor,
guidance: Tensor,
attn_padding: int = 1,
y: Tensor,
block_controlnet_hidden_states=None,
block_controlnet_single_hidden_states=None,
guidance: Tensor | None = None,
txt_attention_mask: Tensor | None = None,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
@@ -654,11 +595,11 @@ class Chroma(nn.Module):
# mask
with torch.no_grad():
txt_mask_w_padding = modify_mask_to_attend_padding(txt_mask, max_len, attn_padding)
txt_mask_w_padding = modify_mask_to_attend_padding(txt_attention_mask, max_len, 1)
txt_img_mask = torch.cat(
[
txt_mask_w_padding,
torch.ones([img.shape[0], img.shape[1]], device=txt_mask.device),
torch.ones([img.shape[0], img.shape[1]], device=txt_attention_mask.device),
],
dim=1,
)