From 41dd835a893b998847a93ab1015d15690598ac40 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 12 May 2023 21:44:07 +0900 Subject: [PATCH] fix to work with fp16, crash with some reso --- library/slicing_vae.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/library/slicing_vae.py b/library/slicing_vae.py index 084bff68..490b5a75 100644 --- a/library/slicing_vae.py +++ b/library/slicing_vae.py @@ -73,6 +73,12 @@ def resblock_forward(_self, num_slices, input_tensor, temb): _self.norm1.to(cpu_device) _self.norm2.to(cpu_device) + # GroupNormがCPUでfp16で動かない対策 + org_dtype = input_tensor.dtype + if org_dtype == torch.float16: + _self.norm1.to(torch.float32) + _self.norm2.to(torch.float32) + # すべてのテンソルをCPUに移動する input_tensor = input_tensor.to(cpu_device) hidden_states = input_tensor @@ -93,7 +99,11 @@ def resblock_forward(_self, num_slices, input_tensor, temb): # return num_div, x # normを分割すると結果が変わるので、ここだけは分割しない。GPUで計算するとVRAMが足りなくなるので、CPUで計算する。幸いCPUでもそこまで遅くない + if org_dtype == torch.float16: + hidden_states = hidden_states.to(torch.float32) hidden_states = _self.norm1(hidden_states) # run on cpu + if org_dtype == torch.float16: + hidden_states = hidden_states.to(torch.float16) sliced = slice_h(hidden_states, num_slices) del hidden_states @@ -113,7 +123,11 @@ def resblock_forward(_self, num_slices, input_tensor, temb): hidden_states = cat_h(sliced) del sliced + if org_dtype == torch.float16: + hidden_states = hidden_states.to(torch.float32) hidden_states = _self.norm2(hidden_states) # run on cpu + if org_dtype == torch.float16: + hidden_states = hidden_states.to(torch.float16) sliced = slice_h(hidden_states, num_slices) del hidden_states @@ -455,7 +469,7 @@ class SlicingDecoder(nn.Module): sliced = slice_h(sample, self.num_slices) del sample - for i in range(self.num_slices): + for i in range(len(sliced)): x = sliced[i] sliced[i] = None @@ -481,7 +495,7 @@ class SlicingDecoder(nn.Module): sliced = slice_h(hidden_states, num_slices) del hidden_states - for i in range(num_slices): + for i in range(len(sliced)): x = sliced[i] sliced[i] = None