fix depth value

This commit is contained in:
minux302
2024-11-18 13:03:28 +00:00
parent 4dd4cd6ec8
commit 31ca899b6b

View File

@@ -1093,7 +1093,7 @@ class ControlNetFlux(nn.Module):
Transformer model for flow matching on sequences.
"""
def __init__(self, params: FluxParams, controlnet_depth=2):
def __init__(self, params: FluxParams, controlnet_depth=2, controlnet_single_depth=0):
super().__init__()
self.params = params
@@ -1128,7 +1128,7 @@ class ControlNetFlux(nn.Module):
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
for _ in range(0) # TODO
for _ in range(controlnet_single_depth)
]
)
@@ -1148,7 +1148,7 @@ class ControlNetFlux(nn.Module):
controlnet_block = zero_module(controlnet_block)
self.controlnet_blocks.append(controlnet_block)
self.controlnet_blocks_for_single = nn.ModuleList([])
for _ in range(0): # TODO
for _ in range(controlnet_single_depth):
controlnet_block = nn.Linear(self.hidden_size, self.hidden_size)
controlnet_block = zero_module(controlnet_block)
self.controlnet_blocks_for_single.append(controlnet_block)