Update train_network.py

This commit is contained in:
DKnight54
2025-05-26 01:46:17 +08:00
committed by GitHub
parent 61870e36df
commit 6f04b129df

View File

@@ -210,13 +210,12 @@ class NetworkTrainer:
current_step = Value("i", 0)
if args.incremental_reg_reload:
train_dataset_group.set_reg_reload(args.incremental_reg_reload)
if args.persistent_data_loader_workers:
logger.warning("persistent_data_loader_workers has been set to False because incremental_reg_reload is enabled.")
args.persistent_data_loader_workers = False
if args.randomized_regularization_image:
# train_dataset_group.set_reg_randomize() triggers a reload to initial state with randomized regularization images. Ensure that this occurs before initial caching to prevent data mismatch
logger.info("Reloading sequentially loaded regularization images to replace with randomly selected regularization images...")
train_dataset_group.set_reg_randomize(args.randomized_regularization_image)
if args.debug_dataset:
@@ -378,9 +377,6 @@ class NetworkTrainer:
# dataloaderを準備する
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
if args.incremental_reg_reload:
logger.warning("incremental_reg_reload = True. Incremental reloading of Regularization Images requires persistent_data_loader_workers = false, overriding.")
args.persistent_data_loader_workers = False
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers