diff --git a/library/flux_models.py b/library/flux_models.py index 4123b40e..328ad481 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -1093,7 +1093,7 @@ class ControlNetFlux(nn.Module): Transformer model for flow matching on sequences. """ - def __init__(self, params: FluxParams, controlnet_depth=2): + def __init__(self, params: FluxParams, controlnet_depth=2, controlnet_single_depth=0): super().__init__() self.params = params @@ -1128,7 +1128,7 @@ class ControlNetFlux(nn.Module): self.single_blocks = nn.ModuleList( [ SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) - for _ in range(0) # TODO + for _ in range(controlnet_single_depth) ] ) @@ -1148,7 +1148,7 @@ class ControlNetFlux(nn.Module): controlnet_block = zero_module(controlnet_block) self.controlnet_blocks.append(controlnet_block) self.controlnet_blocks_for_single = nn.ModuleList([]) - for _ in range(0): # TODO + for _ in range(controlnet_single_depth): controlnet_block = nn.Linear(self.hidden_size, self.hidden_size) controlnet_block = zero_module(controlnet_block) self.controlnet_blocks_for_single.append(controlnet_block)