mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
val
This commit is contained in:
@@ -81,23 +81,24 @@ class ControlNetSubsetParams(BaseSubsetParams):
|
||||
|
||||
@dataclass
|
||||
class BaseDatasetParams:
|
||||
tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None
|
||||
max_token_length: int = None
|
||||
resolution: Optional[Tuple[int, int]] = None
|
||||
debug_dataset: bool = False
|
||||
validation_seed: Optional[int] = None
|
||||
validation_split: float = 0.0
|
||||
tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None
|
||||
max_token_length: int = None
|
||||
resolution: Optional[Tuple[int, int]] = None
|
||||
network_multiplier: float = 1.0
|
||||
debug_dataset: bool = False
|
||||
validation_seed: Optional[int] = None
|
||||
validation_split: float = 0.0
|
||||
|
||||
@dataclass
|
||||
class DreamBoothDatasetParams(BaseDatasetParams):
|
||||
batch_size: int = 1
|
||||
enable_bucket: bool = False
|
||||
min_bucket_reso: int = 256
|
||||
max_bucket_reso: int = 1024
|
||||
bucket_reso_steps: int = 64
|
||||
bucket_no_upscale: bool = False
|
||||
prior_loss_weight: float = 1.0
|
||||
|
||||
batch_size: int = 1
|
||||
enable_bucket: bool = False
|
||||
min_bucket_reso: int = 256
|
||||
max_bucket_reso: int = 1024
|
||||
bucket_reso_steps: int = 64
|
||||
bucket_no_upscale: bool = False
|
||||
prior_loss_weight: float = 1.0
|
||||
|
||||
@dataclass
|
||||
class FineTuningDatasetParams(BaseDatasetParams):
|
||||
batch_size: int = 1
|
||||
@@ -203,8 +204,9 @@ class ConfigSanitizer:
|
||||
"max_bucket_reso": int,
|
||||
"min_bucket_reso": int,
|
||||
"validation_seed": int,
|
||||
"validation_split": float,
|
||||
"validation_split": float,
|
||||
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
|
||||
"network_multiplier": float,
|
||||
}
|
||||
|
||||
# options handled by argparse but not handled by user config
|
||||
|
||||
@@ -122,6 +122,20 @@ IMAGE_TRANSFORMS = transforms.Compose(
|
||||
|
||||
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"
|
||||
|
||||
def split_train_val(paths, is_train, validation_split, validation_seed):
|
||||
if validation_seed is not None:
|
||||
print(f"Using validation seed: {validation_seed}")
|
||||
prevstate = random.getstate()
|
||||
random.seed(validation_seed)
|
||||
random.shuffle(paths)
|
||||
random.setstate(prevstate)
|
||||
else:
|
||||
random.shuffle(paths)
|
||||
|
||||
if is_train:
|
||||
return paths[0:math.ceil(len(paths) * (1 - validation_split))]
|
||||
else:
|
||||
return paths[len(paths) - round(len(paths) * validation_split):]
|
||||
|
||||
def split_train_val(paths, is_train, validation_split, validation_seed):
|
||||
if validation_seed is not None:
|
||||
@@ -1352,7 +1366,6 @@ class DreamBoothDataset(BaseDataset):
|
||||
self.is_train = is_train
|
||||
self.validation_split = validation_split
|
||||
self.validation_seed = validation_seed
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.size = min(self.width, self.height) # 短いほう
|
||||
self.prior_loss_weight = prior_loss_weight
|
||||
@@ -1405,10 +1418,9 @@ class DreamBoothDataset(BaseDataset):
|
||||
return [], []
|
||||
|
||||
img_paths = glob_images(subset.image_dir, "*")
|
||||
|
||||
if self.validation_split > 0.0:
|
||||
img_paths = split_train_val(img_paths, self.is_train, self.validation_split, self.validation_seed)
|
||||
print(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
|
||||
img_paths = split_train_val(img_paths, self.is_train, self.validation_split, self.validation_seed)
|
||||
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
|
||||
|
||||
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
|
||||
captions = []
|
||||
|
||||
Reference in New Issue
Block a user