fix for adding controlnet

This commit is contained in:
minux302
2024-11-15 23:48:51 +09:00
parent ccfaa001e7
commit 42f6edf3a8
4 changed files with 855 additions and 531 deletions

View File

@@ -153,11 +153,14 @@ def load_ae(
return ae
def load_controlnet(name, device, transformer=None):
with torch.device(device):
def load_controlnet():
# TODO
is_schnell = False
name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL
with torch.device("meta"):
controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params)
if transformer is not None:
controlnet.load_state_dict(transformer.state_dict(), strict=False)
# if transformer is not None:
# controlnet.load_state_dict(transformer.state_dict(), strict=False)
return controlnet