mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
remove unnecessary code
This commit is contained in:
@@ -734,24 +734,6 @@ class Transformer2DModel(nn.Module):
|
|||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def forward_xxx(self, hidden_states, encoder_hidden_states=None, timestep=None):
|
|
||||||
if self.training and self.gradient_checkpointing:
|
|
||||||
# print("Transformer2DModel: Using gradient checkpointing")
|
|
||||||
|
|
||||||
def create_custom_forward(func):
|
|
||||||
def custom_forward(*inputs):
|
|
||||||
return func(*inputs)
|
|
||||||
|
|
||||||
return custom_forward
|
|
||||||
|
|
||||||
output = torch.utils.checkpoint.checkpoint(
|
|
||||||
create_custom_forward(self.forward_body), hidden_states, encoder_hidden_states, timestep
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
output = self.forward_body(hidden_states, encoder_hidden_states, timestep)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class Upsample2D(nn.Module):
|
class Upsample2D(nn.Module):
|
||||||
def __init__(self, channels, out_channels):
|
def __init__(self, channels, out_channels):
|
||||||
|
|||||||
Reference in New Issue
Block a user