Fix loss recorder on 0. Fix validation for cached runs. Assert on validation dataset

This commit is contained in:
rockerBOO
2025-01-23 09:57:24 -05:00
parent b489082495
commit c04e5dfe92
8 changed files with 46 additions and 20 deletions

View File

@@ -3,7 +3,7 @@ import argparse
import math
import os
import typing
from typing import Any, List
from typing import Any, List, Union, Optional
import sys
import random
import time
@@ -124,8 +124,10 @@ class NetworkTrainer:
return logs
def assert_extra_args(self, args, train_dataset_group):
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
train_dataset_group.verify_bucket_reso_steps(64)
if val_dataset_group is not None:
val_dataset_group.verify_bucket_reso_steps(64)
def load_target_model(self, args, weight_dtype, accelerator):
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
@@ -512,7 +514,7 @@ class NetworkTrainer:
val_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
self.assert_extra_args(args, train_dataset_group) # may change some args
self.assert_extra_args(args, train_dataset_group, val_dataset_group) # may change some args
# acceleratorを準備する
logger.info("preparing accelerator")
@@ -1414,7 +1416,9 @@ class NetworkTrainer:
args,
text_encoding_strategy,
tokenize_strategy,
is_train=False
is_train=False,
train_text_encoder=False,
train_unet=False
)
current_loss = loss.detach().item()
@@ -1474,7 +1478,9 @@ class NetworkTrainer:
args,
text_encoding_strategy,
tokenize_strategy,
is_train=False
is_train=False,
train_text_encoder=False,
train_unet=False
)
current_loss = loss.detach().item()