mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix to work with fp16, crash with some reso
This commit is contained in:
@@ -73,6 +73,12 @@ def resblock_forward(_self, num_slices, input_tensor, temb):
|
|||||||
_self.norm1.to(cpu_device)
|
_self.norm1.to(cpu_device)
|
||||||
_self.norm2.to(cpu_device)
|
_self.norm2.to(cpu_device)
|
||||||
|
|
||||||
|
# GroupNormがCPUでfp16で動かない対策
|
||||||
|
org_dtype = input_tensor.dtype
|
||||||
|
if org_dtype == torch.float16:
|
||||||
|
_self.norm1.to(torch.float32)
|
||||||
|
_self.norm2.to(torch.float32)
|
||||||
|
|
||||||
# すべてのテンソルをCPUに移動する
|
# すべてのテンソルをCPUに移動する
|
||||||
input_tensor = input_tensor.to(cpu_device)
|
input_tensor = input_tensor.to(cpu_device)
|
||||||
hidden_states = input_tensor
|
hidden_states = input_tensor
|
||||||
@@ -93,7 +99,11 @@ def resblock_forward(_self, num_slices, input_tensor, temb):
|
|||||||
# return num_div, x
|
# return num_div, x
|
||||||
|
|
||||||
# normを分割すると結果が変わるので、ここだけは分割しない。GPUで計算するとVRAMが足りなくなるので、CPUで計算する。幸いCPUでもそこまで遅くない
|
# normを分割すると結果が変わるので、ここだけは分割しない。GPUで計算するとVRAMが足りなくなるので、CPUで計算する。幸いCPUでもそこまで遅くない
|
||||||
|
if org_dtype == torch.float16:
|
||||||
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
hidden_states = _self.norm1(hidden_states) # run on cpu
|
hidden_states = _self.norm1(hidden_states) # run on cpu
|
||||||
|
if org_dtype == torch.float16:
|
||||||
|
hidden_states = hidden_states.to(torch.float16)
|
||||||
|
|
||||||
sliced = slice_h(hidden_states, num_slices)
|
sliced = slice_h(hidden_states, num_slices)
|
||||||
del hidden_states
|
del hidden_states
|
||||||
@@ -113,7 +123,11 @@ def resblock_forward(_self, num_slices, input_tensor, temb):
|
|||||||
hidden_states = cat_h(sliced)
|
hidden_states = cat_h(sliced)
|
||||||
del sliced
|
del sliced
|
||||||
|
|
||||||
|
if org_dtype == torch.float16:
|
||||||
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
hidden_states = _self.norm2(hidden_states) # run on cpu
|
hidden_states = _self.norm2(hidden_states) # run on cpu
|
||||||
|
if org_dtype == torch.float16:
|
||||||
|
hidden_states = hidden_states.to(torch.float16)
|
||||||
|
|
||||||
sliced = slice_h(hidden_states, num_slices)
|
sliced = slice_h(hidden_states, num_slices)
|
||||||
del hidden_states
|
del hidden_states
|
||||||
@@ -455,7 +469,7 @@ class SlicingDecoder(nn.Module):
|
|||||||
|
|
||||||
sliced = slice_h(sample, self.num_slices)
|
sliced = slice_h(sample, self.num_slices)
|
||||||
del sample
|
del sample
|
||||||
for i in range(self.num_slices):
|
for i in range(len(sliced)):
|
||||||
x = sliced[i]
|
x = sliced[i]
|
||||||
sliced[i] = None
|
sliced[i] = None
|
||||||
|
|
||||||
@@ -481,7 +495,7 @@ class SlicingDecoder(nn.Module):
|
|||||||
sliced = slice_h(hidden_states, num_slices)
|
sliced = slice_h(hidden_states, num_slices)
|
||||||
del hidden_states
|
del hidden_states
|
||||||
|
|
||||||
for i in range(num_slices):
|
for i in range(len(sliced)):
|
||||||
x = sliced[i]
|
x = sliced[i]
|
||||||
sliced[i] = None
|
sliced[i] = None
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user