From 97611e89cab7137b994defd65a1e2ba29c0567e5 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 2 Jul 2023 16:49:11 +0900 Subject: [PATCH] remove debug code --- library/sdxl_lpw_stable_diffusion.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/library/sdxl_lpw_stable_diffusion.py b/library/sdxl_lpw_stable_diffusion.py index 786b8119..e806bc61 100644 --- a/library/sdxl_lpw_stable_diffusion.py +++ b/library/sdxl_lpw_stable_diffusion.py @@ -711,19 +711,6 @@ class SdxlStableDiffusionLongPromptWeightingPipeline: # self.vae.set_use_memory_efficient_attention_xformers(False) # image = self.vae.decode(latents.to("cpu")).sample - print("default dtype:", torch.get_default_dtype()) - assert latents.dtype == torch.float32 - assert self.vae.post_quant_conv.weight.dtype == torch.float32 - print("device:", latents.device, "latents dtype:", latents.dtype, "weight dtype:", self.vae.post_quant_conv.weight.dtype) - w = torch.randn_like(self.vae.post_quant_conv.weight, dtype=torch.float32, device=latents.device) - x = torch.randn_like(latents, dtype=torch.float32, device=latents.device) - x = torch.nn.functional.conv2d(x, w) - print("result dtype:", x.dtype) # float16 !! - w = torch.randn_like(self.vae.post_quant_conv.weight, dtype=torch.float32, device="cpu") - x = torch.randn_like(latents, dtype=torch.float32, device="cpu") - x = torch.nn.functional.conv2d(x, w) - print("result dtype:", x.dtype) # float32 - image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16