Merge branch 'main' into original-u-net

This commit is contained in:
Kohya S
2023-06-07 07:42:27 +09:00
4 changed files with 30 additions and 6 deletions

View File

@@ -140,6 +140,13 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
## Change History ## Change History
### 6 Jun. 2023, 2023/06/06
- Fix `train_network.py` to probably work with older versions of LyCORIS.
- `gen_img_diffusers.py` now supports `BREAK` syntax.
- `train_network.py`がLyCORISの以前のバージョンでも恐らく動作するよう修正しました。
- `gen_img_diffusers.py` で `BREAK` 構文をサポートしました。
### 3 Jun. 2023, 2023/06/03 ### 3 Jun. 2023, 2023/06/03
- Max Norm Regularization is now available in `train_network.py`. [PR #545](https://github.com/kohya-ss/sd-scripts/pull/545) Thanks to AI-Casanova! - Max Norm Regularization is now available in `train_network.py`. [PR #545](https://github.com/kohya-ss/sd-scripts/pull/545) Thanks to AI-Casanova!

View File

@@ -463,7 +463,7 @@ def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xform
upsampler.forward = make_replacer(upsampler) upsampler.forward = make_replacer(upsampler)
""" """
def replace_vae_attn_to_memory_efficient(): def replace_vae_attn_to_memory_efficient():
print("AttentionBlock.forward has been replaced to FlashAttention (not xformers)") print("AttentionBlock.forward has been replaced to FlashAttention (not xformers)")
@@ -1801,6 +1801,9 @@ def parse_prompt_attention(text):
for p in range(start_position, len(res)): for p in range(start_position, len(res)):
res[p][1] *= multiplier res[p][1] *= multiplier
# keep break as separate token
text = text.replace("BREAK", "\\BREAK\\")
for m in re_attention.finditer(text): for m in re_attention.finditer(text):
text = m.group(0) text = m.group(0)
weight = m.group(1) weight = m.group(1)
@@ -1832,7 +1835,7 @@ def parse_prompt_attention(text):
# merge runs of identical weights # merge runs of identical weights
i = 0 i = 0
while i + 1 < len(res): while i + 1 < len(res):
if res[i][1] == res[i + 1][1]: if res[i][1] == res[i + 1][1] and res[i][0].strip() != "BREAK" and res[i + 1][0].strip() != "BREAK":
res[i][0] += res[i + 1][0] res[i][0] += res[i + 1][0]
res.pop(i + 1) res.pop(i + 1)
else: else:
@@ -1849,11 +1852,25 @@ def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length:
tokens = [] tokens = []
weights = [] weights = []
truncated = False truncated = False
for text in prompt: for text in prompt:
texts_and_weights = parse_prompt_attention(text) texts_and_weights = parse_prompt_attention(text)
text_token = [] text_token = []
text_weight = [] text_weight = []
for word, weight in texts_and_weights: for word, weight in texts_and_weights:
if word.strip() == "BREAK":
# pad until next multiple of tokenizer's max token length
pad_len = pipe.tokenizer.model_max_length - (len(text_token) % pipe.tokenizer.model_max_length)
print(f"BREAK pad_len: {pad_len}")
for i in range(pad_len):
# v2のときEOSをつけるべきかどうかわからないぜ
# if i == 0:
# text_token.append(pipe.tokenizer.eos_token_id)
# else:
text_token.append(pipe.tokenizer.pad_token_id)
text_weight.append(1.0)
continue
# tokenize and discard the starting and the ending token # tokenize and discard the starting and the ending token
token = pipe.tokenizer(word).input_ids[1:-1] token = pipe.tokenizer(word).input_ids[1:-1]

View File

@@ -400,7 +400,7 @@ def parse_block_lr_kwargs(nw_kwargs):
return down_lr_weight, mid_lr_weight, up_lr_weight return down_lr_weight, mid_lr_weight, up_lr_weight
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, dropout=None, **kwargs): def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, neuron_dropout=None, **kwargs):
if network_dim is None: if network_dim is None:
network_dim = 4 # default network_dim = 4 # default
if network_alpha is None: if network_alpha is None:
@@ -455,7 +455,7 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
multiplier=multiplier, multiplier=multiplier,
lora_dim=network_dim, lora_dim=network_dim,
alpha=network_alpha, alpha=network_alpha,
dropout=dropout, dropout=neuron_dropout,
rank_dropout=rank_dropout, rank_dropout=rank_dropout,
module_dropout=module_dropout, module_dropout=module_dropout,
conv_lora_dim=conv_dim, conv_lora_dim=conv_dim,

View File

@@ -212,7 +212,7 @@ def train(args):
else: else:
# LyCORIS will work with this... # LyCORIS will work with this...
network = network_module.create_network( network = network_module.create_network(
1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, dropout=args.network_dropout, **net_kwargs 1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, neuron_dropout=args.network_dropout, **net_kwargs
) )
if network is None: if network is None:
return return
@@ -724,7 +724,7 @@ def train(args):
progress_bar.set_postfix(**logs) progress_bar.set_postfix(**logs)
if args.scale_weight_norms: if args.scale_weight_norms:
progress_bar.set_postfix(**max_mean_logs) progress_bar.set_postfix(**{**max_mean_logs, **logs})
if args.logging_dir is not None: if args.logging_dir is not None:
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm)