mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
val
This commit is contained in:
@@ -134,6 +134,20 @@ IMAGE_TRANSFORMS = transforms.Compose(
|
||||
|
||||
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"
|
||||
|
||||
def split_train_val(paths, is_train, validation_split, validation_seed):
|
||||
if validation_seed is not None:
|
||||
print(f"Using validation seed: {validation_seed}")
|
||||
prevstate = random.getstate()
|
||||
random.seed(validation_seed)
|
||||
random.shuffle(paths)
|
||||
random.setstate(prevstate)
|
||||
else:
|
||||
random.shuffle(paths)
|
||||
|
||||
if is_train:
|
||||
return paths[0:math.ceil(len(paths) * (1 - validation_split))]
|
||||
else:
|
||||
return paths[len(paths) - round(len(paths) * validation_split):]
|
||||
|
||||
class ImageInfo:
|
||||
def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None:
|
||||
@@ -1360,6 +1374,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
def __init__(
|
||||
self,
|
||||
subsets: Sequence[DreamBoothSubset],
|
||||
is_train: bool,
|
||||
batch_size: int,
|
||||
tokenizer,
|
||||
max_token_length,
|
||||
@@ -1371,12 +1386,17 @@ class DreamBoothDataset(BaseDataset):
|
||||
bucket_reso_steps: int,
|
||||
bucket_no_upscale: bool,
|
||||
prior_loss_weight: float,
|
||||
validation_split: float,
|
||||
validation_seed: Optional[int],
|
||||
debug_dataset: bool,
|
||||
) -> None:
|
||||
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)
|
||||
|
||||
assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
|
||||
|
||||
self.is_train = is_train
|
||||
self.validation_split = validation_split
|
||||
self.validation_seed = validation_seed
|
||||
self.batch_size = batch_size
|
||||
self.size = min(self.width, self.height) # 短いほう
|
||||
self.prior_loss_weight = prior_loss_weight
|
||||
@@ -1429,6 +1449,8 @@ class DreamBoothDataset(BaseDataset):
|
||||
return [], []
|
||||
|
||||
img_paths = glob_images(subset.image_dir, "*")
|
||||
if self.validation_split > 0.0:
|
||||
img_paths = split_train_val(img_paths, self.is_train, self.validation_split, self.validation_seed)
|
||||
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
|
||||
|
||||
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
|
||||
|
||||
Reference in New Issue
Block a user