Fix loss recorder on 0. Fix validation for cached runs. Assert on validation dataset

This commit is contained in:
rockerBOO
2025-01-23 09:57:24 -05:00
parent b489082495
commit c04e5dfe92
8 changed files with 46 additions and 20 deletions

View File

@@ -1,5 +1,5 @@
import argparse
from typing import List, Optional
from typing import List, Optional, Union
import torch
from accelerate import Accelerator
@@ -23,8 +23,8 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR
self.is_sdxl = True
def assert_extra_args(self, args, train_dataset_group):
super().assert_extra_args(args, train_dataset_group)
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
sdxl_train_util.verify_sdxl_training_args(args)
if args.cache_text_encoder_outputs:
@@ -37,6 +37,8 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません"
train_dataset_group.verify_bucket_reso_steps(32)
if val_dataset_group is not None:
val_dataset_group.verify_bucket_reso_steps(32)
def load_target_model(self, args, weight_dtype, accelerator):
(