mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
support BREAK in generation script
This commit is contained in:
@@ -457,7 +457,7 @@ def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xform
|
|||||||
|
|
||||||
upsampler.forward = make_replacer(upsampler)
|
upsampler.forward = make_replacer(upsampler)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def replace_vae_attn_to_memory_efficient():
|
def replace_vae_attn_to_memory_efficient():
|
||||||
print("AttentionBlock.forward has been replaced to FlashAttention (not xformers)")
|
print("AttentionBlock.forward has been replaced to FlashAttention (not xformers)")
|
||||||
@@ -1795,6 +1795,9 @@ def parse_prompt_attention(text):
|
|||||||
for p in range(start_position, len(res)):
|
for p in range(start_position, len(res)):
|
||||||
res[p][1] *= multiplier
|
res[p][1] *= multiplier
|
||||||
|
|
||||||
|
# keep break as separate token
|
||||||
|
text = text.replace("BREAK", "\\BREAK\\")
|
||||||
|
|
||||||
for m in re_attention.finditer(text):
|
for m in re_attention.finditer(text):
|
||||||
text = m.group(0)
|
text = m.group(0)
|
||||||
weight = m.group(1)
|
weight = m.group(1)
|
||||||
@@ -1826,7 +1829,7 @@ def parse_prompt_attention(text):
|
|||||||
# merge runs of identical weights
|
# merge runs of identical weights
|
||||||
i = 0
|
i = 0
|
||||||
while i + 1 < len(res):
|
while i + 1 < len(res):
|
||||||
if res[i][1] == res[i + 1][1]:
|
if res[i][1] == res[i + 1][1] and res[i][0].strip() != "BREAK" and res[i + 1][0].strip() != "BREAK":
|
||||||
res[i][0] += res[i + 1][0]
|
res[i][0] += res[i + 1][0]
|
||||||
res.pop(i + 1)
|
res.pop(i + 1)
|
||||||
else:
|
else:
|
||||||
@@ -1843,11 +1846,25 @@ def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length:
|
|||||||
tokens = []
|
tokens = []
|
||||||
weights = []
|
weights = []
|
||||||
truncated = False
|
truncated = False
|
||||||
|
|
||||||
for text in prompt:
|
for text in prompt:
|
||||||
texts_and_weights = parse_prompt_attention(text)
|
texts_and_weights = parse_prompt_attention(text)
|
||||||
text_token = []
|
text_token = []
|
||||||
text_weight = []
|
text_weight = []
|
||||||
for word, weight in texts_and_weights:
|
for word, weight in texts_and_weights:
|
||||||
|
if word.strip() == "BREAK":
|
||||||
|
# pad until next multiple of tokenizer's max token length
|
||||||
|
pad_len = pipe.tokenizer.model_max_length - (len(text_token) % pipe.tokenizer.model_max_length)
|
||||||
|
print(f"BREAK pad_len: {pad_len}")
|
||||||
|
for i in range(pad_len):
|
||||||
|
# v2のときEOSをつけるべきかどうかわからないぜ
|
||||||
|
# if i == 0:
|
||||||
|
# text_token.append(pipe.tokenizer.eos_token_id)
|
||||||
|
# else:
|
||||||
|
text_token.append(pipe.tokenizer.pad_token_id)
|
||||||
|
text_weight.append(1.0)
|
||||||
|
continue
|
||||||
|
|
||||||
# tokenize and discard the starting and the ending token
|
# tokenize and discard the starting and the ending token
|
||||||
token = pipe.tokenizer(word).input_ids[1:-1]
|
token = pipe.tokenizer(word).input_ids[1:-1]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user