mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
Merge c692dbe14c into 1dae34b0af
This commit is contained in:
@@ -435,10 +435,13 @@ class TextualInversionTrainer:
|
|||||||
text_encoders = [accelerator.prepare(text_encoder) for text_encoder in text_encoders]
|
text_encoders = [accelerator.prepare(text_encoder) for text_encoder in text_encoders]
|
||||||
|
|
||||||
index_no_updates_list = []
|
index_no_updates_list = []
|
||||||
|
index_updates_list = []
|
||||||
orig_embeds_params_list = []
|
orig_embeds_params_list = []
|
||||||
for tokenizer, token_ids, text_encoder in zip(tokenizers, token_ids_list, text_encoders):
|
for tokenizer, token_ids, text_encoder in zip(tokenizers, token_ids_list, text_encoders):
|
||||||
index_no_updates = torch.arange(len(tokenizer)) < token_ids[0]
|
index_no_updates = torch.arange(len(tokenizer)) < token_ids[0]
|
||||||
index_no_updates_list.append(index_no_updates)
|
index_no_updates_list.append(index_no_updates)
|
||||||
|
index_updates = ~index_no_updates
|
||||||
|
index_updates_list.append(index_updates)
|
||||||
|
|
||||||
# accelerator.print(len(index_no_updates), torch.sum(index_no_updates))
|
# accelerator.print(len(index_no_updates), torch.sum(index_no_updates))
|
||||||
orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
|
orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
|
||||||
@@ -634,8 +637,31 @@ class TextualInversionTrainer:
|
|||||||
lr_scheduler.step()
|
lr_scheduler.step()
|
||||||
optimizer.zero_grad(set_to_none=True)
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
# normalize embeddings
|
||||||
|
if args.clip_ti_decay:
|
||||||
|
for text_encoder, index_updates in zip(
|
||||||
|
text_encoders, index_updates_list
|
||||||
|
):
|
||||||
|
pre_norm = (
|
||||||
|
text_encoder.get_input_embeddings()
|
||||||
|
.weight[index_updates, :]
|
||||||
|
.norm(dim=-1, keepdim=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
lambda_ = min(1.0, 100 * lr_scheduler.get_last_lr()[0])
|
||||||
|
text_encoder.get_input_embeddings().weight[
|
||||||
|
index_updates
|
||||||
|
] = torch.nn.functional.normalize(
|
||||||
|
text_encoder.get_input_embeddings().weight[
|
||||||
|
index_updates, :
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
) * (
|
||||||
|
pre_norm + lambda_ * (args.clip_ti_decay - pre_norm)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||||
for text_encoder, orig_embeds_params, index_no_updates in zip(
|
for text_encoder, orig_embeds_params, index_no_updates in zip(
|
||||||
text_encoders, orig_embeds_params_list, index_no_updates_list
|
text_encoders, orig_embeds_params_list, index_no_updates_list
|
||||||
):
|
):
|
||||||
@@ -818,6 +844,12 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
|
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--clip_ti_decay",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Keep the norm of the textual inversion intact (0.4 is a good starting point)",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user