From 24823b061df14e0d5a947ba453997aaaa7b5a903 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 6 Jun 2023 21:53:58 +0900 Subject: [PATCH] support BREAK in generation script --- gen_img_diffusers.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 27bd7460..33c40441 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -457,7 +457,7 @@ def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xform upsampler.forward = make_replacer(upsampler) """ - + def replace_vae_attn_to_memory_efficient(): print("AttentionBlock.forward has been replaced to FlashAttention (not xformers)") @@ -1795,6 +1795,9 @@ def parse_prompt_attention(text): for p in range(start_position, len(res)): res[p][1] *= multiplier + # keep break as separate token + text = text.replace("BREAK", "\\BREAK\\") + for m in re_attention.finditer(text): text = m.group(0) weight = m.group(1) @@ -1826,7 +1829,7 @@ def parse_prompt_attention(text): # merge runs of identical weights i = 0 while i + 1 < len(res): - if res[i][1] == res[i + 1][1]: + if res[i][1] == res[i + 1][1] and res[i][0].strip() != "BREAK" and res[i + 1][0].strip() != "BREAK": res[i][0] += res[i + 1][0] res.pop(i + 1) else: @@ -1843,11 +1846,25 @@ def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length: tokens = [] weights = [] truncated = False + for text in prompt: texts_and_weights = parse_prompt_attention(text) text_token = [] text_weight = [] for word, weight in texts_and_weights: + if word.strip() == "BREAK": + # pad until next multiple of tokenizer's max token length + pad_len = pipe.tokenizer.model_max_length - (len(text_token) % pipe.tokenizer.model_max_length) + print(f"BREAK pad_len: {pad_len}") + for i in range(pad_len): + # v2のときEOSをつけるべきかどうかわからないぜ + # if i == 0: + # text_token.append(pipe.tokenizer.eos_token_id) + # else: + text_token.append(pipe.tokenizer.pad_token_id) + text_weight.append(1.0) + continue + # tokenize and discard the starting and the ending token token = pipe.tokenizer(word).input_ids[1:-1]