Add Validation loss for LoRA training

This commit is contained in:
Hina Chen
2024-12-27 16:47:59 +08:00
parent e89653975d
commit 05bb9183fa
3 changed files with 257 additions and 6 deletions

View File

@@ -145,6 +145,17 @@ IMAGE_TRANSFORMS = transforms.Compose(
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz"
def split_train_val(paths: List[str], validation_split: float, validation_seed: int) -> List[str]:
if validation_seed is not None:
print(f"Using validation seed: {validation_seed}")
prevstate = random.getstate()
random.seed(validation_seed)
random.shuffle(paths)
random.setstate(prevstate)
else:
random.shuffle(paths)
return paths[len(paths) - round(len(paths) * validation_split):]
class ImageInfo:
def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None:
@@ -397,6 +408,8 @@ class BaseSubset:
token_warmup_min: int,
token_warmup_step: Union[float, int],
custom_attributes: Optional[Dict[str, Any]] = None,
validation_seed: Optional[int] = None,
validation_split: Optional[float] = 0.0,
) -> None:
self.image_dir = image_dir
self.alpha_mask = alpha_mask if alpha_mask is not None else False
@@ -424,6 +437,9 @@ class BaseSubset:
self.img_count = 0
self.validation_seed = validation_seed
self.validation_split = validation_split
class DreamBoothSubset(BaseSubset):
def __init__(
@@ -453,6 +469,8 @@ class DreamBoothSubset(BaseSubset):
token_warmup_min,
token_warmup_step,
custom_attributes: Optional[Dict[str, Any]] = None,
validation_seed: Optional[int] = None,
validation_split: Optional[float] = 0.0,
) -> None:
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
@@ -478,6 +496,8 @@ class DreamBoothSubset(BaseSubset):
token_warmup_min,
token_warmup_step,
custom_attributes=custom_attributes,
validation_seed=validation_seed,
validation_split=validation_split,
)
self.is_reg = is_reg
@@ -518,6 +538,8 @@ class FineTuningSubset(BaseSubset):
token_warmup_min,
token_warmup_step,
custom_attributes: Optional[Dict[str, Any]] = None,
validation_seed: Optional[int] = None,
validation_split: Optional[float] = 0.0,
) -> None:
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
@@ -543,6 +565,8 @@ class FineTuningSubset(BaseSubset):
token_warmup_min,
token_warmup_step,
custom_attributes=custom_attributes,
validation_seed=validation_seed,
validation_split=validation_split,
)
self.metadata_file = metadata_file
@@ -579,6 +603,8 @@ class ControlNetSubset(BaseSubset):
token_warmup_min,
token_warmup_step,
custom_attributes: Optional[Dict[str, Any]] = None,
validation_seed: Optional[int] = None,
validation_split: Optional[float] = 0.0,
) -> None:
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
@@ -604,6 +630,8 @@ class ControlNetSubset(BaseSubset):
token_warmup_min,
token_warmup_step,
custom_attributes=custom_attributes,
validation_seed=validation_seed,
validation_split=validation_split,
)
self.conditioning_data_dir = conditioning_data_dir
@@ -1799,6 +1827,9 @@ class DreamBoothDataset(BaseDataset):
bucket_no_upscale: bool,
prior_loss_weight: float,
debug_dataset: bool,
is_train: bool,
validation_seed: int,
validation_split: float,
) -> None:
super().__init__(resolution, network_multiplier, debug_dataset)
@@ -1808,6 +1839,9 @@ class DreamBoothDataset(BaseDataset):
self.size = min(self.width, self.height) # 短いほう
self.prior_loss_weight = prior_loss_weight
self.latents_cache = None
self.is_train = is_train
self.validation_seed = validation_seed
self.validation_split = validation_split
self.enable_bucket = enable_bucket
if self.enable_bucket:
@@ -1992,6 +2026,9 @@ class DreamBoothDataset(BaseDataset):
)
continue
if self.is_train == False:
img_paths = split_train_val(img_paths, self.validation_split, self.validation_seed)
if subset.is_reg:
num_reg_images += subset.num_repeats * len(img_paths)
else:
@@ -2009,7 +2046,11 @@ class DreamBoothDataset(BaseDataset):
subset.img_count = len(img_paths)
self.subsets.append(subset)
logger.info(f"{num_train_images} train images with repeating.")
if self.is_train:
logger.info(f"{num_train_images} train images with repeating.")
else:
logger.info(f"{num_train_images} validation images with repeating.")
self.num_train_images = num_train_images
logger.info(f"{num_reg_images} reg images.")
@@ -2050,6 +2091,9 @@ class FineTuningDataset(BaseDataset):
bucket_reso_steps: int,
bucket_no_upscale: bool,
debug_dataset: bool,
is_train: bool,
validation_seed: int,
validation_split: float,
) -> None:
super().__init__(resolution, network_multiplier, debug_dataset)
@@ -2276,6 +2320,9 @@ class ControlNetDataset(BaseDataset):
bucket_reso_steps: int,
bucket_no_upscale: bool,
debug_dataset: float,
is_train: bool,
validation_seed: int,
validation_split: float,
) -> None:
super().__init__(resolution, network_multiplier, debug_dataset)
@@ -2324,6 +2371,9 @@ class ControlNetDataset(BaseDataset):
bucket_no_upscale,
1.0,
debug_dataset,
is_train,
validation_seed,
validation_split,
)
# config_util等から参照される値をいれておく若干微妙なのでなんとかしたい
@@ -4887,7 +4937,7 @@ def get_optimizer(args, trainable_params) -> tuple[str, str, object]:
import schedulefree as sf
except ImportError:
raise ImportError("No schedulefree / schedulefreeがインストールされていないようです")
if optimizer_type == "RAdamScheduleFree".lower():
optimizer_class = sf.RAdamScheduleFree
logger.info(f"use RAdamScheduleFree optimizer | {optimizer_kwargs}")