fix train controlnet

This commit is contained in:
青龍聖者@bdsqlsz
2024-04-20 21:26:09 +08:00
parent 71e2c91330
commit 4477116a64
3 changed files with 9 additions and 4 deletions

View File

@@ -1982,8 +1982,8 @@ class ControlNetDataset(BaseDataset):
self.bucket_manager = self.dreambooth_dataset_delegate.bucket_manager
self.buckets_indices = self.dreambooth_dataset_delegate.buckets_indices
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process)
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, cache_file_suffix=".npz", divisor=8):
return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process, cache_file_suffix, divisor)
def __len__(self):
return self.dreambooth_dataset_delegate.__len__()

View File

@@ -17,6 +17,7 @@ easygui==0.98.3
toml==0.10.2
voluptuous==0.13.1
huggingface-hub==0.20.1
omegaconf==2.3.0
# for Image utils
imagesize==1.4.1
# for BLIP captioning

View File

@@ -5,7 +5,7 @@ import os
import random
import time
from multiprocessing import Value
from types import SimpleNamespace
from omegaconf import OmegaConf
import toml
from tqdm import tqdm
@@ -148,8 +148,10 @@ def train(args):
"in_channels": 4,
"layers_per_block": 2,
"mid_block_scale_factor": 1,
"mid_block_type": "UNetMidBlock2DCrossAttn",
"norm_eps": 1e-05,
"norm_num_groups": 32,
"num_attention_heads": [5, 10, 20, 20],
"num_class_embeds": None,
"only_cross_attention": False,
"out_channels": 4,
@@ -179,8 +181,10 @@ def train(args):
"in_channels": 4,
"layers_per_block": 2,
"mid_block_scale_factor": 1,
"mid_block_type": "UNetMidBlock2DCrossAttn",
"norm_eps": 1e-05,
"norm_num_groups": 32,
"num_attention_heads": 8,
"out_channels": 4,
"sample_size": 64,
"up_block_types": ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"],
@@ -193,7 +197,7 @@ def train(args):
"resnet_time_scale_shift": "default",
"projection_class_embeddings_input_dim": None,
}
unet.config = SimpleNamespace(**unet.config)
unet.config = OmegaConf.create(unet.config)
controlnet = ControlNetModel.from_unet(unet)