mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix dtype for vae
This commit is contained in:
@@ -129,11 +129,11 @@ if __name__ == "__main__":
|
|||||||
unet.to(DEVICE, dtype=DTYPE)
|
unet.to(DEVICE, dtype=DTYPE)
|
||||||
unet.eval()
|
unet.eval()
|
||||||
|
|
||||||
|
vae_dtype = DTYPE
|
||||||
if DTYPE == torch.float16:
|
if DTYPE == torch.float16:
|
||||||
print("use float32 for vae")
|
print("use float32 for vae")
|
||||||
vae.to(DEVICE, torch.float32) # avoid black image, same as no-half-vae
|
vae_dtype = torch.float32
|
||||||
else:
|
vae.to(DEVICE, dtype=vae_dtype)
|
||||||
vae.to(DEVICE, DTYPE)
|
|
||||||
vae.eval()
|
vae.eval()
|
||||||
|
|
||||||
text_model1.to(DEVICE, dtype=DTYPE)
|
text_model1.to(DEVICE, dtype=DTYPE)
|
||||||
@@ -278,7 +278,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# latents = 1 / 0.18215 * latents
|
# latents = 1 / 0.18215 * latents
|
||||||
latents = 1 / sdxl_model_util.VAE_SCALE_FACTOR * latents
|
latents = 1 / sdxl_model_util.VAE_SCALE_FACTOR * latents
|
||||||
latents = latents.to(torch.float32)
|
latents = latents.to(vae_dtype)
|
||||||
image = vae.decode(latents).sample
|
image = vae.decode(latents).sample
|
||||||
image = (image / 2 + 0.5).clamp(0, 1)
|
image = (image / 2 + 0.5).clamp(0, 1)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user