This commit is contained in:
gesen2egee
2024-03-10 04:37:16 +08:00
parent 2d7389185c
commit b558a5b73d
3 changed files with 237 additions and 88 deletions

View File

@@ -134,6 +134,20 @@ IMAGE_TRANSFORMS = transforms.Compose(
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"
def split_train_val(paths, is_train, validation_split, validation_seed):
if validation_seed is not None:
print(f"Using validation seed: {validation_seed}")
prevstate = random.getstate()
random.seed(validation_seed)
random.shuffle(paths)
random.setstate(prevstate)
else:
random.shuffle(paths)
if is_train:
return paths[0:math.ceil(len(paths) * (1 - validation_split))]
else:
return paths[len(paths) - round(len(paths) * validation_split):]
class ImageInfo:
def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None:
@@ -1360,6 +1374,7 @@ class DreamBoothDataset(BaseDataset):
def __init__(
self,
subsets: Sequence[DreamBoothSubset],
is_train: bool,
batch_size: int,
tokenizer,
max_token_length,
@@ -1371,12 +1386,17 @@ class DreamBoothDataset(BaseDataset):
bucket_reso_steps: int,
bucket_no_upscale: bool,
prior_loss_weight: float,
validation_split: float,
validation_seed: Optional[int],
debug_dataset: bool,
) -> None:
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)
assert resolution is not None, f"resolution is required / resolution解像度指定は必須です"
self.is_train = is_train
self.validation_split = validation_split
self.validation_seed = validation_seed
self.batch_size = batch_size
self.size = min(self.width, self.height) # 短いほう
self.prior_loss_weight = prior_loss_weight
@@ -1429,6 +1449,8 @@ class DreamBoothDataset(BaseDataset):
return [], []
img_paths = glob_images(subset.image_dir, "*")
if self.validation_split > 0.0:
img_paths = split_train_val(img_paths, self.is_train, self.validation_split, self.validation_seed)
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う