mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
feat: support Chroma model in loading and inference processes
This commit is contained in:
@@ -11,17 +11,7 @@ from torch import Tensor, nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as ckpt
|
||||
|
||||
from .flux_models import (
|
||||
attention,
|
||||
rope,
|
||||
apply_rope,
|
||||
EmbedND,
|
||||
timestep_embedding,
|
||||
MLPEmbedder,
|
||||
RMSNorm,
|
||||
QKNorm,
|
||||
SelfAttention
|
||||
)
|
||||
from .flux_models import attention, rope, apply_rope, EmbedND, timestep_embedding, MLPEmbedder, RMSNorm, QKNorm, SelfAttention, Flux
|
||||
from . import custom_offloading_utils
|
||||
|
||||
|
||||
@@ -468,13 +458,13 @@ def modify_mask_to_attend_padding(mask, max_seq_length, num_extra_padding=8):
|
||||
return modified_mask
|
||||
|
||||
|
||||
class Chroma(nn.Module):
|
||||
class Chroma(Flux):
|
||||
"""
|
||||
Transformer model for flow matching on sequences.
|
||||
"""
|
||||
|
||||
def __init__(self, params: ChromaParams):
|
||||
super().__init__()
|
||||
nn.Module.__init__(self)
|
||||
self.params = params
|
||||
self.in_channels = params.in_channels
|
||||
self.out_channels = self.in_channels
|
||||
@@ -548,60 +538,9 @@ class Chroma(nn.Module):
|
||||
self.num_double_blocks = len(self.double_blocks)
|
||||
self.num_single_blocks = len(self.single_blocks)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
# Get the device of the module (assumes all parameters are on the same device)
|
||||
return next(self.parameters()).device
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
self.distilled_guidance_layer.enable_gradient_checkpointing()
|
||||
for block in self.double_blocks + self.single_blocks:
|
||||
block.enable_gradient_checkpointing()
|
||||
|
||||
def disable_gradient_checkpointing(self):
|
||||
self.distilled_guidance_layer.disable_gradient_checkpointing()
|
||||
for block in self.double_blocks + self.single_blocks:
|
||||
block.disable_gradient_checkpointing()
|
||||
|
||||
def enable_block_swap(self, num_blocks: int, device: torch.device):
|
||||
self.blocks_to_swap = num_blocks
|
||||
double_blocks_to_swap = num_blocks // 2
|
||||
single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2
|
||||
|
||||
assert double_blocks_to_swap <= self.num_double_blocks - 2 and single_blocks_to_swap <= self.num_single_blocks - 2, (
|
||||
f"Cannot swap more than {self.num_double_blocks - 2} double blocks and {self.num_single_blocks - 2} single blocks. "
|
||||
f"Requested {double_blocks_to_swap} double blocks and {single_blocks_to_swap} single blocks."
|
||||
)
|
||||
|
||||
self.offloader_double = custom_offloading_utils.ModelOffloader(
|
||||
self.double_blocks, double_blocks_to_swap, device
|
||||
)
|
||||
self.offloader_single = custom_offloading_utils.ModelOffloader(
|
||||
self.single_blocks, single_blocks_to_swap, device
|
||||
)
|
||||
print(
|
||||
f"Chroma: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
|
||||
)
|
||||
|
||||
def move_to_device_except_swap_blocks(self, device: torch.device):
|
||||
# assume model is on cpu. do not move blocks to device to reduce temporary memory usage
|
||||
if self.blocks_to_swap:
|
||||
save_double_blocks = self.double_blocks
|
||||
save_single_blocks = self.single_blocks
|
||||
self.double_blocks = None
|
||||
self.single_blocks = None
|
||||
|
||||
self.to(device)
|
||||
|
||||
if self.blocks_to_swap:
|
||||
self.double_blocks = save_double_blocks
|
||||
self.single_blocks = save_single_blocks
|
||||
|
||||
def prepare_block_swap_before_forward(self):
|
||||
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
||||
return
|
||||
self.offloader_double.prepare_block_devices_before_forward(self.double_blocks)
|
||||
self.offloader_single.prepare_block_devices_before_forward(self.single_blocks)
|
||||
# Initialize properties required by Flux parent class
|
||||
self.gradient_checkpointing = False
|
||||
self.cpu_offload_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -609,10 +548,12 @@ class Chroma(nn.Module):
|
||||
img_ids: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
txt_mask: Tensor,
|
||||
timesteps: Tensor,
|
||||
guidance: Tensor,
|
||||
attn_padding: int = 1,
|
||||
y: Tensor,
|
||||
block_controlnet_hidden_states=None,
|
||||
block_controlnet_single_hidden_states=None,
|
||||
guidance: Tensor | None = None,
|
||||
txt_attention_mask: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
@@ -654,11 +595,11 @@ class Chroma(nn.Module):
|
||||
|
||||
# mask
|
||||
with torch.no_grad():
|
||||
txt_mask_w_padding = modify_mask_to_attend_padding(txt_mask, max_len, attn_padding)
|
||||
txt_mask_w_padding = modify_mask_to_attend_padding(txt_attention_mask, max_len, 1)
|
||||
txt_img_mask = torch.cat(
|
||||
[
|
||||
txt_mask_w_padding,
|
||||
torch.ones([img.shape[0], img.shape[1]], device=txt_mask.device),
|
||||
torch.ones([img.shape[0], img.shape[1]], device=txt_attention_mask.device),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user