diff --git a/README.md b/README.md index 3b320acc..cb3803f0 100644 --- a/README.md +++ b/README.md @@ -140,6 +140,13 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ## Change History +### 6 Jun. 2023, 2023/06/06 + +- Fix `train_network.py` to probably work with older versions of LyCORIS. +- `gen_img_diffusers.py` now supports `BREAK` syntax. +- `train_network.py`がLyCORISの以前のバージョンでも恐らく動作するよう修正しました。 +- `gen_img_diffusers.py` で `BREAK` 構文をサポートしました。 + ### 3 Jun. 2023, 2023/06/03 - Max Norm Regularization is now available in `train_network.py`. [PR #545](https://github.com/kohya-ss/sd-scripts/pull/545) Thanks to AI-Casanova! diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 33b7a65c..28f7323a 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -463,7 +463,7 @@ def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xform upsampler.forward = make_replacer(upsampler) """ - + def replace_vae_attn_to_memory_efficient(): print("AttentionBlock.forward has been replaced to FlashAttention (not xformers)") @@ -1801,6 +1801,9 @@ def parse_prompt_attention(text): for p in range(start_position, len(res)): res[p][1] *= multiplier + # keep break as separate token + text = text.replace("BREAK", "\\BREAK\\") + for m in re_attention.finditer(text): text = m.group(0) weight = m.group(1) @@ -1832,7 +1835,7 @@ def parse_prompt_attention(text): # merge runs of identical weights i = 0 while i + 1 < len(res): - if res[i][1] == res[i + 1][1]: + if res[i][1] == res[i + 1][1] and res[i][0].strip() != "BREAK" and res[i + 1][0].strip() != "BREAK": res[i][0] += res[i + 1][0] res.pop(i + 1) else: @@ -1849,11 +1852,25 @@ def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length: tokens = [] weights = [] truncated = False + for text in prompt: texts_and_weights = parse_prompt_attention(text) text_token = [] text_weight = [] for word, weight in texts_and_weights: + if word.strip() == "BREAK": + # pad until next multiple of tokenizer's max token length + pad_len = pipe.tokenizer.model_max_length - (len(text_token) % pipe.tokenizer.model_max_length) + print(f"BREAK pad_len: {pad_len}") + for i in range(pad_len): + # v2のときEOSをつけるべきかどうかわからないぜ + # if i == 0: + # text_token.append(pipe.tokenizer.eos_token_id) + # else: + text_token.append(pipe.tokenizer.pad_token_id) + text_weight.append(1.0) + continue + # tokenize and discard the starting and the ending token token = pipe.tokenizer(word).input_ids[1:-1] diff --git a/networks/lora.py b/networks/lora.py index 9f2f5094..27f59344 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -400,7 +400,7 @@ def parse_block_lr_kwargs(nw_kwargs): return down_lr_weight, mid_lr_weight, up_lr_weight -def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, dropout=None, **kwargs): +def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, neuron_dropout=None, **kwargs): if network_dim is None: network_dim = 4 # default if network_alpha is None: @@ -455,7 +455,7 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha, - dropout=dropout, + dropout=neuron_dropout, rank_dropout=rank_dropout, module_dropout=module_dropout, conv_lora_dim=conv_dim, diff --git a/train_network.py b/train_network.py index 051d0d18..b62aef7e 100644 --- a/train_network.py +++ b/train_network.py @@ -212,7 +212,7 @@ def train(args): else: # LyCORIS will work with this... network = network_module.create_network( - 1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, dropout=args.network_dropout, **net_kwargs + 1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, neuron_dropout=args.network_dropout, **net_kwargs ) if network is None: return @@ -724,7 +724,7 @@ def train(args): progress_bar.set_postfix(**logs) if args.scale_weight_norms: - progress_bar.set_postfix(**max_mean_logs) + progress_bar.set_postfix(**{**max_mean_logs, **logs}) if args.logging_dir is not None: logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm)