mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix to work sample generation in fp8 ref #1057
This commit is contained in:
@@ -923,7 +923,11 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
|
|||||||
if up1 is not None:
|
if up1 is not None:
|
||||||
uncond_pool = up1
|
uncond_pool = up1
|
||||||
|
|
||||||
dtype = self.unet.dtype
|
unet_dtype = self.unet.dtype
|
||||||
|
dtype = unet_dtype
|
||||||
|
if dtype.itemsize == 1: # fp8
|
||||||
|
dtype = torch.float16
|
||||||
|
self.unet.to(dtype)
|
||||||
|
|
||||||
# 4. Preprocess image and mask
|
# 4. Preprocess image and mask
|
||||||
if isinstance(image, PIL.Image.Image):
|
if isinstance(image, PIL.Image.Image):
|
||||||
@@ -1028,6 +1032,7 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
|
|||||||
if is_cancelled_callback is not None and is_cancelled_callback():
|
if is_cancelled_callback is not None and is_cancelled_callback():
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
self.unet.to(unet_dtype)
|
||||||
return latents
|
return latents
|
||||||
|
|
||||||
def latents_to_image(self, latents):
|
def latents_to_image(self, latents):
|
||||||
|
|||||||
Reference in New Issue
Block a user