fix for adding controlnet

This commit is contained in:
minux302
2024-11-15 23:48:51 +09:00
parent ccfaa001e7
commit 42f6edf3a8
4 changed files with 855 additions and 531 deletions

View File

@@ -40,6 +40,7 @@ def sample_images(
text_encoders,
sample_prompts_te_outputs,
prompt_replacement=None,
controlnet=None
):
if steps == 0:
if not args.sample_at_first:
@@ -67,6 +68,8 @@ def sample_images(
flux = accelerator.unwrap_model(flux)
if text_encoders is not None:
text_encoders = [accelerator.unwrap_model(te) for te in text_encoders]
if controlnet is not None:
controlnet = accelerator.unwrap_model(controlnet)
# print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders])
prompts = train_util.load_prompts(args.sample_prompts)
@@ -98,6 +101,7 @@ def sample_images(
steps,
sample_prompts_te_outputs,
prompt_replacement,
controlnet
)
else:
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
@@ -121,6 +125,7 @@ def sample_images(
steps,
sample_prompts_te_outputs,
prompt_replacement,
controlnet
)
torch.set_rng_state(rng_state)
@@ -142,6 +147,7 @@ def sample_image_inference(
steps,
sample_prompts_te_outputs,
prompt_replacement,
controlnet
):
assert isinstance(prompt_dict, dict)
# negative_prompt = prompt_dict.get("negative_prompt")
@@ -150,7 +156,7 @@ def sample_image_inference(
height = prompt_dict.get("height", 512)
scale = prompt_dict.get("scale", 3.5)
seed = prompt_dict.get("seed")
# controlnet_image = prompt_dict.get("controlnet_image")
controlnet_image = prompt_dict.get("controlnet_image")
prompt: str = prompt_dict.get("prompt", "")
# sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
@@ -169,6 +175,9 @@ def sample_image_inference(
# if negative_prompt is None:
# negative_prompt = ""
if controlnet_image is not None:
controlnet_image = Image.open(controlnet_image).convert("RGB")
controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS)
height = max(64, height - height % 16) # round to divisible by 16
width = max(64, width - width % 16) # round to divisible by 16
@@ -224,7 +233,7 @@ def sample_image_inference(
t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None
with accelerator.autocast(), torch.no_grad():
x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask)
x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image)
x = x.float()
x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
@@ -301,18 +310,37 @@ def denoise(
timesteps: list[float],
guidance: float = 4.0,
t5_attn_mask: Optional[torch.Tensor] = None,
controlnet: Optional[flux_models.ControlNetFlux] = None,
controlnet_img: Optional[torch.Tensor] = None,
):
# this is ignored for schnell
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
model.prepare_block_swap_before_forward()
if controlnet is not None:
block_samples, block_single_samples = controlnet(
img=img,
img_ids=img_ids,
controlnet_cond=controlnet_img,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
)
else:
block_samples = None
block_single_samples = None
pred = model(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
block_controlnet_hidden_states=block_samples,
block_controlnet_single_hidden_states=block_single_samples,
timesteps=t_vec,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,

View File

@@ -153,11 +153,14 @@ def load_ae(
return ae
def load_controlnet(name, device, transformer=None):
with torch.device(device):
def load_controlnet():
# TODO
is_schnell = False
name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL
with torch.device("meta"):
controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params)
if transformer is not None:
controlnet.load_state_dict(transformer.state_dict(), strict=False)
# if transformer is not None:
# controlnet.load_state_dict(transformer.state_dict(), strict=False)
return controlnet