mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 08:36:41 +00:00
fix train controlnet
This commit is contained in:
@@ -1982,8 +1982,8 @@ class ControlNetDataset(BaseDataset):
|
||||
self.bucket_manager = self.dreambooth_dataset_delegate.bucket_manager
|
||||
self.buckets_indices = self.dreambooth_dataset_delegate.buckets_indices
|
||||
|
||||
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
|
||||
return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process)
|
||||
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, cache_file_suffix=".npz", divisor=8):
|
||||
return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process, cache_file_suffix, divisor)
|
||||
|
||||
def __len__(self):
|
||||
return self.dreambooth_dataset_delegate.__len__()
|
||||
|
||||
@@ -17,6 +17,7 @@ easygui==0.98.3
|
||||
toml==0.10.2
|
||||
voluptuous==0.13.1
|
||||
huggingface-hub==0.20.1
|
||||
omegaconf==2.3.0
|
||||
# for Image utils
|
||||
imagesize==1.4.1
|
||||
# for BLIP captioning
|
||||
|
||||
@@ -5,7 +5,7 @@ import os
|
||||
import random
|
||||
import time
|
||||
from multiprocessing import Value
|
||||
from types import SimpleNamespace
|
||||
from omegaconf import OmegaConf
|
||||
import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
@@ -148,8 +148,10 @@ def train(args):
|
||||
"in_channels": 4,
|
||||
"layers_per_block": 2,
|
||||
"mid_block_scale_factor": 1,
|
||||
"mid_block_type": "UNetMidBlock2DCrossAttn",
|
||||
"norm_eps": 1e-05,
|
||||
"norm_num_groups": 32,
|
||||
"num_attention_heads": [5, 10, 20, 20],
|
||||
"num_class_embeds": None,
|
||||
"only_cross_attention": False,
|
||||
"out_channels": 4,
|
||||
@@ -179,8 +181,10 @@ def train(args):
|
||||
"in_channels": 4,
|
||||
"layers_per_block": 2,
|
||||
"mid_block_scale_factor": 1,
|
||||
"mid_block_type": "UNetMidBlock2DCrossAttn",
|
||||
"norm_eps": 1e-05,
|
||||
"norm_num_groups": 32,
|
||||
"num_attention_heads": 8,
|
||||
"out_channels": 4,
|
||||
"sample_size": 64,
|
||||
"up_block_types": ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"],
|
||||
@@ -193,7 +197,7 @@ def train(args):
|
||||
"resnet_time_scale_shift": "default",
|
||||
"projection_class_embeddings_input_dim": None,
|
||||
}
|
||||
unet.config = SimpleNamespace(**unet.config)
|
||||
unet.config = OmegaConf.create(unet.config)
|
||||
|
||||
controlnet = ControlNetModel.from_unet(unet)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user