Merge branch 'sd3' into lumina

This commit is contained in:
rockerBOO
2025-02-26 22:13:56 -05:00
14 changed files with 684 additions and 813 deletions

View File

@@ -13,17 +13,7 @@ import re
import shutil
import time
import typing
from typing import (
Any,
Callable,
Dict,
List,
NamedTuple,
Optional,
Sequence,
Tuple,
Union
)
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState
import glob
import math
@@ -146,12 +136,13 @@ IMAGE_TRANSFORMS = transforms.Compose(
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz"
def split_train_val(
paths: List[str],
paths: List[str],
sizes: List[Optional[Tuple[int, int]]],
is_training_dataset: bool,
validation_split: float,
validation_seed: int | None
is_training_dataset: bool,
validation_split: float,
validation_seed: int | None,
) -> Tuple[List[str], List[Optional[Tuple[int, int]]]]:
"""
Split the dataset into train and validation
@@ -161,15 +152,19 @@ def split_train_val(
[0:80] = 80 training images
[80:] = 20 validation images
"""
dataset = list(zip(paths, sizes))
if validation_seed is not None:
logging.info(f"Using validation seed: {validation_seed}")
prevstate = random.getstate()
random.seed(validation_seed)
random.shuffle(paths)
random.shuffle(dataset)
random.setstate(prevstate)
else:
random.shuffle(paths)
random.shuffle(dataset)
paths, sizes = zip(*dataset)
paths = list(paths)
sizes = list(sizes)
# Split the dataset between training and validation
if is_training_dataset:
# Training dataset we split to the first part
@@ -1850,7 +1845,7 @@ class BaseDataset(torch.utils.data.Dataset):
class DreamBoothDataset(BaseDataset):
IMAGE_INFO_CACHE_FILE = "metadata_cache.json"
# The is_training_dataset defines the type of dataset, training or validation
# The is_training_dataset defines the type of dataset, training or validation
# if is_training_dataset is True -> training dataset
# if is_training_dataset is False -> validation dataset
def __init__(
@@ -1991,29 +1986,25 @@ class DreamBoothDataset(BaseDataset):
logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}")
# We want to create a training and validation split. This should be improved in the future
# to allow a clearer distinction between training and validation. This can be seen as a
# to allow a clearer distinction between training and validation. This can be seen as a
# short-term solution to limit what is necessary to implement validation datasets
#
#
# We split the dataset for the subset based on if we are doing a validation split
# The self.is_training_dataset defines the type of dataset, training or validation
# The self.is_training_dataset defines the type of dataset, training or validation
# if self.is_training_dataset is True -> training dataset
# if self.is_training_dataset is False -> validation dataset
if self.validation_split > 0.0:
# For regularization images we do not want to split this dataset.
# For regularization images we do not want to split this dataset.
if subset.is_reg is True:
# Skip any validation dataset for regularization images
if self.is_training_dataset is False:
img_paths = []
sizes = []
# Otherwise the img_paths remain as original img_paths and no split
# Otherwise the img_paths remain as original img_paths and no split
# required for training images dataset of regularization images
else:
img_paths, sizes = split_train_val(
img_paths,
sizes,
self.is_training_dataset,
self.validation_split,
self.validation_seed
img_paths, sizes, self.is_training_dataset, self.validation_split, self.validation_seed
)
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
@@ -2385,7 +2376,7 @@ class ControlNetDataset(BaseDataset):
bucket_no_upscale: bool,
debug_dataset: bool,
validation_split: float,
validation_seed: Optional[int],
validation_seed: Optional[int],
) -> None:
super().__init__(resolution, network_multiplier, debug_dataset)
@@ -2443,9 +2434,9 @@ class ControlNetDataset(BaseDataset):
self.image_data = self.dreambooth_dataset_delegate.image_data
self.batch_size = batch_size
self.num_train_images = self.dreambooth_dataset_delegate.num_train_images
self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images
self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images
self.validation_split = validation_split
self.validation_seed = validation_seed
self.validation_seed = validation_seed
# assert all conditioning data exists
missing_imgs = []
@@ -5958,12 +5949,17 @@ def save_sd_model_on_train_end_common(
def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: torch.device) -> torch.Tensor:
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu")
if min_timestep < max_timestep:
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu")
else:
timesteps = torch.full((b_size,), max_timestep, device="cpu")
timesteps = timesteps.long().to(device)
return timesteps
def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor]:
def get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents: torch.FloatTensor
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor]:
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:
@@ -6465,7 +6461,7 @@ def sample_image_inference(
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
def init_trackers(accelerator: Accelerator, args: argparse.Namespace, default_tracker_name: str):
def init_trackers(accelerator: Accelerator, args: argparse.Namespace, default_tracker_name: str):
"""
Initialize experiment trackers with tracker specific behaviors
"""
@@ -6482,13 +6478,17 @@ def init_trackers(accelerator: Accelerator, args: argparse.Namespace, default_tr
)
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
import wandb
import wandb
wandb_tracker = accelerator.get_tracker("wandb", unwrap=True)
# Define specific metrics to handle validation and epochs "steps"
wandb_tracker.define_metric("epoch", hidden=True)
wandb_tracker.define_metric("val_step", hidden=True)
wandb_tracker.define_metric("global_step", hidden=True)
# endregion