Fix training, validation split, revert to using upstream implemenation

This commit is contained in:
rockerBOO
2025-01-03 15:20:25 -05:00
parent 6604b36044
commit 0522070d19
5 changed files with 152 additions and 160 deletions

View File

@@ -146,7 +146,15 @@ IMAGE_TRANSFORMS = transforms.Compose(
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz"
def split_train_val(paths: List[str], is_train: bool, validation_split: float, validation_seed: int) -> List[str]:
def split_train_val(paths: List[str], is_training_dataset: bool, validation_split: float, validation_seed: int) -> List[str]:
"""
Split the dataset into train and validation
Shuffle the dataset based on the validation_seed or the current random seed.
For example if the split of 0.2 of 100 images.
[0:79] = 80 training images
[80:] = 20 validation images
"""
if validation_seed is not None:
print(f"Using validation seed: {validation_seed}")
prevstate = random.getstate()
@@ -156,9 +164,12 @@ def split_train_val(paths: List[str], is_train: bool, validation_split: float, v
else:
random.shuffle(paths)
if is_train:
# Split the dataset between training and validation
if is_training_dataset:
# Training dataset we split to the first part
return paths[0:math.ceil(len(paths) * (1 - validation_split))]
else:
# Validation dataset we split to the second part
return paths[len(paths) - round(len(paths) * validation_split):]
@@ -1822,6 +1833,7 @@ class DreamBoothDataset(BaseDataset):
def __init__(
self,
subsets: Sequence[DreamBoothSubset],
is_training_dataset: bool,
batch_size: int,
resolution,
network_multiplier: float,
@@ -1843,6 +1855,7 @@ class DreamBoothDataset(BaseDataset):
self.size = min(self.width, self.height) # 短いほう
self.prior_loss_weight = prior_loss_weight
self.latents_cache = None
self.is_training_dataset = is_training_dataset
self.validation_seed = validation_seed
self.validation_split = validation_split
@@ -1952,6 +1965,9 @@ class DreamBoothDataset(BaseDataset):
size_set_count += 1
logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}")
if self.validation_split > 0.0:
img_paths = split_train_val(img_paths, self.is_training_dataset, self.validation_split, self.validation_seed)
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
if use_cached_info_for_subset:
@@ -2046,7 +2062,8 @@ class DreamBoothDataset(BaseDataset):
subset.img_count = len(img_paths)
self.subsets.append(subset)
logger.info(f"{num_train_images} train images with repeats.")
images_split_name = "train" if self.is_training_dataset else "validation"
logger.info(f"{num_train_images} {images_split_name} images with repeats.")
self.num_train_images = num_train_images
@@ -2411,8 +2428,12 @@ class ControlNetDataset(BaseDataset):
conditioning_img_paths = [os.path.abspath(p) for p in conditioning_img_paths] # normalize path
extra_imgs.extend([p for p in conditioning_img_paths if os.path.splitext(p)[0] not in cond_imgs_with_pair])
assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}"
assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}"
assert (
len(missing_imgs) == 0
), f"missing conditioning data for {len(missing_imgs)} images / 制御用画像が見つかりませんでした: {missing_imgs}"
assert (
len(extra_imgs) == 0
), f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}"
self.conditioning_image_transforms = IMAGE_TRANSFORMS
@@ -4586,7 +4607,6 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar
config_args = argparse.Namespace(**ignore_nesting_dict)
args = parser.parse_args(namespace=config_args)
args.config_file = os.path.splitext(args.config_file)[0]
logger.info(args.config_file)
return args
@@ -5880,55 +5900,35 @@ def save_sd_model_on_train_end_common(
huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True)
def get_random_timesteps(args, min_timestep: int, max_timestep: int, batch_size: int, device: torch.device) -> torch.IntTensor:
"""
Get a random timestep between the min and max timesteps
Can error (NotImplementedError) if the loss type is not supported
"""
# TODO: if a huber loss is selected, it will use constant timesteps for each batch
# as. In the future there may be a smarter way
if args.loss_type == "huber" or args.loss_type == "smooth_l1":
timesteps = torch.randint(min_timestep, max_timestep, (1,), device="cpu")
timesteps = timesteps.repeat(batch_size).to(device)
elif args.loss_type == "l2":
timesteps = torch.randint(min_timestep, max_timestep, (batch_size,), device=device)
else:
raise NotImplementedError(f"Unknown loss type {args.loss_type}")
return typing.cast(torch.IntTensor, timesteps)
def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: torch.device = torch.device("cpu")) -> torch.IntTensor:
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device)
return timesteps
def get_huber_c(args, noise_scheduler: DDPMScheduler, timesteps: torch.IntTensor) -> Optional[float]:
"""
Calculate the Huber convolution (huber_c) value
Huber loss is a loss function used in robust regression, that is less sensitive
to outliers in data than the squared error loss.
https://en.wikipedia.org/wiki/Huber_loss
"""
if args.loss_type == "huber" or args.loss_type == "smooth_l1":
if args.huber_schedule == "exponential":
alpha = -math.log(args.huber_c) / noise_scheduler.config.get('num_train_timesteps', 1000)
huber_c = math.exp(-alpha * timesteps.item())
elif args.huber_schedule == "snr":
if not hasattr(noise_scheduler, "alphas_cumprod"):
raise NotImplementedError("Huber schedule 'snr' is not supported with the current model.")
alphas_cumprod = noise_scheduler.alphas_cumprod.index_select(0, timesteps)
sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c
elif args.huber_schedule == "constant":
huber_c = args.huber_c
else:
raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!")
elif args.loss_type == "l2":
def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler) -> Optional[torch.Tensor]:
if not (args.loss_type == "huber" or args.loss_type == "smooth_l1"):
return None
b_size = timesteps.shape[0]
if args.huber_schedule == "exponential":
alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
result = torch.exp(-alpha * timesteps) * args.huber_scale
elif args.huber_schedule == "snr":
if not hasattr(noise_scheduler, "alphas_cumprod"):
raise NotImplementedError("Huber schedule 'snr' is not supported with the current model.")
alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps.cpu())
sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
result = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c
result = result.to(timesteps.device)
elif args.huber_schedule == "constant":
result = torch.full((b_size,), args.huber_c * args.huber_scale, device=timesteps.device)
else:
raise NotImplementedError(f"Unknown loss type {args.loss_type}")
raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!")
return huber_c
return result
def modify_noise(args, noise: torch.Tensor, latents: torch.Tensor):
def modify_noise(args, noise: torch.Tensor, latents: torch.Tensor) -> torch.FloatTensor:
"""
Apply noise modifications like noise offset and multires noise
"""
@@ -5964,27 +5964,44 @@ def make_random_timesteps(args, noise_scheduler: DDPMScheduler, batch_size: int,
max_timestep = noise_scheduler.config.get('num_train_timesteps', 1000) if args.max_timestep is None else args.max_timestep
# Sample a random timestep for each image
timesteps = get_random_timesteps(args, min_timestep, max_timestep, batch_size, device)
timesteps = get_timesteps(min_timestep, max_timestep, batch_size, device)
return timesteps
def get_noise_noisy_latents_and_timesteps(args, noise_scheduler: DDPMScheduler, latents: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor, Optional[float]]:
"""
Unified noise, noisy_latents, timesteps and huber loss convolution calculations
"""
batch_size = latents.shape[0]
def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor]:
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:
if args.noise_offset_random_strength:
noise_offset = torch.rand(1, device=latents.device) * args.noise_offset
else:
noise_offset = args.noise_offset
noise = custom_train_functions.apply_noise_offset(latents, noise, noise_offset, args.adaptive_noise_scale)
if args.multires_noise_iterations:
noise = custom_train_functions.pyramid_noise_like(
noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount
)
# Sample a random timestep for each image
b_size = latents.shape[0]
min_timestep = 0 if args.min_timestep is None else args.min_timestep
max_timestep = noise_scheduler.config.get("num_train_timesteps", 1000) if args.max_timestep is None else args.max_timestep
max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep
# A random timestep for each image in the batch
timesteps = get_random_timesteps(args, min_timestep, max_timestep, batch_size, latents.device)
huber_c = get_huber_c(args, noise_scheduler, timesteps)
timesteps = get_timesteps(min_timestep, max_timestep, b_size, latents.device)
noise = make_noise(args, latents)
noisy_latents = get_noisy_latents(args, noise, noise_scheduler, latents, timesteps)
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
if args.ip_noise_gamma:
if args.ip_noise_gamma_random_strength:
strength = torch.rand(1, device=latents.device) * args.ip_noise_gamma
else:
strength = args.ip_noise_gamma
noisy_latents = noise_scheduler.add_noise(latents, noise + strength * torch.randn_like(latents), timesteps)
else:
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
return noise, noisy_latents, timesteps, huber_c
return noise, noisy_latents, timesteps
def get_noisy_latents(args, noise: torch.FloatTensor, noise_scheduler: DDPMScheduler, latents: torch.FloatTensor, timesteps: torch.IntTensor) -> torch.FloatTensor:
@@ -6015,6 +6032,8 @@ def conditional_loss(
elif loss_type == "l1":
loss = torch.nn.functional.l1_loss(model_pred, target, reduction=reduction)
elif loss_type == "huber":
if huber_c is None:
raise NotImplementedError("huber_c not implemented correctly")
huber_c = huber_c.view(-1, 1, 1, 1)
loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
if reduction == "mean":
@@ -6022,6 +6041,8 @@ def conditional_loss(
elif reduction == "sum":
loss = torch.sum(loss)
elif loss_type == "smooth_l1":
if huber_c is None:
raise NotImplementedError("huber_c not implemented correctly")
huber_c = huber_c.view(-1, 1, 1, 1)
loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
if reduction == "mean":