mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix clip_skip not work in weight capt, sample gen
This commit is contained in:
@@ -265,11 +265,6 @@ def get_unweighted_text_embeddings(
|
|||||||
text_embedding = enc_out["hidden_states"][-clip_skip]
|
text_embedding = enc_out["hidden_states"][-clip_skip]
|
||||||
text_embedding = text_encoder.text_model.final_layer_norm(text_embedding)
|
text_embedding = text_encoder.text_model.final_layer_norm(text_embedding)
|
||||||
|
|
||||||
# cover the head and the tail by the starting and the ending tokens
|
|
||||||
text_input_chunk[:, 0] = text_input[0, 0]
|
|
||||||
text_input_chunk[:, -1] = text_input[0, -1]
|
|
||||||
text_embedding = text_encoder(text_input_chunk, attention_mask=None)[0]
|
|
||||||
|
|
||||||
if no_boseos_middle:
|
if no_boseos_middle:
|
||||||
if i == 0:
|
if i == 0:
|
||||||
# discard the ending token
|
# discard the ending token
|
||||||
@@ -284,7 +279,12 @@ def get_unweighted_text_embeddings(
|
|||||||
text_embeddings.append(text_embedding)
|
text_embeddings.append(text_embedding)
|
||||||
text_embeddings = torch.concat(text_embeddings, axis=1)
|
text_embeddings = torch.concat(text_embeddings, axis=1)
|
||||||
else:
|
else:
|
||||||
|
if clip_skip is None or clip_skip == 1:
|
||||||
text_embeddings = text_encoder(text_input)[0]
|
text_embeddings = text_encoder(text_input)[0]
|
||||||
|
else:
|
||||||
|
enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True)
|
||||||
|
text_embeddings = enc_out["hidden_states"][-clip_skip]
|
||||||
|
text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings)
|
||||||
return text_embeddings
|
return text_embeddings
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -245,11 +245,6 @@ def get_unweighted_text_embeddings(
|
|||||||
text_embedding = enc_out["hidden_states"][-clip_skip]
|
text_embedding = enc_out["hidden_states"][-clip_skip]
|
||||||
text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding)
|
text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding)
|
||||||
|
|
||||||
# cover the head and the tail by the starting and the ending tokens
|
|
||||||
text_input_chunk[:, 0] = text_input[0, 0]
|
|
||||||
text_input_chunk[:, -1] = text_input[0, -1]
|
|
||||||
text_embedding = pipe.text_encoder(text_input_chunk, attention_mask=None)[0]
|
|
||||||
|
|
||||||
if no_boseos_middle:
|
if no_boseos_middle:
|
||||||
if i == 0:
|
if i == 0:
|
||||||
# discard the ending token
|
# discard the ending token
|
||||||
@@ -264,7 +259,12 @@ def get_unweighted_text_embeddings(
|
|||||||
text_embeddings.append(text_embedding)
|
text_embeddings.append(text_embedding)
|
||||||
text_embeddings = torch.concat(text_embeddings, axis=1)
|
text_embeddings = torch.concat(text_embeddings, axis=1)
|
||||||
else:
|
else:
|
||||||
|
if clip_skip is None or clip_skip == 1:
|
||||||
text_embeddings = pipe.text_encoder(text_input)[0]
|
text_embeddings = pipe.text_encoder(text_input)[0]
|
||||||
|
else:
|
||||||
|
enc_out = pipe.text_encoder(text_input, output_hidden_states=True, return_dict=True)
|
||||||
|
text_embeddings = enc_out["hidden_states"][-clip_skip]
|
||||||
|
text_embeddings = pipe.text_encoder.text_model.final_layer_norm(text_embeddings)
|
||||||
return text_embeddings
|
return text_embeddings
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user