mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
remove debug code
This commit is contained in:
@@ -711,19 +711,6 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
|
|||||||
# self.vae.set_use_memory_efficient_attention_xformers(False)
|
# self.vae.set_use_memory_efficient_attention_xformers(False)
|
||||||
# image = self.vae.decode(latents.to("cpu")).sample
|
# image = self.vae.decode(latents.to("cpu")).sample
|
||||||
|
|
||||||
print("default dtype:", torch.get_default_dtype())
|
|
||||||
assert latents.dtype == torch.float32
|
|
||||||
assert self.vae.post_quant_conv.weight.dtype == torch.float32
|
|
||||||
print("device:", latents.device, "latents dtype:", latents.dtype, "weight dtype:", self.vae.post_quant_conv.weight.dtype)
|
|
||||||
w = torch.randn_like(self.vae.post_quant_conv.weight, dtype=torch.float32, device=latents.device)
|
|
||||||
x = torch.randn_like(latents, dtype=torch.float32, device=latents.device)
|
|
||||||
x = torch.nn.functional.conv2d(x, w)
|
|
||||||
print("result dtype:", x.dtype) # float16 !!
|
|
||||||
w = torch.randn_like(self.vae.post_quant_conv.weight, dtype=torch.float32, device="cpu")
|
|
||||||
x = torch.randn_like(latents, dtype=torch.float32, device="cpu")
|
|
||||||
x = torch.nn.functional.conv2d(x, w)
|
|
||||||
print("result dtype:", x.dtype) # float32
|
|
||||||
|
|
||||||
image = self.vae.decode(latents).sample
|
image = self.vae.decode(latents).sample
|
||||||
image = (image / 2 + 0.5).clamp(0, 1)
|
image = (image / 2 + 0.5).clamp(0, 1)
|
||||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||||
|
|||||||
Reference in New Issue
Block a user