add flux controlnet base module

This commit is contained in:
minux302
2024-11-15 20:21:28 +09:00
parent 0047bb1fc3
commit ccfaa001e7
4 changed files with 841 additions and 2 deletions

View File

@@ -1013,6 +1013,8 @@ class Flux(nn.Module):
txt_ids: Tensor,
timesteps: Tensor,
y: Tensor,
block_controlnet_hidden_states=None,
block_controlnet_single_hidden_states=None,
guidance: Tensor | None = None,
txt_attention_mask: Tensor | None = None,
) -> Tensor:
@@ -1031,18 +1033,29 @@ class Flux(nn.Module):
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
if block_controlnet_hidden_states is not None:
controlnet_depth = len(block_controlnet_hidden_states)
if block_controlnet_single_hidden_states is not None:
controlnet_single_depth = len(block_controlnet_single_hidden_states)
if not self.blocks_to_swap:
for block in self.double_blocks:
for block_idx, block in enumerate(self.double_blocks):
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
if block_controlnet_hidden_states is not None:
img = img + block_controlnet_hidden_states[block_idx % controlnet_depth]
img = torch.cat((txt, img), 1)
for block in self.single_blocks:
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
if block_controlnet_single_hidden_states is not None:
img = img + block_controlnet_single_hidden_states[block_idx % controlnet_single_depth]
else:
for block_idx, block in enumerate(self.double_blocks):
self.offloader_double.wait_for_block(block_idx)
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
if block_controlnet_hidden_states is not None:
img = img + block_controlnet_hidden_states[block_idx % controlnet_depth]
self.offloader_double.submit_move_blocks(self.double_blocks, block_idx)
@@ -1052,6 +1065,8 @@ class Flux(nn.Module):
self.offloader_single.wait_for_block(block_idx)
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
if block_controlnet_single_hidden_states is not None:
img = img + block_controlnet_single_hidden_states[block_idx % controlnet_single_depth]
self.offloader_single.submit_move_blocks(self.single_blocks, block_idx)
@@ -1066,6 +1081,246 @@ class Flux(nn.Module):
return img
def zero_module(module):
for p in module.parameters():
nn.init.zeros_(p)
return module
class ControlNetFlux(nn.Module):
"""
Transformer model for flow matching on sequences.
"""
def __init__(self, params: FluxParams, controlnet_depth=2):
super().__init__()
self.params = params
self.in_channels = params.in_channels
self.out_channels = self.in_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
pe_dim = params.hidden_size // params.num_heads
if sum(params.axes_dim) != pe_dim:
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
)
for _ in range(params.depth)
]
)
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
for _ in range(0) # TMP
]
)
self.gradient_checkpointing = False
self.cpu_offload_checkpointing = False
self.blocks_to_swap = None
self.offloader_double = None
self.offloader_single = None
self.num_double_blocks = len(self.double_blocks)
self.num_single_blocks = len(self.single_blocks)
# add ControlNet blocks
self.controlnet_blocks_for_double = nn.ModuleList([])
for _ in range(controlnet_depth):
controlnet_block = nn.Linear(self.hidden_size, self.hidden_size)
controlnet_block = zero_module(controlnet_block)
self.controlnet_blocks_for_double.append(controlnet_block)
self.controlnet_blocks_for_single = nn.ModuleList([])
for _ in range(controlnet_depth):
controlnet_block = nn.Linear(self.hidden_size, self.hidden_size)
controlnet_block = zero_module(controlnet_block)
self.controlnet_blocks_for_single.append(controlnet_block)
self.pos_embed_input = nn.Linear(self.in_channels, self.hidden_size, bias=True)
self.gradient_checkpointing = False
self.input_hint_block = nn.Sequential(
nn.Conv2d(3, 16, 3, padding=1),
nn.SiLU(),
nn.Conv2d(16, 16, 3, padding=1),
nn.SiLU(),
nn.Conv2d(16, 16, 3, padding=1, stride=2),
nn.SiLU(),
nn.Conv2d(16, 16, 3, padding=1),
nn.SiLU(),
nn.Conv2d(16, 16, 3, padding=1, stride=2),
nn.SiLU(),
nn.Conv2d(16, 16, 3, padding=1),
nn.SiLU(),
nn.Conv2d(16, 16, 3, padding=1, stride=2),
nn.SiLU(),
zero_module(nn.Conv2d(16, 16, 3, padding=1))
)
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype
def enable_gradient_checkpointing(self, cpu_offload: bool = False):
self.gradient_checkpointing = True
self.cpu_offload_checkpointing = cpu_offload
self.time_in.enable_gradient_checkpointing()
self.vector_in.enable_gradient_checkpointing()
if self.guidance_in.__class__ != nn.Identity:
self.guidance_in.enable_gradient_checkpointing()
for block in self.double_blocks + self.single_blocks:
block.enable_gradient_checkpointing(cpu_offload=cpu_offload)
print(f"FLUX: Gradient checkpointing enabled. CPU offload: {cpu_offload}")
def disable_gradient_checkpointing(self):
self.gradient_checkpointing = False
self.cpu_offload_checkpointing = False
self.time_in.disable_gradient_checkpointing()
self.vector_in.disable_gradient_checkpointing()
if self.guidance_in.__class__ != nn.Identity:
self.guidance_in.disable_gradient_checkpointing()
for block in self.double_blocks + self.single_blocks:
block.disable_gradient_checkpointing()
print("FLUX: Gradient checkpointing disabled.")
def enable_block_swap(self, num_blocks: int, device: torch.device):
self.blocks_to_swap = num_blocks
double_blocks_to_swap = num_blocks // 2
single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2
assert double_blocks_to_swap <= self.num_double_blocks - 2 and single_blocks_to_swap <= self.num_single_blocks - 2, (
f"Cannot swap more than {self.num_double_blocks - 2} double blocks and {self.num_single_blocks - 2} single blocks. "
f"Requested {double_blocks_to_swap} double blocks and {single_blocks_to_swap} single blocks."
)
self.offloader_double = custom_offloading_utils.ModelOffloader(
self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device # , debug=True
)
self.offloader_single = custom_offloading_utils.ModelOffloader(
self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device # , debug=True
)
print(
f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
)
def move_to_device_except_swap_blocks(self, device: torch.device):
# assume model is on cpu. do not move blocks to device to reduce temporary memory usage
if self.blocks_to_swap:
save_double_blocks = self.double_blocks
save_single_blocks = self.single_blocks
self.double_blocks = None
self.single_blocks = None
self.to(device)
if self.blocks_to_swap:
self.double_blocks = save_double_blocks
self.single_blocks = save_single_blocks
def prepare_block_swap_before_forward(self):
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return
self.offloader_double.prepare_block_devices_before_forward(self.double_blocks)
self.offloader_single.prepare_block_devices_before_forward(self.single_blocks)
def forward(
self,
img: Tensor,
img_ids: Tensor,
controlnet_cond: Tensor,
txt: Tensor,
txt_ids: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor | None = None,
txt_attention_mask: Tensor | None = None,
) -> tuple[tuple[Tensor]]:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
# running on sequences img
img = self.img_in(img)
controlnet_cond = self.input_hint_block(controlnet_cond)
controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
controlnet_cond = self.pos_embed_input(controlnet_cond)
img = img + controlnet_cond
vec = self.time_in(timestep_embedding(timesteps, 256))
if self.params.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
block_samples = ()
block_single_samples = ()
if not self.blocks_to_swap:
for block_idx, block in enumerate(self.double_blocks):
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
block_samples = block_samples + (img,)
img = torch.cat((txt, img), 1)
for block in self.single_blocks:
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
block_single_samples = block_single_samples + (img,)
else:
for block_idx, block in enumerate(self.double_blocks):
self.offloader_double.wait_for_block(block_idx)
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
block_samples = block_samples + (img,)
self.offloader_double.submit_move_blocks(self.double_blocks, block_idx)
img = torch.cat((txt, img), 1)
for block_idx, block in enumerate(self.single_blocks):
self.offloader_single.wait_for_block(block_idx)
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
block_single_samples = block_single_samples + (img,)
self.offloader_single.submit_move_blocks(self.single_blocks, block_idx)
controlnet_block_samples = ()
controlnet_single_block_samples = ()
for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks_for_double):
block_sample = controlnet_block(block_sample)
controlnet_block_samples = controlnet_block_samples + (block_sample,)
for block_sample, controlnet_block in zip(block_samples, self.controlnet_single_blocks_for_single):
block_sample = controlnet_block(block_sample)
controlnet_single_block_samples = controlnet_single_block_samples + (block_sample,)
return controlnet_block_samples, controlnet_single_block_samples
"""
class FluxUpper(nn.Module):
""

View File

@@ -153,6 +153,14 @@ def load_ae(
return ae
def load_controlnet(name, device, transformer=None):
with torch.device(device):
controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params)
if transformer is not None:
controlnet.load_state_dict(transformer.state_dict(), strict=False)
return controlnet
def load_clip_l(
ckpt_path: Optional[str],
dtype: torch.dtype,