Merge pull request #1 from kohya-ss/main

Update to newest
This commit is contained in:
HkingAuditore
2023-05-15 11:22:02 +08:00
committed by GitHub

View File

@@ -73,6 +73,12 @@ def resblock_forward(_self, num_slices, input_tensor, temb):
_self.norm1.to(cpu_device)
_self.norm2.to(cpu_device)
# GroupNormがCPUでfp16で動かない対策
org_dtype = input_tensor.dtype
if org_dtype == torch.float16:
_self.norm1.to(torch.float32)
_self.norm2.to(torch.float32)
# すべてのテンソルをCPUに移動する
input_tensor = input_tensor.to(cpu_device)
hidden_states = input_tensor
@@ -93,7 +99,11 @@ def resblock_forward(_self, num_slices, input_tensor, temb):
# return num_div, x
# normを分割すると結果が変わるので、ここだけは分割しない。GPUで計算するとVRAMが足りなくなるので、CPUで計算する。幸いCPUでもそこまで遅くない
if org_dtype == torch.float16:
hidden_states = hidden_states.to(torch.float32)
hidden_states = _self.norm1(hidden_states) # run on cpu
if org_dtype == torch.float16:
hidden_states = hidden_states.to(torch.float16)
sliced = slice_h(hidden_states, num_slices)
del hidden_states
@@ -113,7 +123,11 @@ def resblock_forward(_self, num_slices, input_tensor, temb):
hidden_states = cat_h(sliced)
del sliced
if org_dtype == torch.float16:
hidden_states = hidden_states.to(torch.float32)
hidden_states = _self.norm2(hidden_states) # run on cpu
if org_dtype == torch.float16:
hidden_states = hidden_states.to(torch.float16)
sliced = slice_h(hidden_states, num_slices)
del hidden_states
@@ -455,7 +469,7 @@ class SlicingDecoder(nn.Module):
sliced = slice_h(sample, self.num_slices)
del sample
for i in range(self.num_slices):
for i in range(len(sliced)):
x = sliced[i]
sliced[i] = None
@@ -481,7 +495,7 @@ class SlicingDecoder(nn.Module):
sliced = slice_h(hidden_states, num_slices)
del hidden_states
for i in range(num_slices):
for i in range(len(sliced)):
x = sliced[i]
sliced[i] = None