fix depth value

This commit is contained in:
minux302
2024-11-18 13:03:28 +00:00
parent 4dd4cd6ec8
commit 31ca899b6b

View File

@@ -1093,7 +1093,7 @@ class ControlNetFlux(nn.Module):
Transformer model for flow matching on sequences. Transformer model for flow matching on sequences.
""" """
def __init__(self, params: FluxParams, controlnet_depth=2): def __init__(self, params: FluxParams, controlnet_depth=2, controlnet_single_depth=0):
super().__init__() super().__init__()
self.params = params self.params = params
@@ -1128,7 +1128,7 @@ class ControlNetFlux(nn.Module):
self.single_blocks = nn.ModuleList( self.single_blocks = nn.ModuleList(
[ [
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
for _ in range(0) # TODO for _ in range(controlnet_single_depth)
] ]
) )
@@ -1148,7 +1148,7 @@ class ControlNetFlux(nn.Module):
controlnet_block = zero_module(controlnet_block) controlnet_block = zero_module(controlnet_block)
self.controlnet_blocks.append(controlnet_block) self.controlnet_blocks.append(controlnet_block)
self.controlnet_blocks_for_single = nn.ModuleList([]) self.controlnet_blocks_for_single = nn.ModuleList([])
for _ in range(0): # TODO for _ in range(controlnet_single_depth):
controlnet_block = nn.Linear(self.hidden_size, self.hidden_size) controlnet_block = nn.Linear(self.hidden_size, self.hidden_size)
controlnet_block = zero_module(controlnet_block) controlnet_block = zero_module(controlnet_block)
self.controlnet_blocks_for_single.append(controlnet_block) self.controlnet_blocks_for_single.append(controlnet_block)