This commit is contained in:
gesen2egee
2024-03-10 04:37:16 +08:00
committed by rockerBOO
parent 569ca72fc4
commit 8743532963
3 changed files with 102 additions and 52 deletions

View File

@@ -122,6 +122,20 @@ IMAGE_TRANSFORMS = transforms.Compose(
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"
def split_train_val(paths, is_train, validation_split, validation_seed):
if validation_seed is not None:
print(f"Using validation seed: {validation_seed}")
prevstate = random.getstate()
random.seed(validation_seed)
random.shuffle(paths)
random.setstate(prevstate)
else:
random.shuffle(paths)
if is_train:
return paths[0:math.ceil(len(paths) * (1 - validation_split))]
else:
return paths[len(paths) - round(len(paths) * validation_split):]
def split_train_val(paths, is_train, validation_split, validation_seed):
if validation_seed is not None:
@@ -1352,7 +1366,6 @@ class DreamBoothDataset(BaseDataset):
self.is_train = is_train
self.validation_split = validation_split
self.validation_seed = validation_seed
self.batch_size = batch_size
self.size = min(self.width, self.height) # 短いほう
self.prior_loss_weight = prior_loss_weight
@@ -1405,10 +1418,9 @@ class DreamBoothDataset(BaseDataset):
return [], []
img_paths = glob_images(subset.image_dir, "*")
if self.validation_split > 0.0:
img_paths = split_train_val(img_paths, self.is_train, self.validation_split, self.validation_seed)
print(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
img_paths = split_train_val(img_paths, self.is_train, self.validation_split, self.validation_seed)
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
captions = []