Merge pull request #1813 from minux302/flux-controlnet

Add Flux ControlNet
This commit is contained in:
Kohya S.
2024-12-02 23:32:16 +09:00
committed by GitHub
6 changed files with 1212 additions and 18 deletions

View File

@@ -2,15 +2,15 @@
# license: Apache-2.0 License
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass
import math
import os
import time
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
from library import utils
from library.device_utils import init_ipex, clean_memory_on_device
from library.device_utils import clean_memory_on_device, init_ipex
init_ipex()
@@ -18,6 +18,7 @@ import torch
from einops import rearrange
from torch import Tensor, nn
from torch.utils.checkpoint import checkpoint
from library import custom_offloading_utils
# USE_REENTRANT = True
@@ -1013,6 +1014,8 @@ class Flux(nn.Module):
txt_ids: Tensor,
timesteps: Tensor,
y: Tensor,
block_controlnet_hidden_states=None,
block_controlnet_single_hidden_states=None,
guidance: Tensor | None = None,
txt_attention_mask: Tensor | None = None,
) -> Tensor:
@@ -1031,18 +1034,29 @@ class Flux(nn.Module):
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
if block_controlnet_hidden_states is not None:
controlnet_depth = len(block_controlnet_hidden_states)
if block_controlnet_single_hidden_states is not None:
controlnet_single_depth = len(block_controlnet_single_hidden_states)
if not self.blocks_to_swap:
for block in self.double_blocks:
for block_idx, block in enumerate(self.double_blocks):
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
if block_controlnet_hidden_states is not None and controlnet_depth > 0:
img = img + block_controlnet_hidden_states[block_idx % controlnet_depth]
img = torch.cat((txt, img), 1)
for block in self.single_blocks:
for block_idx, block in enumerate(self.single_blocks):
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
if block_controlnet_single_hidden_states is not None and controlnet_single_depth > 0:
img = img + block_controlnet_single_hidden_states[block_idx % controlnet_single_depth]
else:
for block_idx, block in enumerate(self.double_blocks):
self.offloader_double.wait_for_block(block_idx)
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
if block_controlnet_hidden_states is not None and controlnet_depth > 0:
img = img + block_controlnet_hidden_states[block_idx % controlnet_depth]
self.offloader_double.submit_move_blocks(self.double_blocks, block_idx)
@@ -1052,6 +1066,8 @@ class Flux(nn.Module):
self.offloader_single.wait_for_block(block_idx)
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
if block_controlnet_single_hidden_states is not None and controlnet_single_depth > 0:
img = img + block_controlnet_single_hidden_states[block_idx % controlnet_single_depth]
self.offloader_single.submit_move_blocks(self.single_blocks, block_idx)
@@ -1066,6 +1082,246 @@ class Flux(nn.Module):
return img
def zero_module(module):
for p in module.parameters():
nn.init.zeros_(p)
return module
class ControlNetFlux(nn.Module):
"""
Transformer model for flow matching on sequences.
"""
def __init__(self, params: FluxParams, controlnet_depth=2, controlnet_single_depth=0):
super().__init__()
self.params = params
self.in_channels = params.in_channels
self.out_channels = self.in_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
pe_dim = params.hidden_size // params.num_heads
if sum(params.axes_dim) != pe_dim:
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
)
for _ in range(controlnet_depth)
]
)
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
for _ in range(controlnet_single_depth)
]
)
self.gradient_checkpointing = False
self.cpu_offload_checkpointing = False
self.blocks_to_swap = None
self.offloader_double = None
self.offloader_single = None
self.num_double_blocks = len(self.double_blocks)
self.num_single_blocks = len(self.single_blocks)
# add ControlNet blocks
self.controlnet_blocks = nn.ModuleList([])
for _ in range(controlnet_depth):
controlnet_block = nn.Linear(self.hidden_size, self.hidden_size)
controlnet_block = zero_module(controlnet_block)
self.controlnet_blocks.append(controlnet_block)
self.controlnet_blocks_for_single = nn.ModuleList([])
for _ in range(controlnet_single_depth):
controlnet_block = nn.Linear(self.hidden_size, self.hidden_size)
controlnet_block = zero_module(controlnet_block)
self.controlnet_blocks_for_single.append(controlnet_block)
self.pos_embed_input = nn.Linear(self.in_channels, self.hidden_size, bias=True)
self.gradient_checkpointing = False
self.input_hint_block = nn.Sequential(
nn.Conv2d(3, 16, 3, padding=1),
nn.SiLU(),
nn.Conv2d(16, 16, 3, padding=1),
nn.SiLU(),
nn.Conv2d(16, 16, 3, padding=1, stride=2),
nn.SiLU(),
nn.Conv2d(16, 16, 3, padding=1),
nn.SiLU(),
nn.Conv2d(16, 16, 3, padding=1, stride=2),
nn.SiLU(),
nn.Conv2d(16, 16, 3, padding=1),
nn.SiLU(),
nn.Conv2d(16, 16, 3, padding=1, stride=2),
nn.SiLU(),
zero_module(nn.Conv2d(16, 16, 3, padding=1))
)
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype
def enable_gradient_checkpointing(self, cpu_offload: bool = False):
self.gradient_checkpointing = True
self.cpu_offload_checkpointing = cpu_offload
self.time_in.enable_gradient_checkpointing()
self.vector_in.enable_gradient_checkpointing()
if self.guidance_in.__class__ != nn.Identity:
self.guidance_in.enable_gradient_checkpointing()
for block in self.double_blocks + self.single_blocks:
block.enable_gradient_checkpointing(cpu_offload=cpu_offload)
print(f"FLUX: Gradient checkpointing enabled. CPU offload: {cpu_offload}")
def disable_gradient_checkpointing(self):
self.gradient_checkpointing = False
self.cpu_offload_checkpointing = False
self.time_in.disable_gradient_checkpointing()
self.vector_in.disable_gradient_checkpointing()
if self.guidance_in.__class__ != nn.Identity:
self.guidance_in.disable_gradient_checkpointing()
for block in self.double_blocks + self.single_blocks:
block.disable_gradient_checkpointing()
print("FLUX: Gradient checkpointing disabled.")
def enable_block_swap(self, num_blocks: int, device: torch.device):
self.blocks_to_swap = num_blocks
double_blocks_to_swap = num_blocks // 2
single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2
assert double_blocks_to_swap <= self.num_double_blocks - 2 and single_blocks_to_swap <= self.num_single_blocks - 2, (
f"Cannot swap more than {self.num_double_blocks - 2} double blocks and {self.num_single_blocks - 2} single blocks. "
f"Requested {double_blocks_to_swap} double blocks and {single_blocks_to_swap} single blocks."
)
self.offloader_double = custom_offloading_utils.ModelOffloader(
self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device # , debug=True
)
self.offloader_single = custom_offloading_utils.ModelOffloader(
self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device # , debug=True
)
print(
f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
)
def move_to_device_except_swap_blocks(self, device: torch.device):
# assume model is on cpu. do not move blocks to device to reduce temporary memory usage
if self.blocks_to_swap:
save_double_blocks = self.double_blocks
save_single_blocks = self.single_blocks
self.double_blocks = None
self.single_blocks = None
self.to(device)
if self.blocks_to_swap:
self.double_blocks = save_double_blocks
self.single_blocks = save_single_blocks
def prepare_block_swap_before_forward(self):
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return
self.offloader_double.prepare_block_devices_before_forward(self.double_blocks)
self.offloader_single.prepare_block_devices_before_forward(self.single_blocks)
def forward(
self,
img: Tensor,
img_ids: Tensor,
controlnet_cond: Tensor,
txt: Tensor,
txt_ids: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor | None = None,
txt_attention_mask: Tensor | None = None,
) -> tuple[tuple[Tensor]]:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
# running on sequences img
img = self.img_in(img)
controlnet_cond = self.input_hint_block(controlnet_cond)
controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
controlnet_cond = self.pos_embed_input(controlnet_cond)
img = img + controlnet_cond
vec = self.time_in(timestep_embedding(timesteps, 256))
if self.params.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
block_samples = ()
block_single_samples = ()
if not self.blocks_to_swap:
for block in self.double_blocks:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
block_samples = block_samples + (img,)
img = torch.cat((txt, img), 1)
for block in self.single_blocks:
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
block_single_samples = block_single_samples + (img,)
else:
for block_idx, block in enumerate(self.double_blocks):
self.offloader_double.wait_for_block(block_idx)
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
block_samples = block_samples + (img,)
self.offloader_double.submit_move_blocks(self.double_blocks, block_idx)
img = torch.cat((txt, img), 1)
for block_idx, block in enumerate(self.single_blocks):
self.offloader_single.wait_for_block(block_idx)
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
block_single_samples = block_single_samples + (img,)
self.offloader_single.submit_move_blocks(self.single_blocks, block_idx)
controlnet_block_samples = ()
controlnet_single_block_samples = ()
for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks):
block_sample = controlnet_block(block_sample)
controlnet_block_samples = controlnet_block_samples + (block_sample,)
for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks_for_single):
block_sample = controlnet_block(block_sample)
controlnet_single_block_samples = controlnet_single_block_samples + (block_sample,)
return controlnet_block_samples, controlnet_single_block_samples
"""
class FluxUpper(nn.Module):
""

View File

@@ -40,6 +40,7 @@ def sample_images(
text_encoders,
sample_prompts_te_outputs,
prompt_replacement=None,
controlnet=None
):
if steps == 0:
if not args.sample_at_first:
@@ -67,6 +68,8 @@ def sample_images(
flux = accelerator.unwrap_model(flux)
if text_encoders is not None:
text_encoders = [accelerator.unwrap_model(te) for te in text_encoders]
if controlnet is not None:
controlnet = accelerator.unwrap_model(controlnet)
# print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders])
prompts = train_util.load_prompts(args.sample_prompts)
@@ -98,6 +101,7 @@ def sample_images(
steps,
sample_prompts_te_outputs,
prompt_replacement,
controlnet
)
else:
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
@@ -121,6 +125,7 @@ def sample_images(
steps,
sample_prompts_te_outputs,
prompt_replacement,
controlnet
)
torch.set_rng_state(rng_state)
@@ -142,6 +147,7 @@ def sample_image_inference(
steps,
sample_prompts_te_outputs,
prompt_replacement,
controlnet
):
assert isinstance(prompt_dict, dict)
# negative_prompt = prompt_dict.get("negative_prompt")
@@ -150,7 +156,7 @@ def sample_image_inference(
height = prompt_dict.get("height", 512)
scale = prompt_dict.get("scale", 3.5)
seed = prompt_dict.get("seed")
# controlnet_image = prompt_dict.get("controlnet_image")
controlnet_image = prompt_dict.get("controlnet_image")
prompt: str = prompt_dict.get("prompt", "")
# sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
@@ -169,7 +175,6 @@ def sample_image_inference(
# if negative_prompt is None:
# negative_prompt = ""
height = max(64, height - height % 16) # round to divisible by 16
width = max(64, width - width % 16) # round to divisible by 16
logger.info(f"prompt: {prompt}")
@@ -223,10 +228,15 @@ def sample_image_inference(
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype)
t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None
with accelerator.autocast(), torch.no_grad():
x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask)
if controlnet_image is not None:
controlnet_image = Image.open(controlnet_image).convert("RGB")
controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS)
controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1)
controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(weight_dtype).to(accelerator.device)
with accelerator.autocast(), torch.no_grad():
x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image)
x = x.float()
x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
# latent to image
@@ -301,18 +311,39 @@ def denoise(
timesteps: list[float],
guidance: float = 4.0,
t5_attn_mask: Optional[torch.Tensor] = None,
controlnet: Optional[flux_models.ControlNetFlux] = None,
controlnet_img: Optional[torch.Tensor] = None,
):
# this is ignored for schnell
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
model.prepare_block_swap_before_forward()
if controlnet is not None:
block_samples, block_single_samples = controlnet(
img=img,
img_ids=img_ids,
controlnet_cond=controlnet_img,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
)
else:
block_samples = None
block_single_samples = None
pred = model(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
block_controlnet_hidden_states=block_samples,
block_controlnet_single_hidden_states=block_single_samples,
timesteps=t_vec,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
@@ -432,7 +463,7 @@ def get_noisy_model_input_and_timesteps(
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
return noisy_model_input, timesteps, sigmas
return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas
def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas):
@@ -532,6 +563,12 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):
help="path to t5xxl (*.sft or *.safetensors), should be float16 / t5xxlのパス*.sftまたは*.safetensors、float16が前提",
)
parser.add_argument("--ae", type=str, help="path to ae (*.sft or *.safetensors) / aeのパス*.sftまたは*.safetensors")
parser.add_argument(
"--controlnet",
type=str,
default=None,
help="path to controlnet (*.sft or *.safetensors) / aeのパス*.sftまたは*.safetensors"
)
parser.add_argument(
"--t5xxl_max_token_length",
type=int,

View File

@@ -1,14 +1,14 @@
from dataclasses import replace
import json
import os
from dataclasses import replace
from typing import List, Optional, Tuple, Union
import einops
import torch
from safetensors.torch import load_file
from safetensors import safe_open
from accelerate import init_empty_weights
from transformers import CLIPTextModel, CLIPConfig, T5EncoderModel, T5Config
from safetensors import safe_open
from safetensors.torch import load_file
from transformers import CLIPConfig, CLIPTextModel, T5Config, T5EncoderModel
from library.utils import setup_logging
@@ -153,6 +153,22 @@ def load_ae(
return ae
def load_controlnet(
ckpt_path: Optional[str], is_schnell: bool, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False
):
logger.info("Building ControlNet")
name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL
with torch.device(device):
controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params).to(dtype)
if ckpt_path is not None:
logger.info(f"Loading state dict from {ckpt_path}")
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
info = controlnet.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded ControlNet: {info}")
return controlnet
def load_clip_l(
ckpt_path: Optional[str],
dtype: torch.dtype,