mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Fix loss recorder on 0. Fix validation for cached runs. Assert on validation dataset
This commit is contained in:
@@ -3,7 +3,7 @@ import argparse
|
||||
import math
|
||||
import os
|
||||
import typing
|
||||
from typing import Any, List
|
||||
from typing import Any, List, Union, Optional
|
||||
import sys
|
||||
import random
|
||||
import time
|
||||
@@ -124,8 +124,10 @@ class NetworkTrainer:
|
||||
|
||||
return logs
|
||||
|
||||
def assert_extra_args(self, args, train_dataset_group):
|
||||
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
|
||||
train_dataset_group.verify_bucket_reso_steps(64)
|
||||
if val_dataset_group is not None:
|
||||
val_dataset_group.verify_bucket_reso_steps(64)
|
||||
|
||||
def load_target_model(self, args, weight_dtype, accelerator):
|
||||
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
||||
@@ -512,7 +514,7 @@ class NetworkTrainer:
|
||||
val_dataset_group.is_latent_cacheable()
|
||||
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||
|
||||
self.assert_extra_args(args, train_dataset_group) # may change some args
|
||||
self.assert_extra_args(args, train_dataset_group, val_dataset_group) # may change some args
|
||||
|
||||
# acceleratorを準備する
|
||||
logger.info("preparing accelerator")
|
||||
@@ -1414,7 +1416,9 @@ class NetworkTrainer:
|
||||
args,
|
||||
text_encoding_strategy,
|
||||
tokenize_strategy,
|
||||
is_train=False
|
||||
is_train=False,
|
||||
train_text_encoder=False,
|
||||
train_unet=False
|
||||
)
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
@@ -1474,7 +1478,9 @@ class NetworkTrainer:
|
||||
args,
|
||||
text_encoding_strategy,
|
||||
tokenize_strategy,
|
||||
is_train=False
|
||||
is_train=False,
|
||||
train_text_encoder=False,
|
||||
train_unet=False
|
||||
)
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
|
||||
Reference in New Issue
Block a user