add lora controlnet train/gen temporarily

This commit is contained in:
Kohya S
2023-08-17 10:08:02 +09:00
parent 983698dd1b
commit 3f7235c36f
6 changed files with 3582 additions and 83 deletions

View File

@@ -39,6 +39,7 @@ CONTEXT_DIM: int = 2048
MODEL_CHANNELS: int = 320
TIME_EMBED_DIM = 320 * 4
USE_REENTRANT = True
# region memory effcient attention
@@ -322,7 +323,7 @@ class ResnetBlock2D(nn.Module):
return custom_forward
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), x, emb)
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), x, emb, use_reentrant=USE_REENTRANT)
else:
x = self.forward_body(x, emb)
@@ -356,7 +357,9 @@ class Downsample2D(nn.Module):
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), hidden_states)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.forward_body), hidden_states, use_reentrant=USE_REENTRANT
)
else:
hidden_states = self.forward_body(hidden_states)
@@ -641,7 +644,9 @@ class BasicTransformerBlock(nn.Module):
return custom_forward
output = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), hidden_states, context, timestep)
output = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.forward_body), hidden_states, context, timestep, use_reentrant=USE_REENTRANT
)
else:
output = self.forward_body(hidden_states, context, timestep)
@@ -782,7 +787,9 @@ class Upsample2D(nn.Module):
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), hidden_states, output_size)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.forward_body), hidden_states, output_size, use_reentrant=USE_REENTRANT
)
else:
hidden_states = self.forward_body(hidden_states, output_size)