mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix conditioning
This commit is contained in:
@@ -285,7 +285,7 @@ def get_unweighted_text_embeddings(
|
||||
|
||||
|
||||
def get_weighted_text_embeddings(
|
||||
pipe: StableDiffusionPipeline,
|
||||
pipe, # : SdxlStableDiffusionLongPromptWeightingPipeline,
|
||||
prompt: Union[str, List[str]],
|
||||
uncond_prompt: Optional[Union[str, List[str]]] = None,
|
||||
max_embeddings_multiples: Optional[int] = 3,
|
||||
@@ -657,11 +657,9 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
|
||||
uncond_pool = uncond_pool.repeat(1, num_images_per_prompt)
|
||||
uncond_pool = uncond_pool.view(bs_embed * num_images_per_prompt, -1)
|
||||
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
if text_pool is not None:
|
||||
text_pool = torch.cat([uncond_pool, text_pool])
|
||||
return text_embeddings, text_pool, uncond_embeddings, uncond_pool
|
||||
|
||||
return text_embeddings, text_pool
|
||||
return text_embeddings, text_pool, None, None
|
||||
|
||||
def check_inputs(self, prompt, height, width, strength, callback_steps):
|
||||
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
||||
@@ -671,7 +669,6 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
|
||||
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
||||
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
print(height, width)
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if (callback_steps is None) or (
|
||||
@@ -901,12 +898,14 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
|
||||
# 実装を簡単にするためにtokenzer/text encoderを切り替えて二回呼び出す
|
||||
# To simplify the implementation, switch the tokenzer/text encoder and call it twice
|
||||
text_embeddings_list = []
|
||||
text_pools = []
|
||||
text_pool = None
|
||||
uncond_embeddings_list = []
|
||||
uncond_pool = None
|
||||
for i in range(len(self.tokenizers)):
|
||||
self.tokenizer = self.tokenizers[i]
|
||||
self.text_encoder = self.text_encoders[i]
|
||||
|
||||
text_embeddings, text_pool = self._encode_prompt(
|
||||
text_embeddings, tp1, uncond_embeddings, up1 = self._encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
@@ -916,7 +915,12 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
|
||||
is_sdxl_text_encoder2=i == 1,
|
||||
)
|
||||
text_embeddings_list.append(text_embeddings)
|
||||
text_pools.append(text_pool)
|
||||
uncond_embeddings_list.append(uncond_embeddings)
|
||||
|
||||
if tp1 is not None:
|
||||
text_pool = tp1
|
||||
if up1 is not None:
|
||||
uncond_pool = up1
|
||||
|
||||
dtype = text_embeddings_list[0].dtype
|
||||
|
||||
@@ -965,11 +969,19 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
|
||||
crop_size = torch.zeros_like(orig_size)
|
||||
target_size = orig_size
|
||||
embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, device).to(dtype)
|
||||
if do_classifier_free_guidance:
|
||||
embs = torch.cat([embs] * 2)
|
||||
|
||||
vector_embedding = torch.cat([text_pools[1], embs], dim=1).to(dtype)
|
||||
text_embedding = torch.cat(text_embeddings_list, dim=2).to(dtype)
|
||||
# make conditionings
|
||||
if do_classifier_free_guidance:
|
||||
text_embeddings = torch.cat(text_embeddings_list, dim=2)
|
||||
uncond_embeddings = torch.cat(uncond_embeddings_list, dim=2)
|
||||
text_embedding = torch.cat([text_embeddings, uncond_embeddings]).to(dtype)
|
||||
|
||||
cond_vector = torch.cat([text_pool, embs], dim=1)
|
||||
uncond_vector = torch.cat([uncond_pool, embs], dim=1)
|
||||
vector_embedding = torch.cat([cond_vector, uncond_vector]).to(dtype)
|
||||
else:
|
||||
text_embedding = torch.cat(text_embeddings_list, dim=2).to(dtype)
|
||||
vector_embedding = torch.cat([text_pool, embs], dim=1).to(dtype)
|
||||
|
||||
# 8. Denoising loop
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
|
||||
Reference in New Issue
Block a user