fix conditioning

This commit is contained in:
Kohya S
2023-07-09 19:00:38 +09:00
parent a380502c01
commit 77ec70d145

View File

@@ -285,7 +285,7 @@ def get_unweighted_text_embeddings(
def get_weighted_text_embeddings( def get_weighted_text_embeddings(
pipe: StableDiffusionPipeline, pipe, # : SdxlStableDiffusionLongPromptWeightingPipeline,
prompt: Union[str, List[str]], prompt: Union[str, List[str]],
uncond_prompt: Optional[Union[str, List[str]]] = None, uncond_prompt: Optional[Union[str, List[str]]] = None,
max_embeddings_multiples: Optional[int] = 3, max_embeddings_multiples: Optional[int] = 3,
@@ -657,11 +657,9 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
uncond_pool = uncond_pool.repeat(1, num_images_per_prompt) uncond_pool = uncond_pool.repeat(1, num_images_per_prompt)
uncond_pool = uncond_pool.view(bs_embed * num_images_per_prompt, -1) uncond_pool = uncond_pool.view(bs_embed * num_images_per_prompt, -1)
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) return text_embeddings, text_pool, uncond_embeddings, uncond_pool
if text_pool is not None:
text_pool = torch.cat([uncond_pool, text_pool])
return text_embeddings, text_pool return text_embeddings, text_pool, None, None
def check_inputs(self, prompt, height, width, strength, callback_steps): def check_inputs(self, prompt, height, width, strength, callback_steps):
if not isinstance(prompt, str) and not isinstance(prompt, list): if not isinstance(prompt, str) and not isinstance(prompt, list):
@@ -671,7 +669,6 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if height % 8 != 0 or width % 8 != 0: if height % 8 != 0 or width % 8 != 0:
print(height, width)
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or ( if (callback_steps is None) or (
@@ -901,12 +898,14 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
# 実装を簡単にするためにtokenzer/text encoderを切り替えて二回呼び出す # 実装を簡単にするためにtokenzer/text encoderを切り替えて二回呼び出す
# To simplify the implementation, switch the tokenzer/text encoder and call it twice # To simplify the implementation, switch the tokenzer/text encoder and call it twice
text_embeddings_list = [] text_embeddings_list = []
text_pools = [] text_pool = None
uncond_embeddings_list = []
uncond_pool = None
for i in range(len(self.tokenizers)): for i in range(len(self.tokenizers)):
self.tokenizer = self.tokenizers[i] self.tokenizer = self.tokenizers[i]
self.text_encoder = self.text_encoders[i] self.text_encoder = self.text_encoders[i]
text_embeddings, text_pool = self._encode_prompt( text_embeddings, tp1, uncond_embeddings, up1 = self._encode_prompt(
prompt, prompt,
device, device,
num_images_per_prompt, num_images_per_prompt,
@@ -916,7 +915,12 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
is_sdxl_text_encoder2=i == 1, is_sdxl_text_encoder2=i == 1,
) )
text_embeddings_list.append(text_embeddings) text_embeddings_list.append(text_embeddings)
text_pools.append(text_pool) uncond_embeddings_list.append(uncond_embeddings)
if tp1 is not None:
text_pool = tp1
if up1 is not None:
uncond_pool = up1
dtype = text_embeddings_list[0].dtype dtype = text_embeddings_list[0].dtype
@@ -965,11 +969,19 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
crop_size = torch.zeros_like(orig_size) crop_size = torch.zeros_like(orig_size)
target_size = orig_size target_size = orig_size
embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, device).to(dtype) embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, device).to(dtype)
if do_classifier_free_guidance:
embs = torch.cat([embs] * 2)
vector_embedding = torch.cat([text_pools[1], embs], dim=1).to(dtype) # make conditionings
text_embedding = torch.cat(text_embeddings_list, dim=2).to(dtype) if do_classifier_free_guidance:
text_embeddings = torch.cat(text_embeddings_list, dim=2)
uncond_embeddings = torch.cat(uncond_embeddings_list, dim=2)
text_embedding = torch.cat([text_embeddings, uncond_embeddings]).to(dtype)
cond_vector = torch.cat([text_pool, embs], dim=1)
uncond_vector = torch.cat([uncond_pool, embs], dim=1)
vector_embedding = torch.cat([cond_vector, uncond_vector]).to(dtype)
else:
text_embedding = torch.cat(text_embeddings_list, dim=2).to(dtype)
vector_embedding = torch.cat([text_pool, embs], dim=1).to(dtype)
# 8. Denoising loop # 8. Denoising loop
for i, t in enumerate(self.progress_bar(timesteps)): for i, t in enumerate(self.progress_bar(timesteps)):