fix to work with fp16, crash with some reso

This commit is contained in:
Kohya S
2023-05-12 21:44:07 +09:00
parent 47b6101465
commit 41dd835a89

View File

@@ -73,6 +73,12 @@ def resblock_forward(_self, num_slices, input_tensor, temb):
_self.norm1.to(cpu_device) _self.norm1.to(cpu_device)
_self.norm2.to(cpu_device) _self.norm2.to(cpu_device)
# GroupNormがCPUでfp16で動かない対策
org_dtype = input_tensor.dtype
if org_dtype == torch.float16:
_self.norm1.to(torch.float32)
_self.norm2.to(torch.float32)
# すべてのテンソルをCPUに移動する # すべてのテンソルをCPUに移動する
input_tensor = input_tensor.to(cpu_device) input_tensor = input_tensor.to(cpu_device)
hidden_states = input_tensor hidden_states = input_tensor
@@ -93,7 +99,11 @@ def resblock_forward(_self, num_slices, input_tensor, temb):
# return num_div, x # return num_div, x
# normを分割すると結果が変わるので、ここだけは分割しない。GPUで計算するとVRAMが足りなくなるので、CPUで計算する。幸いCPUでもそこまで遅くない # normを分割すると結果が変わるので、ここだけは分割しない。GPUで計算するとVRAMが足りなくなるので、CPUで計算する。幸いCPUでもそこまで遅くない
if org_dtype == torch.float16:
hidden_states = hidden_states.to(torch.float32)
hidden_states = _self.norm1(hidden_states) # run on cpu hidden_states = _self.norm1(hidden_states) # run on cpu
if org_dtype == torch.float16:
hidden_states = hidden_states.to(torch.float16)
sliced = slice_h(hidden_states, num_slices) sliced = slice_h(hidden_states, num_slices)
del hidden_states del hidden_states
@@ -113,7 +123,11 @@ def resblock_forward(_self, num_slices, input_tensor, temb):
hidden_states = cat_h(sliced) hidden_states = cat_h(sliced)
del sliced del sliced
if org_dtype == torch.float16:
hidden_states = hidden_states.to(torch.float32)
hidden_states = _self.norm2(hidden_states) # run on cpu hidden_states = _self.norm2(hidden_states) # run on cpu
if org_dtype == torch.float16:
hidden_states = hidden_states.to(torch.float16)
sliced = slice_h(hidden_states, num_slices) sliced = slice_h(hidden_states, num_slices)
del hidden_states del hidden_states
@@ -455,7 +469,7 @@ class SlicingDecoder(nn.Module):
sliced = slice_h(sample, self.num_slices) sliced = slice_h(sample, self.num_slices)
del sample del sample
for i in range(self.num_slices): for i in range(len(sliced)):
x = sliced[i] x = sliced[i]
sliced[i] = None sliced[i] = None
@@ -481,7 +495,7 @@ class SlicingDecoder(nn.Module):
sliced = slice_h(hidden_states, num_slices) sliced = slice_h(hidden_states, num_slices)
del hidden_states del hidden_states
for i in range(num_slices): for i in range(len(sliced)):
x = sliced[i] x = sliced[i]
sliced[i] = None sliced[i] = None