fix NaN in sampling image

This commit is contained in:
Kohya S
2023-07-11 23:18:35 +09:00
parent 2e67d74df4
commit 814996b14f
2 changed files with 118 additions and 118 deletions

View File

@@ -922,7 +922,7 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
if up1 is not None: if up1 is not None:
uncond_pool = up1 uncond_pool = up1
dtype = text_embeddings_list[0].dtype dtype = self.unet.dtype
# 4. Preprocess image and mask # 4. Preprocess image and mask
if isinstance(image, PIL.Image.Image): if isinstance(image, PIL.Image.Image):

View File

@@ -3874,7 +3874,7 @@ def sample_images_common(
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
with torch.no_grad(): with torch.no_grad():
with accelerator.autocast(): # with accelerator.autocast():
for i, prompt in enumerate(prompts): for i, prompt in enumerate(prompts):
if not accelerator.is_main_process: if not accelerator.is_main_process:
continue continue