Fix loss recorder on 0. Fix validation for cached runs. Assert on validation dataset

This commit is contained in:
rockerBOO
2025-01-23 09:57:24 -05:00
parent b489082495
commit c04e5dfe92
8 changed files with 46 additions and 20 deletions

View File

@@ -2,7 +2,7 @@ import argparse
import copy
import math
import random
from typing import Any, Optional
from typing import Any, Optional, Union
import torch
from accelerate import Accelerator
@@ -26,7 +26,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
super().__init__()
self.sample_prompts_te_outputs = None
def assert_extra_args(self, args, train_dataset_group: train_util.DatasetGroup):
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
# super().assert_extra_args(args, train_dataset_group)
# sdxl_train_util.verify_sdxl_training_args(args)
@@ -56,9 +56,14 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません"
train_dataset_group.verify_bucket_reso_steps(32) # TODO check this
if val_dataset_group is not None:
val_dataset_group.verify_bucket_reso_steps(32) # TODO check this
# enumerate resolutions from dataset for positional embeddings
self.resolutions = train_dataset_group.get_resolutions()
resolutions = train_dataset_group.get_resolutions()
if val_dataset_group is not None:
resolutions = resolutions + val_dataset_group.get_resolutions()
self.resolutions = resolutions
def load_target_model(self, args, weight_dtype, accelerator):
# currently offload to cpu for some models