mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
support weighted captions for sdxl LoRA and fine tuning
This commit is contained in:
@@ -74,6 +74,9 @@ class TokenizeStrategy:
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def tokenize_with_weights(self, text: Union[str, List[str]]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
def tokenize_with_weights(self, text: Union[str, List[str]]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
returns: [tokens1, tokens2, ...], [weights1, weights2, ...]
|
||||||
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def _get_weighted_input_ids(
|
def _get_weighted_input_ids(
|
||||||
|
|||||||
@@ -174,7 +174,8 @@ class SdxlTextEncodingStrategy(TextEncodingStrategy):
|
|||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
tokenize_strategy: TokenizeStrategy
|
tokenize_strategy: TokenizeStrategy
|
||||||
models: List of models, [text_encoder1, text_encoder2, unwrapped text_encoder2 (optional)]
|
models: List of models, [text_encoder1, text_encoder2, unwrapped text_encoder2 (optional)].
|
||||||
|
If text_encoder2 is wrapped by accelerate, unwrapped_text_encoder2 is required
|
||||||
tokens: List of tokens, for text_encoder1 and text_encoder2
|
tokens: List of tokens, for text_encoder1 and text_encoder2
|
||||||
"""
|
"""
|
||||||
if len(models) == 2:
|
if len(models) == 2:
|
||||||
|
|||||||
@@ -104,8 +104,8 @@ def train(args):
|
|||||||
setup_logging(args, reset=True)
|
setup_logging(args, reset=True)
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
not args.weighted_captions
|
not args.weighted_captions or not args.cache_text_encoder_outputs
|
||||||
), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません"
|
), "weighted_captions is not supported when caching text encoder outputs / cache_text_encoder_outputsを使うときはweighted_captionsはサポートされていません"
|
||||||
assert (
|
assert (
|
||||||
not args.train_text_encoder or not args.cache_text_encoder_outputs
|
not args.train_text_encoder or not args.cache_text_encoder_outputs
|
||||||
), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません"
|
), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません"
|
||||||
@@ -660,22 +660,24 @@ def train(args):
|
|||||||
input_ids1, input_ids2 = batch["input_ids_list"]
|
input_ids1, input_ids2 = batch["input_ids_list"]
|
||||||
with torch.set_grad_enabled(args.train_text_encoder):
|
with torch.set_grad_enabled(args.train_text_encoder):
|
||||||
# Get the text embedding for conditioning
|
# Get the text embedding for conditioning
|
||||||
# TODO support weighted captions
|
if args.weighted_captions:
|
||||||
# if args.weighted_captions:
|
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"])
|
||||||
# encoder_hidden_states = get_weighted_text_embeddings(
|
encoder_hidden_states1, encoder_hidden_states2, pool2 = (
|
||||||
# tokenizer,
|
text_encoding_strategy.encode_tokens_with_weights(
|
||||||
# text_encoder,
|
tokenize_strategy,
|
||||||
# batch["captions"],
|
[text_encoder1, text_encoder2, accelerator.unwrap_model(text_encoder2)],
|
||||||
# accelerator.device,
|
input_ids_list,
|
||||||
# args.max_token_length // 75 if args.max_token_length else 1,
|
weights_list,
|
||||||
# clip_skip=args.clip_skip,
|
)
|
||||||
# )
|
)
|
||||||
# else:
|
else:
|
||||||
input_ids1 = input_ids1.to(accelerator.device)
|
input_ids1 = input_ids1.to(accelerator.device)
|
||||||
input_ids2 = input_ids2.to(accelerator.device)
|
input_ids2 = input_ids2.to(accelerator.device)
|
||||||
encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoding_strategy.encode_tokens(
|
encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoding_strategy.encode_tokens(
|
||||||
tokenize_strategy, [text_encoder1, text_encoder2], [input_ids1, input_ids2]
|
tokenize_strategy,
|
||||||
)
|
[text_encoder1, text_encoder2, accelerator.unwrap_model(text_encoder2)],
|
||||||
|
[input_ids1, input_ids2],
|
||||||
|
)
|
||||||
if args.full_fp16:
|
if args.full_fp16:
|
||||||
encoder_hidden_states1 = encoder_hidden_states1.to(weight_dtype)
|
encoder_hidden_states1 = encoder_hidden_states1.to(weight_dtype)
|
||||||
encoder_hidden_states2 = encoder_hidden_states2.to(weight_dtype)
|
encoder_hidden_states2 = encoder_hidden_states2.to(weight_dtype)
|
||||||
|
|||||||
@@ -12,24 +12,21 @@ from library.device_utils import init_ipex, clean_memory_on_device
|
|||||||
|
|
||||||
init_ipex()
|
init_ipex()
|
||||||
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
from diffusers import DDPMScheduler, ControlNetModel
|
from diffusers import DDPMScheduler
|
||||||
from diffusers.utils.torch_utils import is_compiled_module
|
from diffusers.utils.torch_utils import is_compiled_module
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
from library import (
|
from library import (
|
||||||
deepspeed_utils,
|
deepspeed_utils,
|
||||||
sai_model_spec,
|
sai_model_spec,
|
||||||
sdxl_model_util,
|
sdxl_model_util,
|
||||||
sdxl_original_unet,
|
|
||||||
sdxl_train_util,
|
sdxl_train_util,
|
||||||
strategy_base,
|
strategy_base,
|
||||||
strategy_sd,
|
strategy_sd,
|
||||||
strategy_sdxl,
|
strategy_sdxl,
|
||||||
)
|
)
|
||||||
|
|
||||||
import library.model_util as model_util
|
|
||||||
import library.train_util as train_util
|
import library.train_util as train_util
|
||||||
import library.config_util as config_util
|
import library.config_util as config_util
|
||||||
from library.config_util import (
|
from library.config_util import (
|
||||||
|
|||||||
@@ -1123,14 +1123,21 @@ class NetworkTrainer:
|
|||||||
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
|
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
|
||||||
# Get the text embedding for conditioning
|
# Get the text embedding for conditioning
|
||||||
if args.weighted_captions:
|
if args.weighted_captions:
|
||||||
# SD only
|
# # SD only
|
||||||
encoded_text_encoder_conds = get_weighted_text_embeddings(
|
# encoded_text_encoder_conds = get_weighted_text_embeddings(
|
||||||
tokenizers[0],
|
# tokenizers[0],
|
||||||
text_encoder,
|
# text_encoder,
|
||||||
batch["captions"],
|
# batch["captions"],
|
||||||
accelerator.device,
|
# accelerator.device,
|
||||||
args.max_token_length // 75 if args.max_token_length else 1,
|
# args.max_token_length // 75 if args.max_token_length else 1,
|
||||||
clip_skip=args.clip_skip,
|
# clip_skip=args.clip_skip,
|
||||||
|
# )
|
||||||
|
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"])
|
||||||
|
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens_with_weights(
|
||||||
|
tokenize_strategy,
|
||||||
|
self.get_models_for_text_encoding(args, accelerator, text_encoders),
|
||||||
|
input_ids_list,
|
||||||
|
weights_list,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
|
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
|
||||||
@@ -1139,8 +1146,8 @@ class NetworkTrainer:
|
|||||||
self.get_models_for_text_encoding(args, accelerator, text_encoders),
|
self.get_models_for_text_encoding(args, accelerator, text_encoders),
|
||||||
input_ids,
|
input_ids,
|
||||||
)
|
)
|
||||||
if args.full_fp16:
|
if args.full_fp16:
|
||||||
encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds]
|
encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds]
|
||||||
|
|
||||||
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
|
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
|
||||||
if len(text_encoder_conds) == 0:
|
if len(text_encoder_conds) == 0:
|
||||||
|
|||||||
Reference in New Issue
Block a user