This commit is contained in:
gesen2egee
2024-03-10 04:37:16 +08:00
committed by rockerBOO
parent 569ca72fc4
commit 8743532963
3 changed files with 102 additions and 52 deletions

View File

@@ -81,23 +81,24 @@ class ControlNetSubsetParams(BaseSubsetParams):
@dataclass
class BaseDatasetParams:
tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None
max_token_length: int = None
resolution: Optional[Tuple[int, int]] = None
debug_dataset: bool = False
validation_seed: Optional[int] = None
validation_split: float = 0.0
tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None
max_token_length: int = None
resolution: Optional[Tuple[int, int]] = None
network_multiplier: float = 1.0
debug_dataset: bool = False
validation_seed: Optional[int] = None
validation_split: float = 0.0
@dataclass
class DreamBoothDatasetParams(BaseDatasetParams):
batch_size: int = 1
enable_bucket: bool = False
min_bucket_reso: int = 256
max_bucket_reso: int = 1024
bucket_reso_steps: int = 64
bucket_no_upscale: bool = False
prior_loss_weight: float = 1.0
batch_size: int = 1
enable_bucket: bool = False
min_bucket_reso: int = 256
max_bucket_reso: int = 1024
bucket_reso_steps: int = 64
bucket_no_upscale: bool = False
prior_loss_weight: float = 1.0
@dataclass
class FineTuningDatasetParams(BaseDatasetParams):
batch_size: int = 1
@@ -203,8 +204,9 @@ class ConfigSanitizer:
"max_bucket_reso": int,
"min_bucket_reso": int,
"validation_seed": int,
"validation_split": float,
"validation_split": float,
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
"network_multiplier": float,
}
# options handled by argparse but not handled by user config

View File

@@ -122,6 +122,20 @@ IMAGE_TRANSFORMS = transforms.Compose(
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"
def split_train_val(paths, is_train, validation_split, validation_seed):
if validation_seed is not None:
print(f"Using validation seed: {validation_seed}")
prevstate = random.getstate()
random.seed(validation_seed)
random.shuffle(paths)
random.setstate(prevstate)
else:
random.shuffle(paths)
if is_train:
return paths[0:math.ceil(len(paths) * (1 - validation_split))]
else:
return paths[len(paths) - round(len(paths) * validation_split):]
def split_train_val(paths, is_train, validation_split, validation_seed):
if validation_seed is not None:
@@ -1352,7 +1366,6 @@ class DreamBoothDataset(BaseDataset):
self.is_train = is_train
self.validation_split = validation_split
self.validation_seed = validation_seed
self.batch_size = batch_size
self.size = min(self.width, self.height) # 短いほう
self.prior_loss_weight = prior_loss_weight
@@ -1405,10 +1418,9 @@ class DreamBoothDataset(BaseDataset):
return [], []
img_paths = glob_images(subset.image_dir, "*")
if self.validation_split > 0.0:
img_paths = split_train_val(img_paths, self.is_train, self.validation_split, self.validation_seed)
print(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
img_paths = split_train_val(img_paths, self.is_train, self.validation_split, self.validation_seed)
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
captions = []