mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
fix for adding controlnet
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -125,9 +125,6 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
|
|
||||||
ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
||||||
|
|
||||||
controlnet = flux_utils.load_controlnet()
|
|
||||||
controlnet.train()
|
|
||||||
|
|
||||||
return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model, controlnet
|
return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model, controlnet
|
||||||
|
|
||||||
def get_tokenize_strategy(self, args):
|
def get_tokenize_strategy(self, args):
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ def sample_images(
|
|||||||
text_encoders,
|
text_encoders,
|
||||||
sample_prompts_te_outputs,
|
sample_prompts_te_outputs,
|
||||||
prompt_replacement=None,
|
prompt_replacement=None,
|
||||||
|
controlnet=None
|
||||||
):
|
):
|
||||||
if steps == 0:
|
if steps == 0:
|
||||||
if not args.sample_at_first:
|
if not args.sample_at_first:
|
||||||
@@ -67,6 +68,8 @@ def sample_images(
|
|||||||
flux = accelerator.unwrap_model(flux)
|
flux = accelerator.unwrap_model(flux)
|
||||||
if text_encoders is not None:
|
if text_encoders is not None:
|
||||||
text_encoders = [accelerator.unwrap_model(te) for te in text_encoders]
|
text_encoders = [accelerator.unwrap_model(te) for te in text_encoders]
|
||||||
|
if controlnet is not None:
|
||||||
|
controlnet = accelerator.unwrap_model(controlnet)
|
||||||
# print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders])
|
# print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders])
|
||||||
|
|
||||||
prompts = train_util.load_prompts(args.sample_prompts)
|
prompts = train_util.load_prompts(args.sample_prompts)
|
||||||
@@ -98,6 +101,7 @@ def sample_images(
|
|||||||
steps,
|
steps,
|
||||||
sample_prompts_te_outputs,
|
sample_prompts_te_outputs,
|
||||||
prompt_replacement,
|
prompt_replacement,
|
||||||
|
controlnet
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
|
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
|
||||||
@@ -121,6 +125,7 @@ def sample_images(
|
|||||||
steps,
|
steps,
|
||||||
sample_prompts_te_outputs,
|
sample_prompts_te_outputs,
|
||||||
prompt_replacement,
|
prompt_replacement,
|
||||||
|
controlnet
|
||||||
)
|
)
|
||||||
|
|
||||||
torch.set_rng_state(rng_state)
|
torch.set_rng_state(rng_state)
|
||||||
@@ -142,6 +147,7 @@ def sample_image_inference(
|
|||||||
steps,
|
steps,
|
||||||
sample_prompts_te_outputs,
|
sample_prompts_te_outputs,
|
||||||
prompt_replacement,
|
prompt_replacement,
|
||||||
|
controlnet
|
||||||
):
|
):
|
||||||
assert isinstance(prompt_dict, dict)
|
assert isinstance(prompt_dict, dict)
|
||||||
# negative_prompt = prompt_dict.get("negative_prompt")
|
# negative_prompt = prompt_dict.get("negative_prompt")
|
||||||
@@ -150,7 +156,7 @@ def sample_image_inference(
|
|||||||
height = prompt_dict.get("height", 512)
|
height = prompt_dict.get("height", 512)
|
||||||
scale = prompt_dict.get("scale", 3.5)
|
scale = prompt_dict.get("scale", 3.5)
|
||||||
seed = prompt_dict.get("seed")
|
seed = prompt_dict.get("seed")
|
||||||
# controlnet_image = prompt_dict.get("controlnet_image")
|
controlnet_image = prompt_dict.get("controlnet_image")
|
||||||
prompt: str = prompt_dict.get("prompt", "")
|
prompt: str = prompt_dict.get("prompt", "")
|
||||||
# sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
|
# sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
|
||||||
|
|
||||||
@@ -169,6 +175,9 @@ def sample_image_inference(
|
|||||||
|
|
||||||
# if negative_prompt is None:
|
# if negative_prompt is None:
|
||||||
# negative_prompt = ""
|
# negative_prompt = ""
|
||||||
|
if controlnet_image is not None:
|
||||||
|
controlnet_image = Image.open(controlnet_image).convert("RGB")
|
||||||
|
controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS)
|
||||||
|
|
||||||
height = max(64, height - height % 16) # round to divisible by 16
|
height = max(64, height - height % 16) # round to divisible by 16
|
||||||
width = max(64, width - width % 16) # round to divisible by 16
|
width = max(64, width - width % 16) # round to divisible by 16
|
||||||
@@ -224,7 +233,7 @@ def sample_image_inference(
|
|||||||
t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None
|
t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None
|
||||||
|
|
||||||
with accelerator.autocast(), torch.no_grad():
|
with accelerator.autocast(), torch.no_grad():
|
||||||
x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask)
|
x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image)
|
||||||
|
|
||||||
x = x.float()
|
x = x.float()
|
||||||
x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
|
x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
|
||||||
@@ -301,18 +310,37 @@ def denoise(
|
|||||||
timesteps: list[float],
|
timesteps: list[float],
|
||||||
guidance: float = 4.0,
|
guidance: float = 4.0,
|
||||||
t5_attn_mask: Optional[torch.Tensor] = None,
|
t5_attn_mask: Optional[torch.Tensor] = None,
|
||||||
|
controlnet: Optional[flux_models.ControlNetFlux] = None,
|
||||||
|
controlnet_img: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
# this is ignored for schnell
|
# this is ignored for schnell
|
||||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||||
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
|
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
|
||||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||||
model.prepare_block_swap_before_forward()
|
model.prepare_block_swap_before_forward()
|
||||||
|
if controlnet is not None:
|
||||||
|
block_samples, block_single_samples = controlnet(
|
||||||
|
img=img,
|
||||||
|
img_ids=img_ids,
|
||||||
|
controlnet_cond=controlnet_img,
|
||||||
|
txt=txt,
|
||||||
|
txt_ids=txt_ids,
|
||||||
|
y=vec,
|
||||||
|
timesteps=t_vec,
|
||||||
|
guidance=guidance_vec,
|
||||||
|
txt_attention_mask=t5_attn_mask,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
block_samples = None
|
||||||
|
block_single_samples = None
|
||||||
pred = model(
|
pred = model(
|
||||||
img=img,
|
img=img,
|
||||||
img_ids=img_ids,
|
img_ids=img_ids,
|
||||||
txt=txt,
|
txt=txt,
|
||||||
txt_ids=txt_ids,
|
txt_ids=txt_ids,
|
||||||
y=vec,
|
y=vec,
|
||||||
|
block_controlnet_hidden_states=block_samples,
|
||||||
|
block_controlnet_single_hidden_states=block_single_samples,
|
||||||
timesteps=t_vec,
|
timesteps=t_vec,
|
||||||
guidance=guidance_vec,
|
guidance=guidance_vec,
|
||||||
txt_attention_mask=t5_attn_mask,
|
txt_attention_mask=t5_attn_mask,
|
||||||
|
|||||||
@@ -153,11 +153,14 @@ def load_ae(
|
|||||||
return ae
|
return ae
|
||||||
|
|
||||||
|
|
||||||
def load_controlnet(name, device, transformer=None):
|
def load_controlnet():
|
||||||
with torch.device(device):
|
# TODO
|
||||||
|
is_schnell = False
|
||||||
|
name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL
|
||||||
|
with torch.device("meta"):
|
||||||
controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params)
|
controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params)
|
||||||
if transformer is not None:
|
# if transformer is not None:
|
||||||
controlnet.load_state_dict(transformer.state_dict(), strict=False)
|
# controlnet.load_state_dict(transformer.state_dict(), strict=False)
|
||||||
return controlnet
|
return controlnet
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user