mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
sample generation in SDXL ControlNet training
This commit is contained in:
@@ -13,12 +13,20 @@ from tqdm import tqdm
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import SchedulerMixin, StableDiffusionPipeline
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
|
||||
from diffusers.models import AutoencoderKL
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
||||
from diffusers.utils import logging
|
||||
from PIL import Image
|
||||
|
||||
from library import sdxl_model_util, sdxl_train_util, train_util
|
||||
from library import (
|
||||
sdxl_model_util,
|
||||
sdxl_train_util,
|
||||
strategy_base,
|
||||
strategy_sdxl,
|
||||
train_util,
|
||||
sdxl_original_unet,
|
||||
sdxl_original_control_net,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
@@ -537,7 +545,7 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: List[CLIPTextModel],
|
||||
tokenizer: List[CLIPTokenizer],
|
||||
unet: UNet2DConditionModel,
|
||||
unet: Union[sdxl_original_unet.SdxlUNet2DConditionModel, sdxl_original_control_net.SdxlControlledUNet],
|
||||
scheduler: SchedulerMixin,
|
||||
# clip_skip: int,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
@@ -594,74 +602,6 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
|
||||
return torch.device(module._hf_hook.execution_device)
|
||||
return self.device
|
||||
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
max_embeddings_multiples,
|
||||
is_sdxl_text_encoder2,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `list(int)`):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`):
|
||||
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
||||
if `guidance_scale` is less than `1`).
|
||||
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
||||
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
||||
"""
|
||||
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
||||
|
||||
if negative_prompt is None:
|
||||
negative_prompt = [""] * batch_size
|
||||
elif isinstance(negative_prompt, str):
|
||||
negative_prompt = [negative_prompt] * batch_size
|
||||
if batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
text_embeddings, text_pool, uncond_embeddings, uncond_pool = get_weighted_text_embeddings(
|
||||
pipe=self,
|
||||
prompt=prompt,
|
||||
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
|
||||
max_embeddings_multiples=max_embeddings_multiples,
|
||||
clip_skip=self.clip_skip,
|
||||
is_sdxl_text_encoder2=is_sdxl_text_encoder2,
|
||||
)
|
||||
bs_embed, seq_len, _ = text_embeddings.shape
|
||||
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) # ??
|
||||
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
if text_pool is not None:
|
||||
text_pool = text_pool.repeat(1, num_images_per_prompt)
|
||||
text_pool = text_pool.view(bs_embed * num_images_per_prompt, -1)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
bs_embed, seq_len, _ = uncond_embeddings.shape
|
||||
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
|
||||
uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
if uncond_pool is not None:
|
||||
uncond_pool = uncond_pool.repeat(1, num_images_per_prompt)
|
||||
uncond_pool = uncond_pool.view(bs_embed * num_images_per_prompt, -1)
|
||||
|
||||
return text_embeddings, text_pool, uncond_embeddings, uncond_pool
|
||||
|
||||
return text_embeddings, text_pool, None, None
|
||||
|
||||
def check_inputs(self, prompt, height, width, strength, callback_steps):
|
||||
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
@@ -792,7 +732,7 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
|
||||
max_embeddings_multiples: Optional[int] = 3,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
controlnet=None,
|
||||
controlnet: sdxl_original_control_net.SdxlControlNet = None,
|
||||
controlnet_image=None,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
||||
@@ -896,32 +836,24 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
# 実装を簡単にするためにtokenzer/text encoderを切り替えて二回呼び出す
|
||||
# To simplify the implementation, switch the tokenzer/text encoder and call it twice
|
||||
text_embeddings_list = []
|
||||
text_pool = None
|
||||
uncond_embeddings_list = []
|
||||
uncond_pool = None
|
||||
for i in range(len(self.tokenizers)):
|
||||
self.tokenizer = self.tokenizers[i]
|
||||
self.text_encoder = self.text_encoders[i]
|
||||
tokenize_strategy: strategy_sdxl.SdxlTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
|
||||
encoding_strategy: strategy_sdxl.SdxlTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
|
||||
|
||||
text_embeddings, tp1, uncond_embeddings, up1 = self._encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
max_embeddings_multiples,
|
||||
is_sdxl_text_encoder2=i == 1,
|
||||
text_input_ids, text_weights = tokenize_strategy.tokenize_with_weights(prompt)
|
||||
hidden_states_1, hidden_states_2, text_pool = encoding_strategy.encode_tokens_with_weights(
|
||||
tokenize_strategy, self.text_encoders, text_input_ids, text_weights
|
||||
)
|
||||
text_embeddings = torch.cat([hidden_states_1, hidden_states_2], dim=-1)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
input_ids, weights = tokenize_strategy.tokenize_with_weights(negative_prompt or "")
|
||||
hidden_states_1, hidden_states_2, uncond_pool = encoding_strategy.encode_tokens_with_weights(
|
||||
tokenize_strategy, self.text_encoders, input_ids, weights
|
||||
)
|
||||
text_embeddings_list.append(text_embeddings)
|
||||
uncond_embeddings_list.append(uncond_embeddings)
|
||||
|
||||
if tp1 is not None:
|
||||
text_pool = tp1
|
||||
if up1 is not None:
|
||||
uncond_pool = up1
|
||||
uncond_embeddings = torch.cat([hidden_states_1, hidden_states_2], dim=-1)
|
||||
else:
|
||||
uncond_embeddings = None
|
||||
uncond_pool = None
|
||||
|
||||
unet_dtype = self.unet.dtype
|
||||
dtype = unet_dtype
|
||||
@@ -970,23 +902,23 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# create size embs and concat embeddings for SDXL
|
||||
orig_size = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1).to(dtype)
|
||||
orig_size = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1).to(device, dtype)
|
||||
crop_size = torch.zeros_like(orig_size)
|
||||
target_size = orig_size
|
||||
embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, device).to(dtype)
|
||||
embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, device).to(device, dtype)
|
||||
|
||||
# make conditionings
|
||||
text_pool = text_pool.to(device, dtype)
|
||||
if do_classifier_free_guidance:
|
||||
text_embeddings = torch.cat(text_embeddings_list, dim=2)
|
||||
uncond_embeddings = torch.cat(uncond_embeddings_list, dim=2)
|
||||
text_embedding = torch.cat([uncond_embeddings, text_embeddings]).to(dtype)
|
||||
text_embedding = torch.cat([uncond_embeddings, text_embeddings]).to(device, dtype)
|
||||
|
||||
cond_vector = torch.cat([text_pool, embs], dim=1)
|
||||
uncond_vector = torch.cat([uncond_pool, embs], dim=1)
|
||||
vector_embedding = torch.cat([uncond_vector, cond_vector]).to(dtype)
|
||||
uncond_pool = uncond_pool.to(device, dtype)
|
||||
cond_vector = torch.cat([text_pool, embs], dim=1).to(dtype)
|
||||
uncond_vector = torch.cat([uncond_pool, embs], dim=1).to(dtype)
|
||||
vector_embedding = torch.cat([uncond_vector, cond_vector])
|
||||
else:
|
||||
text_embedding = torch.cat(text_embeddings_list, dim=2).to(dtype)
|
||||
vector_embedding = torch.cat([text_pool, embs], dim=1).to(dtype)
|
||||
text_embedding = text_embeddings.to(device, dtype)
|
||||
vector_embedding = torch.cat([text_pool, embs], dim=1)
|
||||
|
||||
# 8. Denoising loop
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
@@ -994,22 +926,14 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
unet_additional_args = {}
|
||||
if controlnet is not None:
|
||||
down_block_res_samples, mid_block_res_sample = controlnet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=text_embeddings,
|
||||
controlnet_cond=controlnet_image,
|
||||
conditioning_scale=1.0,
|
||||
guess_mode=False,
|
||||
return_dict=False,
|
||||
)
|
||||
unet_additional_args["down_block_additional_residuals"] = down_block_res_samples
|
||||
unet_additional_args["mid_block_additional_residual"] = mid_block_res_sample
|
||||
# FIXME SD1 ControlNet is not working
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding)
|
||||
if controlnet is not None:
|
||||
input_resi_add, mid_add = controlnet(latent_model_input, t, text_embedding, vector_embedding, controlnet_image)
|
||||
noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding, input_resi_add, mid_add)
|
||||
else:
|
||||
noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding)
|
||||
noise_pred = noise_pred.to(dtype) # U-Net changes dtype in LoRA training
|
||||
|
||||
# perform guidance
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# base class for platform strategies. this file defines the interface for strategies
|
||||
|
||||
import os
|
||||
import re
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -22,6 +23,24 @@ logger = logging.getLogger(__name__)
|
||||
class TokenizeStrategy:
|
||||
_strategy = None # strategy instance: actual strategy class
|
||||
|
||||
_re_attention = re.compile(
|
||||
r"""\\\(|
|
||||
\\\)|
|
||||
\\\[|
|
||||
\\]|
|
||||
\\\\|
|
||||
\\|
|
||||
\(|
|
||||
\[|
|
||||
:([+-]?[.\d]+)\)|
|
||||
\)|
|
||||
]|
|
||||
[^\\()\[\]:]+|
|
||||
:
|
||||
""",
|
||||
re.X,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def set_strategy(cls, strategy):
|
||||
if cls._strategy is not None:
|
||||
@@ -54,7 +73,151 @@ class TokenizeStrategy:
|
||||
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_input_ids(self, tokenizer: CLIPTokenizer, text: str, max_length: Optional[int] = None) -> torch.Tensor:
|
||||
def tokenize_with_weights(self, text: Union[str, List[str]]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_weighted_input_ids(
|
||||
self, tokenizer: CLIPTokenizer, text: str, max_length: Optional[int] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
max_length includes starting and ending tokens.
|
||||
"""
|
||||
|
||||
def parse_prompt_attention(text):
|
||||
"""
|
||||
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
|
||||
Accepted tokens are:
|
||||
(abc) - increases attention to abc by a multiplier of 1.1
|
||||
(abc:3.12) - increases attention to abc by a multiplier of 3.12
|
||||
[abc] - decreases attention to abc by a multiplier of 1.1
|
||||
\( - literal character '('
|
||||
\[ - literal character '['
|
||||
\) - literal character ')'
|
||||
\] - literal character ']'
|
||||
\\ - literal character '\'
|
||||
anything else - just text
|
||||
>>> parse_prompt_attention('normal text')
|
||||
[['normal text', 1.0]]
|
||||
>>> parse_prompt_attention('an (important) word')
|
||||
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
|
||||
>>> parse_prompt_attention('(unbalanced')
|
||||
[['unbalanced', 1.1]]
|
||||
>>> parse_prompt_attention('\(literal\]')
|
||||
[['(literal]', 1.0]]
|
||||
>>> parse_prompt_attention('(unnecessary)(parens)')
|
||||
[['unnecessaryparens', 1.1]]
|
||||
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
|
||||
[['a ', 1.0],
|
||||
['house', 1.5730000000000004],
|
||||
[' ', 1.1],
|
||||
['on', 1.0],
|
||||
[' a ', 1.1],
|
||||
['hill', 0.55],
|
||||
[', sun, ', 1.1],
|
||||
['sky', 1.4641000000000006],
|
||||
['.', 1.1]]
|
||||
"""
|
||||
|
||||
res = []
|
||||
round_brackets = []
|
||||
square_brackets = []
|
||||
|
||||
round_bracket_multiplier = 1.1
|
||||
square_bracket_multiplier = 1 / 1.1
|
||||
|
||||
def multiply_range(start_position, multiplier):
|
||||
for p in range(start_position, len(res)):
|
||||
res[p][1] *= multiplier
|
||||
|
||||
for m in TokenizeStrategy._re_attention.finditer(text):
|
||||
text = m.group(0)
|
||||
weight = m.group(1)
|
||||
|
||||
if text.startswith("\\"):
|
||||
res.append([text[1:], 1.0])
|
||||
elif text == "(":
|
||||
round_brackets.append(len(res))
|
||||
elif text == "[":
|
||||
square_brackets.append(len(res))
|
||||
elif weight is not None and len(round_brackets) > 0:
|
||||
multiply_range(round_brackets.pop(), float(weight))
|
||||
elif text == ")" and len(round_brackets) > 0:
|
||||
multiply_range(round_brackets.pop(), round_bracket_multiplier)
|
||||
elif text == "]" and len(square_brackets) > 0:
|
||||
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
||||
else:
|
||||
res.append([text, 1.0])
|
||||
|
||||
for pos in round_brackets:
|
||||
multiply_range(pos, round_bracket_multiplier)
|
||||
|
||||
for pos in square_brackets:
|
||||
multiply_range(pos, square_bracket_multiplier)
|
||||
|
||||
if len(res) == 0:
|
||||
res = [["", 1.0]]
|
||||
|
||||
# merge runs of identical weights
|
||||
i = 0
|
||||
while i + 1 < len(res):
|
||||
if res[i][1] == res[i + 1][1]:
|
||||
res[i][0] += res[i + 1][0]
|
||||
res.pop(i + 1)
|
||||
else:
|
||||
i += 1
|
||||
|
||||
return res
|
||||
|
||||
def get_prompts_with_weights(text: str, max_length: int):
|
||||
r"""
|
||||
Tokenize a list of prompts and return its tokens with weights of each token. max_length does not include starting and ending token.
|
||||
|
||||
No padding, starting or ending token is included.
|
||||
"""
|
||||
truncated = False
|
||||
|
||||
texts_and_weights = parse_prompt_attention(text)
|
||||
tokens = []
|
||||
weights = []
|
||||
for word, weight in texts_and_weights:
|
||||
# tokenize and discard the starting and the ending token
|
||||
token = tokenizer(word).input_ids[1:-1]
|
||||
tokens += token
|
||||
# copy the weight by length of token
|
||||
weights += [weight] * len(token)
|
||||
# stop if the text is too long (longer than truncation limit)
|
||||
if len(tokens) > max_length:
|
||||
truncated = True
|
||||
break
|
||||
# truncate
|
||||
if len(tokens) > max_length:
|
||||
truncated = True
|
||||
tokens = tokens[:max_length]
|
||||
weights = weights[:max_length]
|
||||
if truncated:
|
||||
logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
|
||||
return tokens, weights
|
||||
|
||||
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad):
|
||||
r"""
|
||||
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
|
||||
"""
|
||||
tokens = [bos] + tokens + [eos] + [pad] * (max_length - 2 - len(tokens))
|
||||
weights = [1.0] + weights + [1.0] * (max_length - 1 - len(weights))
|
||||
return tokens, weights
|
||||
|
||||
if max_length is None:
|
||||
max_length = tokenizer.model_max_length
|
||||
|
||||
tokens, weights = get_prompts_with_weights(text, max_length - 2)
|
||||
tokens, weights = pad_tokens_and_weights(
|
||||
tokens, weights, max_length, tokenizer.bos_token_id, tokenizer.eos_token_id, tokenizer.pad_token_id
|
||||
)
|
||||
return torch.tensor(tokens).unsqueeze(0), torch.tensor(weights).unsqueeze(0)
|
||||
|
||||
def _get_input_ids(
|
||||
self, tokenizer: CLIPTokenizer, text: str, max_length: Optional[int] = None, weighted: bool = False
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
for SD1.5/2.0/SDXL
|
||||
TODO support batch input
|
||||
@@ -62,7 +225,10 @@ class TokenizeStrategy:
|
||||
if max_length is None:
|
||||
max_length = tokenizer.model_max_length - 2
|
||||
|
||||
input_ids = tokenizer(text, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt").input_ids
|
||||
if weighted:
|
||||
input_ids, weights = self._get_weighted_input_ids(tokenizer, text, max_length)
|
||||
else:
|
||||
input_ids = tokenizer(text, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt").input_ids
|
||||
|
||||
if max_length > tokenizer.model_max_length:
|
||||
input_ids = input_ids.squeeze(0)
|
||||
@@ -101,6 +267,17 @@ class TokenizeStrategy:
|
||||
iids_list.append(ids_chunk)
|
||||
|
||||
input_ids = torch.stack(iids_list) # 3,77
|
||||
|
||||
if weighted:
|
||||
weights = weights.squeeze(0)
|
||||
new_weights = torch.ones(input_ids.shape)
|
||||
for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2):
|
||||
b = i // (tokenizer.model_max_length - 2)
|
||||
new_weights[b, 1 : 1 + tokenizer.model_max_length - 2] = weights[i : i + tokenizer.model_max_length - 2]
|
||||
weights = new_weights
|
||||
|
||||
if weighted:
|
||||
return input_ids, weights
|
||||
return input_ids
|
||||
|
||||
|
||||
@@ -126,6 +303,17 @@ class TextEncodingStrategy:
|
||||
:return: list of output embeddings for each architecture
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def encode_tokens_with_weights(
|
||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor], weights: List[torch.Tensor]
|
||||
) -> List[torch.Tensor]:
|
||||
"""
|
||||
Encode tokens into embeddings and outputs.
|
||||
:param tokens: list of token tensors for each TextModel
|
||||
:param weights: list of weight tensors for each TextModel
|
||||
:return: list of output embeddings for each architecture
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TextEncoderOutputsCachingStrategy:
|
||||
|
||||
@@ -37,6 +37,22 @@ class SdxlTokenizeStrategy(TokenizeStrategy):
|
||||
torch.stack([self._get_input_ids(self.tokenizer2, t, self.max_length) for t in text], dim=0),
|
||||
)
|
||||
|
||||
def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor]]:
|
||||
text = [text] if isinstance(text, str) else text
|
||||
tokens1_list, tokens2_list = [], []
|
||||
weights1_list, weights2_list = [], []
|
||||
for t in text:
|
||||
tokens1, weights1 = self._get_weighted_input_ids(self.tokenizer1, t, self.max_length)
|
||||
tokens2, weights2 = self._get_weighted_input_ids(self.tokenizer2, t, self.max_length)
|
||||
tokens1_list.append(tokens1)
|
||||
tokens2_list.append(tokens2)
|
||||
weights1_list.append(weights1)
|
||||
weights2_list.append(weights2)
|
||||
return (torch.stack(tokens1_list, dim=0), torch.stack(tokens2_list, dim=0)), (
|
||||
torch.stack(weights1_list, dim=0),
|
||||
torch.stack(weights2_list, dim=0),
|
||||
)
|
||||
|
||||
|
||||
class SdxlTextEncodingStrategy(TextEncodingStrategy):
|
||||
def __init__(self) -> None:
|
||||
@@ -98,7 +114,10 @@ class SdxlTextEncodingStrategy(TextEncodingStrategy):
|
||||
):
|
||||
# input_ids: b,n,77 -> b*n, 77
|
||||
b_size = input_ids1.size()[0]
|
||||
max_token_length = input_ids1.size()[1] * input_ids1.size()[2]
|
||||
if input_ids1.size()[1] == 1:
|
||||
max_token_length = None
|
||||
else:
|
||||
max_token_length = input_ids1.size()[1] * input_ids1.size()[2]
|
||||
input_ids1 = input_ids1.reshape((-1, tokenizer1.model_max_length)) # batch_size*n, 77
|
||||
input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) # batch_size*n, 77
|
||||
input_ids1 = input_ids1.to(text_encoder1.device)
|
||||
@@ -172,6 +191,24 @@ class SdxlTextEncodingStrategy(TextEncodingStrategy):
|
||||
)
|
||||
return [hidden_states1, hidden_states2, pool2]
|
||||
|
||||
def encode_tokens_with_weights(
|
||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor], weights: List[torch.Tensor]
|
||||
) -> List[torch.Tensor]:
|
||||
hidden_states1, hidden_states2, pool2 = self.encode_tokens(tokenize_strategy, models, tokens)
|
||||
|
||||
# apply weights
|
||||
if weights[0].shape[1] == 1: # no max_token_length
|
||||
# weights: ((b, 1, 77), (b, 1, 77)), hidden_states: (b, 77, 768), (b, 77, 768)
|
||||
hidden_states1 = hidden_states1 * weights[0].squeeze(1).unsqueeze(2)
|
||||
hidden_states2 = hidden_states2 * weights[1].squeeze(1).unsqueeze(2)
|
||||
else:
|
||||
# weights: ((b, n, 77), (b, n, 77)), hidden_states: (b, n*75+2, 768), (b, n*75+2, 768)
|
||||
for weight, hidden_states in zip(weights, [hidden_states1, hidden_states2]):
|
||||
for i in range(weight.shape[1]):
|
||||
hidden_states[:, i * 75 + 1 : i * 75 + 76] = hidden_states[:, i * 75 + 1 : i * 75 + 76] * weight[:, i, 1:-1]
|
||||
|
||||
return [hidden_states1, hidden_states2, pool2]
|
||||
|
||||
|
||||
class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_te_outputs.npz"
|
||||
|
||||
@@ -74,6 +74,7 @@ import imagesize
|
||||
import cv2
|
||||
import safetensors.torch
|
||||
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
|
||||
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
|
||||
import library.model_util as model_util
|
||||
import library.huggingface_util as huggingface_util
|
||||
import library.sai_model_spec as sai_model_spec
|
||||
@@ -3581,7 +3582,20 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
# available backends:
|
||||
# https://github.com/huggingface/accelerate/blob/d1abd59114ada8ba673e1214218cb2878c13b82d/src/accelerate/utils/dataclasses.py#L376-L388C5
|
||||
# https://pytorch.org/docs/stable/torch.compiler.html
|
||||
choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt", "tensort", "ipex", "tvm"],
|
||||
choices=[
|
||||
"eager",
|
||||
"aot_eager",
|
||||
"inductor",
|
||||
"aot_ts_nvfuser",
|
||||
"nvprims_nvfuser",
|
||||
"cudagraphs",
|
||||
"ofi",
|
||||
"fx2trt",
|
||||
"onnxrt",
|
||||
"tensort",
|
||||
"ipex",
|
||||
"tvm",
|
||||
],
|
||||
help="dynamo backend type (default is inductor) / dynamoのbackendの種類(デフォルトは inductor)",
|
||||
)
|
||||
parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
|
||||
@@ -5850,8 +5864,8 @@ def sample_images_common(
|
||||
pipe_class,
|
||||
accelerator: Accelerator,
|
||||
args: argparse.Namespace,
|
||||
epoch,
|
||||
steps,
|
||||
epoch: int,
|
||||
steps: int,
|
||||
device,
|
||||
vae,
|
||||
tokenizer,
|
||||
@@ -5910,11 +5924,7 @@ def sample_images_common(
|
||||
with open(args.sample_prompts, "r", encoding="utf-8") as f:
|
||||
prompts = json.load(f)
|
||||
|
||||
# schedulers: dict = {} cannot find where this is used
|
||||
default_scheduler = get_my_scheduler(
|
||||
sample_sampler=args.sample_sampler,
|
||||
v_parameterization=args.v_parameterization,
|
||||
)
|
||||
default_scheduler = get_my_scheduler(sample_sampler=args.sample_sampler, v_parameterization=args.v_parameterization)
|
||||
|
||||
pipeline = pipe_class(
|
||||
text_encoder=text_encoder,
|
||||
@@ -5975,21 +5985,18 @@ def sample_images_common(
|
||||
# clear pipeline and cache to reduce vram usage
|
||||
del pipeline
|
||||
|
||||
# I'm not sure which of these is the correct way to clear the memory, but accelerator's device is used in the pipeline, so I'm using it here.
|
||||
# with torch.cuda.device(torch.cuda.current_device()):
|
||||
# torch.cuda.empty_cache()
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
torch.set_rng_state(rng_state)
|
||||
if torch.cuda.is_available() and cuda_rng_state is not None:
|
||||
torch.cuda.set_rng_state(cuda_rng_state)
|
||||
vae.to(org_vae_device)
|
||||
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
|
||||
def sample_image_inference(
|
||||
accelerator: Accelerator,
|
||||
args: argparse.Namespace,
|
||||
pipeline,
|
||||
pipeline: Union[StableDiffusionLongPromptWeightingPipeline, SdxlStableDiffusionLongPromptWeightingPipeline],
|
||||
save_dir,
|
||||
prompt_dict,
|
||||
epoch,
|
||||
|
||||
@@ -83,6 +83,7 @@ def train(args):
|
||||
|
||||
tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir)
|
||||
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
|
||||
tokenizer1, tokenizer2 = tokenize_strategy.tokenizer1, tokenize_strategy.tokenizer2 # this is used for sampling images
|
||||
|
||||
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
|
||||
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
|
||||
@@ -436,19 +437,19 @@ def train(args):
|
||||
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
|
||||
os.remove(old_ckpt_file)
|
||||
|
||||
# # For --sample_at_first
|
||||
# sdxl_train_util.sample_images(
|
||||
# accelerator,
|
||||
# args,
|
||||
# 0,
|
||||
# global_step,
|
||||
# accelerator.device,
|
||||
# vae,
|
||||
# [tokenizer1, tokenizer2],
|
||||
# [text_encoder1, text_encoder2],
|
||||
# unet,
|
||||
# controlnet=control_net,
|
||||
# )
|
||||
# For --sample_at_first
|
||||
sdxl_train_util.sample_images(
|
||||
accelerator,
|
||||
args,
|
||||
0,
|
||||
global_step,
|
||||
accelerator.device,
|
||||
vae,
|
||||
[tokenizer1, tokenizer2],
|
||||
[text_encoder1, text_encoder2, unwrap_model(text_encoder2)],
|
||||
unet,
|
||||
controlnet=control_net,
|
||||
)
|
||||
|
||||
# training loop
|
||||
for epoch in range(num_train_epochs):
|
||||
@@ -484,7 +485,7 @@ def train(args):
|
||||
input_ids1 = input_ids1.to(accelerator.device)
|
||||
input_ids2 = input_ids2.to(accelerator.device)
|
||||
encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, [text_encoder1, text_encoder2], [input_ids1, input_ids2]
|
||||
tokenize_strategy, [text_encoder1, text_encoder2, unwrap_model(text_encoder2)], [input_ids1, input_ids2]
|
||||
)
|
||||
if args.full_fp16:
|
||||
encoder_hidden_states1 = encoder_hidden_states1.to(weight_dtype)
|
||||
@@ -558,18 +559,18 @@ def train(args):
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
# sdxl_train_util.sample_images(
|
||||
# accelerator,
|
||||
# args,
|
||||
# None,
|
||||
# global_step,
|
||||
# accelerator.device,
|
||||
# vae,
|
||||
# [tokenizer1, tokenizer2],
|
||||
# [text_encoder1, text_encoder2],
|
||||
# unet,
|
||||
# controlnet=control_net,
|
||||
# )
|
||||
sdxl_train_util.sample_images(
|
||||
accelerator,
|
||||
args,
|
||||
None,
|
||||
global_step,
|
||||
accelerator.device,
|
||||
vae,
|
||||
[tokenizer1, tokenizer2],
|
||||
[text_encoder1, text_encoder2, unwrap_model(text_encoder2)],
|
||||
unet,
|
||||
controlnet=control_net,
|
||||
)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||
@@ -628,7 +629,7 @@ def train(args):
|
||||
accelerator.device,
|
||||
vae,
|
||||
[tokenizer1, tokenizer2],
|
||||
[text_encoder1, text_encoder2],
|
||||
[text_encoder1, text_encoder2, unwrap_model(text_encoder2)],
|
||||
unet,
|
||||
controlnet=control_net,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user