mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix to work with fp16, crash with some reso
This commit is contained in:
@@ -73,6 +73,12 @@ def resblock_forward(_self, num_slices, input_tensor, temb):
|
||||
_self.norm1.to(cpu_device)
|
||||
_self.norm2.to(cpu_device)
|
||||
|
||||
# GroupNormがCPUでfp16で動かない対策
|
||||
org_dtype = input_tensor.dtype
|
||||
if org_dtype == torch.float16:
|
||||
_self.norm1.to(torch.float32)
|
||||
_self.norm2.to(torch.float32)
|
||||
|
||||
# すべてのテンソルをCPUに移動する
|
||||
input_tensor = input_tensor.to(cpu_device)
|
||||
hidden_states = input_tensor
|
||||
@@ -93,7 +99,11 @@ def resblock_forward(_self, num_slices, input_tensor, temb):
|
||||
# return num_div, x
|
||||
|
||||
# normを分割すると結果が変わるので、ここだけは分割しない。GPUで計算するとVRAMが足りなくなるので、CPUで計算する。幸いCPUでもそこまで遅くない
|
||||
if org_dtype == torch.float16:
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
hidden_states = _self.norm1(hidden_states) # run on cpu
|
||||
if org_dtype == torch.float16:
|
||||
hidden_states = hidden_states.to(torch.float16)
|
||||
|
||||
sliced = slice_h(hidden_states, num_slices)
|
||||
del hidden_states
|
||||
@@ -113,7 +123,11 @@ def resblock_forward(_self, num_slices, input_tensor, temb):
|
||||
hidden_states = cat_h(sliced)
|
||||
del sliced
|
||||
|
||||
if org_dtype == torch.float16:
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
hidden_states = _self.norm2(hidden_states) # run on cpu
|
||||
if org_dtype == torch.float16:
|
||||
hidden_states = hidden_states.to(torch.float16)
|
||||
|
||||
sliced = slice_h(hidden_states, num_slices)
|
||||
del hidden_states
|
||||
@@ -455,7 +469,7 @@ class SlicingDecoder(nn.Module):
|
||||
|
||||
sliced = slice_h(sample, self.num_slices)
|
||||
del sample
|
||||
for i in range(self.num_slices):
|
||||
for i in range(len(sliced)):
|
||||
x = sliced[i]
|
||||
sliced[i] = None
|
||||
|
||||
@@ -481,7 +495,7 @@ class SlicingDecoder(nn.Module):
|
||||
sliced = slice_h(hidden_states, num_slices)
|
||||
del hidden_states
|
||||
|
||||
for i in range(num_slices):
|
||||
for i in range(len(sliced)):
|
||||
x = sliced[i]
|
||||
sliced[i] = None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user