add arbitrary dataset feature to each script

This commit is contained in:
Kohya S
2023-06-15 20:39:39 +09:00
parent f2989b36c2
commit 9806b00f74
6 changed files with 115 additions and 92 deletions

View File

@@ -20,7 +20,13 @@ from library.config_util import (
BlueprintGenerator,
)
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight, prepare_scheduler_for_custom_training, pyramid_noise_like, apply_noise_offset, scale_v_prediction_loss_like_noise_prediction
from library.custom_train_functions import (
apply_snr_weight,
prepare_scheduler_for_custom_training,
pyramid_noise_like,
apply_noise_offset,
scale_v_prediction_loss_like_noise_prediction,
)
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
imagenet_templates_small = [
@@ -88,6 +94,9 @@ def train(args):
print(
"sample_every_n_steps and sample_every_n_epochs are not supported in this script currently / sample_every_n_stepsとsample_every_n_epochsは現在このスクリプトではサポートされていません"
)
assert (
args.dataset_class is None
), "dataset_class is not supported in this script currently / dataset_classは現在このスクリプトではサポートされていません"
cache_latents = args.cache_latents