remove debug code

This commit is contained in:
Kohya S
2023-07-02 16:49:11 +09:00
parent 64cf922841
commit 97611e89ca

View File

@@ -711,19 +711,6 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
# self.vae.set_use_memory_efficient_attention_xformers(False)
# image = self.vae.decode(latents.to("cpu")).sample
print("default dtype:", torch.get_default_dtype())
assert latents.dtype == torch.float32
assert self.vae.post_quant_conv.weight.dtype == torch.float32
print("device:", latents.device, "latents dtype:", latents.dtype, "weight dtype:", self.vae.post_quant_conv.weight.dtype)
w = torch.randn_like(self.vae.post_quant_conv.weight, dtype=torch.float32, device=latents.device)
x = torch.randn_like(latents, dtype=torch.float32, device=latents.device)
x = torch.nn.functional.conv2d(x, w)
print("result dtype:", x.dtype) # float16 !!
w = torch.randn_like(self.vae.post_quant_conv.weight, dtype=torch.float32, device="cpu")
x = torch.randn_like(latents, dtype=torch.float32, device="cpu")
x = torch.nn.functional.conv2d(x, w)
print("result dtype:", x.dtype) # float32
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16