mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix wandb val logging
This commit is contained in:
@@ -13,17 +13,7 @@ import re
|
|||||||
import shutil
|
import shutil
|
||||||
import time
|
import time
|
||||||
import typing
|
import typing
|
||||||
from typing import (
|
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
|
||||||
Any,
|
|
||||||
Callable,
|
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
NamedTuple,
|
|
||||||
Optional,
|
|
||||||
Sequence,
|
|
||||||
Tuple,
|
|
||||||
Union
|
|
||||||
)
|
|
||||||
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState
|
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState
|
||||||
import glob
|
import glob
|
||||||
import math
|
import math
|
||||||
@@ -146,12 +136,13 @@ IMAGE_TRANSFORMS = transforms.Compose(
|
|||||||
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"
|
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"
|
||||||
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz"
|
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz"
|
||||||
|
|
||||||
|
|
||||||
def split_train_val(
|
def split_train_val(
|
||||||
paths: List[str],
|
paths: List[str],
|
||||||
sizes: List[Optional[Tuple[int, int]]],
|
sizes: List[Optional[Tuple[int, int]]],
|
||||||
is_training_dataset: bool,
|
is_training_dataset: bool,
|
||||||
validation_split: float,
|
validation_split: float,
|
||||||
validation_seed: int | None
|
validation_seed: int | None,
|
||||||
) -> Tuple[List[str], List[Optional[Tuple[int, int]]]]:
|
) -> Tuple[List[str], List[Optional[Tuple[int, int]]]]:
|
||||||
"""
|
"""
|
||||||
Split the dataset into train and validation
|
Split the dataset into train and validation
|
||||||
@@ -1999,11 +1990,7 @@ class DreamBoothDataset(BaseDataset):
|
|||||||
# required for training images dataset of regularization images
|
# required for training images dataset of regularization images
|
||||||
else:
|
else:
|
||||||
img_paths, sizes = split_train_val(
|
img_paths, sizes = split_train_val(
|
||||||
img_paths,
|
img_paths, sizes, self.is_training_dataset, self.validation_split, self.validation_seed
|
||||||
sizes,
|
|
||||||
self.is_training_dataset,
|
|
||||||
self.validation_split,
|
|
||||||
self.validation_seed
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
|
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
|
||||||
@@ -5952,7 +5939,9 @@ def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: tor
|
|||||||
return timesteps
|
return timesteps
|
||||||
|
|
||||||
|
|
||||||
def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor]:
|
def get_noise_noisy_latents_and_timesteps(
|
||||||
|
args, noise_scheduler, latents: torch.FloatTensor
|
||||||
|
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor]:
|
||||||
# Sample noise that we'll add to the latents
|
# Sample noise that we'll add to the latents
|
||||||
noise = torch.randn_like(latents, device=latents.device)
|
noise = torch.randn_like(latents, device=latents.device)
|
||||||
if args.noise_offset:
|
if args.noise_offset:
|
||||||
@@ -6462,12 +6451,16 @@ def init_trackers(accelerator: Accelerator, args: argparse.Namespace, default_tr
|
|||||||
|
|
||||||
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
|
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
|
||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
wandb_tracker = accelerator.get_tracker("wandb", unwrap=True)
|
wandb_tracker = accelerator.get_tracker("wandb", unwrap=True)
|
||||||
|
|
||||||
# Define specific metrics to handle validation and epochs "steps"
|
# Define specific metrics to handle validation and epochs "steps"
|
||||||
wandb_tracker.define_metric("epoch", hidden=True)
|
wandb_tracker.define_metric("epoch", hidden=True)
|
||||||
wandb_tracker.define_metric("val_step", hidden=True)
|
wandb_tracker.define_metric("val_step", hidden=True)
|
||||||
|
|
||||||
|
wandb_tracker.define_metric("global_step", hidden=True)
|
||||||
|
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -119,6 +119,45 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
return logs
|
return logs
|
||||||
|
|
||||||
|
def step_logging(self, accelerator: Accelerator, logs: dict, global_step: int, epoch: int):
|
||||||
|
self.accelerator_logging(accelerator, logs, global_step, global_step, epoch)
|
||||||
|
|
||||||
|
def epoch_logging(self, accelerator: Accelerator, logs: dict, global_step: int, epoch: int):
|
||||||
|
self.accelerator_logging(accelerator, logs, epoch, global_step, epoch)
|
||||||
|
|
||||||
|
def val_logging(self, accelerator: Accelerator, logs: dict, global_step: int, epoch: int, val_step: int):
|
||||||
|
self.accelerator_logging(accelerator, logs, global_step + val_step, global_step, epoch, val_step)
|
||||||
|
|
||||||
|
def accelerator_logging(
|
||||||
|
self, accelerator: Accelerator, logs: dict, step_value: int, global_step: int, epoch: int, val_step: Optional[int] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
step_value is for tensorboard, other values are for wandb
|
||||||
|
"""
|
||||||
|
tensorboard_tracker = None
|
||||||
|
wandb_tracker = None
|
||||||
|
other_trackers = []
|
||||||
|
for tracker in accelerator.trackers:
|
||||||
|
if tracker.name == "tensorboard":
|
||||||
|
tensorboard_tracker = accelerator.get_tracker("tensorboard")
|
||||||
|
elif tracker.name == "wandb":
|
||||||
|
wandb_tracker = accelerator.get_tracker("wandb")
|
||||||
|
else:
|
||||||
|
other_trackers.append(accelerator.get_tracker(tracker.name))
|
||||||
|
|
||||||
|
if tensorboard_tracker is not None:
|
||||||
|
tensorboard_tracker.log(logs, step=step_value)
|
||||||
|
|
||||||
|
if wandb_tracker is not None:
|
||||||
|
logs["global_step"] = global_step
|
||||||
|
logs["epoch"] = epoch
|
||||||
|
if val_step is not None:
|
||||||
|
logs["val_step"] = val_step
|
||||||
|
wandb_tracker.log(logs)
|
||||||
|
|
||||||
|
for tracker in other_trackers:
|
||||||
|
tracker.log(logs, step=step_value)
|
||||||
|
|
||||||
def assert_extra_args(
|
def assert_extra_args(
|
||||||
self,
|
self,
|
||||||
args,
|
args,
|
||||||
@@ -1412,7 +1451,7 @@ class NetworkTrainer:
|
|||||||
logs = self.generate_step_logs(
|
logs = self.generate_step_logs(
|
||||||
args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm
|
args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm
|
||||||
)
|
)
|
||||||
accelerator.log(logs, step=global_step)
|
self.step_logging(accelerator, logs, global_step, epoch + 1)
|
||||||
|
|
||||||
# VALIDATION PER STEP: global_step is already incremented
|
# VALIDATION PER STEP: global_step is already incremented
|
||||||
# for example, if validate_every_n_steps=100, validate at step 100, 200, 300, ...
|
# for example, if validate_every_n_steps=100, validate at step 100, 200, 300, ...
|
||||||
@@ -1428,7 +1467,7 @@ class NetworkTrainer:
|
|||||||
disable=not accelerator.is_local_main_process,
|
disable=not accelerator.is_local_main_process,
|
||||||
desc="validation steps",
|
desc="validation steps",
|
||||||
)
|
)
|
||||||
val_ts_step = 0
|
val_timesteps_step = 0
|
||||||
for val_step, batch in enumerate(val_dataloader):
|
for val_step, batch in enumerate(val_dataloader):
|
||||||
if val_step >= validation_steps:
|
if val_step >= validation_steps:
|
||||||
break
|
break
|
||||||
@@ -1457,20 +1496,18 @@ class NetworkTrainer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
current_loss = loss.detach().item()
|
current_loss = loss.detach().item()
|
||||||
val_step_loss_recorder.add(epoch=epoch, step=val_ts_step, loss=current_loss)
|
val_step_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=current_loss)
|
||||||
val_progress_bar.update(1)
|
val_progress_bar.update(1)
|
||||||
val_progress_bar.set_postfix(
|
val_progress_bar.set_postfix(
|
||||||
{"val_avg_loss": val_step_loss_recorder.moving_average, "timestep": timestep}
|
{"val_avg_loss": val_step_loss_recorder.moving_average, "timestep": timestep}
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_tracking:
|
# if is_tracking:
|
||||||
logs = {"loss/validation/step_current": current_loss}
|
# logs = {f"loss/validation/step_current_{timestep}": current_loss}
|
||||||
accelerator.log(
|
# self.val_logging(accelerator, logs, global_step, epoch + 1, val_step)
|
||||||
logs, step=global_step + val_ts_step
|
|
||||||
) # a bit weird to log with global_step + val_ts_step
|
|
||||||
|
|
||||||
self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
||||||
val_ts_step += 1
|
val_timesteps_step += 1
|
||||||
|
|
||||||
if is_tracking:
|
if is_tracking:
|
||||||
loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average
|
loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average
|
||||||
@@ -1478,7 +1515,7 @@ class NetworkTrainer:
|
|||||||
"loss/validation/step_average": val_step_loss_recorder.moving_average,
|
"loss/validation/step_average": val_step_loss_recorder.moving_average,
|
||||||
"loss/validation/step_divergence": loss_validation_divergence,
|
"loss/validation/step_divergence": loss_validation_divergence,
|
||||||
}
|
}
|
||||||
accelerator.log(logs, step=global_step)
|
self.step_logging(accelerator, logs, global_step, epoch=epoch + 1)
|
||||||
|
|
||||||
restore_rng_state(rng_states)
|
restore_rng_state(rng_states)
|
||||||
args.min_timestep = original_args_min_timestep
|
args.min_timestep = original_args_min_timestep
|
||||||
@@ -1507,7 +1544,7 @@ class NetworkTrainer:
|
|||||||
desc="epoch validation steps",
|
desc="epoch validation steps",
|
||||||
)
|
)
|
||||||
|
|
||||||
val_ts_step = 0
|
val_timesteps_step = 0
|
||||||
for val_step, batch in enumerate(val_dataloader):
|
for val_step, batch in enumerate(val_dataloader):
|
||||||
if val_step >= validation_steps:
|
if val_step >= validation_steps:
|
||||||
break
|
break
|
||||||
@@ -1537,18 +1574,18 @@ class NetworkTrainer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
current_loss = loss.detach().item()
|
current_loss = loss.detach().item()
|
||||||
val_epoch_loss_recorder.add(epoch=epoch, step=val_ts_step, loss=current_loss)
|
val_epoch_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=current_loss)
|
||||||
val_progress_bar.update(1)
|
val_progress_bar.update(1)
|
||||||
val_progress_bar.set_postfix(
|
val_progress_bar.set_postfix(
|
||||||
{"val_epoch_avg_loss": val_epoch_loss_recorder.moving_average, "timestep": timestep}
|
{"val_epoch_avg_loss": val_epoch_loss_recorder.moving_average, "timestep": timestep}
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_tracking:
|
# if is_tracking:
|
||||||
logs = {"loss/validation/epoch_current": current_loss}
|
# logs = {f"loss/validation/epoch_current_{timestep}": current_loss}
|
||||||
accelerator.log(logs, step=global_step + val_ts_step)
|
# self.val_logging(accelerator, logs, global_step, epoch + 1, val_step)
|
||||||
|
|
||||||
self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
||||||
val_ts_step += 1
|
val_timesteps_step += 1
|
||||||
|
|
||||||
if is_tracking:
|
if is_tracking:
|
||||||
avr_loss: float = val_epoch_loss_recorder.moving_average
|
avr_loss: float = val_epoch_loss_recorder.moving_average
|
||||||
@@ -1557,7 +1594,7 @@ class NetworkTrainer:
|
|||||||
"loss/validation/epoch_average": avr_loss,
|
"loss/validation/epoch_average": avr_loss,
|
||||||
"loss/validation/epoch_divergence": loss_validation_divergence,
|
"loss/validation/epoch_divergence": loss_validation_divergence,
|
||||||
}
|
}
|
||||||
accelerator.log(logs, step=epoch + 1)
|
self.epoch_logging(accelerator, logs, global_step, epoch + 1)
|
||||||
|
|
||||||
restore_rng_state(rng_states)
|
restore_rng_state(rng_states)
|
||||||
args.min_timestep = original_args_min_timestep
|
args.min_timestep = original_args_min_timestep
|
||||||
@@ -1569,7 +1606,7 @@ class NetworkTrainer:
|
|||||||
# END OF EPOCH
|
# END OF EPOCH
|
||||||
if is_tracking:
|
if is_tracking:
|
||||||
logs = {"loss/epoch_average": loss_recorder.moving_average}
|
logs = {"loss/epoch_average": loss_recorder.moving_average}
|
||||||
accelerator.log(logs, step=epoch + 1)
|
self.epoch_logging(accelerator, logs, global_step, epoch + 1)
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user