support SD3.5M

This commit is contained in:
Kohya S
2024-10-30 12:51:49 +09:00
parent 75554867ce
commit bdddc20d68
5 changed files with 99 additions and 58 deletions

View File

@@ -41,20 +41,21 @@ def analyze_state_dict_state(state_dict: Dict, prefix: str = ""):
# x_block_self_attn_layers.append(int(key.split(".x_block.attn2.ln_k.weight")[0].split(".")[-1]))
x_block_self_attn_layers = []
re_attn = re.compile(r".(\d+).x_block.attn2.ln_k.weight")
re_attn = re.compile(r"\.(\d+)\.x_block\.attn2\.ln_k\.weight")
for key in list(state_dict.keys()):
m = re_attn.match(key)
m = re_attn.search(key)
if m:
x_block_self_attn_layers.append(int(m.group(1)))
assert len(x_block_self_attn_layers) == 0, "x_block_self_attn_layers is not supported"
context_embedder_in_features = context_shape[1]
context_embedder_out_features = context_shape[0]
# only supports 3-5-large and 3-medium
# only supports 3-5-large, medium or 3-medium
if qk_norm is not None:
model_type = "3-5-large"
if len(x_block_self_attn_layers) == 0:
model_type = "3-5-large"
else:
model_type = "3-5-medium"
else:
model_type = "3-medium"