mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
add lora controlnet train/gen temporarily
This commit is contained in:
@@ -39,6 +39,7 @@ CONTEXT_DIM: int = 2048
|
||||
MODEL_CHANNELS: int = 320
|
||||
TIME_EMBED_DIM = 320 * 4
|
||||
|
||||
USE_REENTRANT = True
|
||||
|
||||
# region memory effcient attention
|
||||
|
||||
@@ -322,7 +323,7 @@ class ResnetBlock2D(nn.Module):
|
||||
|
||||
return custom_forward
|
||||
|
||||
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), x, emb)
|
||||
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), x, emb, use_reentrant=USE_REENTRANT)
|
||||
else:
|
||||
x = self.forward_body(x, emb)
|
||||
|
||||
@@ -356,7 +357,9 @@ class Downsample2D(nn.Module):
|
||||
|
||||
return custom_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), hidden_states)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.forward_body), hidden_states, use_reentrant=USE_REENTRANT
|
||||
)
|
||||
else:
|
||||
hidden_states = self.forward_body(hidden_states)
|
||||
|
||||
@@ -641,7 +644,9 @@ class BasicTransformerBlock(nn.Module):
|
||||
|
||||
return custom_forward
|
||||
|
||||
output = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), hidden_states, context, timestep)
|
||||
output = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.forward_body), hidden_states, context, timestep, use_reentrant=USE_REENTRANT
|
||||
)
|
||||
else:
|
||||
output = self.forward_body(hidden_states, context, timestep)
|
||||
|
||||
@@ -782,7 +787,9 @@ class Upsample2D(nn.Module):
|
||||
|
||||
return custom_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), hidden_states, output_size)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.forward_body), hidden_states, output_size, use_reentrant=USE_REENTRANT
|
||||
)
|
||||
else:
|
||||
hidden_states = self.forward_body(hidden_states, output_size)
|
||||
|
||||
|
||||
@@ -1743,6 +1743,9 @@ class ControlNetDataset(BaseDataset):
|
||||
self.bucket_manager = self.dreambooth_dataset_delegate.bucket_manager
|
||||
self.buckets_indices = self.dreambooth_dataset_delegate.buckets_indices
|
||||
|
||||
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
|
||||
return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process)
|
||||
|
||||
def __len__(self):
|
||||
return self.dreambooth_dataset_delegate.__len__()
|
||||
|
||||
@@ -1775,9 +1778,14 @@ class ControlNetDataset(BaseDataset):
|
||||
h, w = target_size_hw
|
||||
cond_img = cond_img[ct : ct + h, cl : cl + w]
|
||||
else:
|
||||
assert (
|
||||
cond_img.shape[0] == self.height and cond_img.shape[1] == self.width
|
||||
), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
|
||||
# assert (
|
||||
# cond_img.shape[0] == self.height and cond_img.shape[1] == self.width
|
||||
# ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
|
||||
# resize to target
|
||||
if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]:
|
||||
cond_img = cv2.resize(
|
||||
cond_img, (int(target_size_hw[1]), int(target_size_hw[0])), interpolation=cv2.INTER_LANCZOS4
|
||||
)
|
||||
|
||||
if flipped:
|
||||
cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride
|
||||
|
||||
Reference in New Issue
Block a user