mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix wandb val logging
This commit is contained in:
@@ -13,17 +13,7 @@ import re
|
||||
import shutil
|
||||
import time
|
||||
import typing
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union
|
||||
)
|
||||
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
|
||||
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState
|
||||
import glob
|
||||
import math
|
||||
@@ -146,12 +136,13 @@ IMAGE_TRANSFORMS = transforms.Compose(
|
||||
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"
|
||||
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz"
|
||||
|
||||
|
||||
def split_train_val(
|
||||
paths: List[str],
|
||||
paths: List[str],
|
||||
sizes: List[Optional[Tuple[int, int]]],
|
||||
is_training_dataset: bool,
|
||||
validation_split: float,
|
||||
validation_seed: int | None
|
||||
is_training_dataset: bool,
|
||||
validation_split: float,
|
||||
validation_seed: int | None,
|
||||
) -> Tuple[List[str], List[Optional[Tuple[int, int]]]]:
|
||||
"""
|
||||
Split the dataset into train and validation
|
||||
@@ -1842,7 +1833,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
class DreamBoothDataset(BaseDataset):
|
||||
IMAGE_INFO_CACHE_FILE = "metadata_cache.json"
|
||||
|
||||
# The is_training_dataset defines the type of dataset, training or validation
|
||||
# The is_training_dataset defines the type of dataset, training or validation
|
||||
# if is_training_dataset is True -> training dataset
|
||||
# if is_training_dataset is False -> validation dataset
|
||||
def __init__(
|
||||
@@ -1981,29 +1972,25 @@ class DreamBoothDataset(BaseDataset):
|
||||
logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}")
|
||||
|
||||
# We want to create a training and validation split. This should be improved in the future
|
||||
# to allow a clearer distinction between training and validation. This can be seen as a
|
||||
# to allow a clearer distinction between training and validation. This can be seen as a
|
||||
# short-term solution to limit what is necessary to implement validation datasets
|
||||
#
|
||||
#
|
||||
# We split the dataset for the subset based on if we are doing a validation split
|
||||
# The self.is_training_dataset defines the type of dataset, training or validation
|
||||
# The self.is_training_dataset defines the type of dataset, training or validation
|
||||
# if self.is_training_dataset is True -> training dataset
|
||||
# if self.is_training_dataset is False -> validation dataset
|
||||
if self.validation_split > 0.0:
|
||||
# For regularization images we do not want to split this dataset.
|
||||
# For regularization images we do not want to split this dataset.
|
||||
if subset.is_reg is True:
|
||||
# Skip any validation dataset for regularization images
|
||||
if self.is_training_dataset is False:
|
||||
img_paths = []
|
||||
sizes = []
|
||||
# Otherwise the img_paths remain as original img_paths and no split
|
||||
# Otherwise the img_paths remain as original img_paths and no split
|
||||
# required for training images dataset of regularization images
|
||||
else:
|
||||
img_paths, sizes = split_train_val(
|
||||
img_paths,
|
||||
sizes,
|
||||
self.is_training_dataset,
|
||||
self.validation_split,
|
||||
self.validation_seed
|
||||
img_paths, sizes, self.is_training_dataset, self.validation_split, self.validation_seed
|
||||
)
|
||||
|
||||
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
|
||||
@@ -2373,7 +2360,7 @@ class ControlNetDataset(BaseDataset):
|
||||
bucket_no_upscale: bool,
|
||||
debug_dataset: bool,
|
||||
validation_split: float,
|
||||
validation_seed: Optional[int],
|
||||
validation_seed: Optional[int],
|
||||
) -> None:
|
||||
super().__init__(resolution, network_multiplier, debug_dataset)
|
||||
|
||||
@@ -2431,9 +2418,9 @@ class ControlNetDataset(BaseDataset):
|
||||
self.image_data = self.dreambooth_dataset_delegate.image_data
|
||||
self.batch_size = batch_size
|
||||
self.num_train_images = self.dreambooth_dataset_delegate.num_train_images
|
||||
self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images
|
||||
self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images
|
||||
self.validation_split = validation_split
|
||||
self.validation_seed = validation_seed
|
||||
self.validation_seed = validation_seed
|
||||
|
||||
# assert all conditioning data exists
|
||||
missing_imgs = []
|
||||
@@ -5952,7 +5939,9 @@ def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: tor
|
||||
return timesteps
|
||||
|
||||
|
||||
def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor]:
|
||||
def get_noise_noisy_latents_and_timesteps(
|
||||
args, noise_scheduler, latents: torch.FloatTensor
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor]:
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents, device=latents.device)
|
||||
if args.noise_offset:
|
||||
@@ -6444,7 +6433,7 @@ def sample_image_inference(
|
||||
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
|
||||
|
||||
|
||||
def init_trackers(accelerator: Accelerator, args: argparse.Namespace, default_tracker_name: str):
|
||||
def init_trackers(accelerator: Accelerator, args: argparse.Namespace, default_tracker_name: str):
|
||||
"""
|
||||
Initialize experiment trackers with tracker specific behaviors
|
||||
"""
|
||||
@@ -6461,13 +6450,17 @@ def init_trackers(accelerator: Accelerator, args: argparse.Namespace, default_tr
|
||||
)
|
||||
|
||||
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
|
||||
import wandb
|
||||
import wandb
|
||||
|
||||
wandb_tracker = accelerator.get_tracker("wandb", unwrap=True)
|
||||
|
||||
# Define specific metrics to handle validation and epochs "steps"
|
||||
wandb_tracker.define_metric("epoch", hidden=True)
|
||||
wandb_tracker.define_metric("val_step", hidden=True)
|
||||
|
||||
wandb_tracker.define_metric("global_step", hidden=True)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user