Add CDC-FM parameters to model metadata

- Add ss_use_cdc_fm, ss_cdc_k_neighbors, ss_cdc_k_bandwidth, ss_cdc_d_cdc, ss_cdc_gamma
- Ensures CDC-FM training parameters are tracked in model metadata
- Enables reproducibility and model provenance tracking
This commit is contained in:
rockerBOO
2025-10-09 22:51:47 -04:00
parent 20c6ae5a9a
commit f450443fe4
2 changed files with 10 additions and 3 deletions

View File

@@ -461,6 +461,13 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
metadata["ss_model_prediction_type"] = args.model_prediction_type
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
# CDC-FM metadata
metadata["ss_use_cdc_fm"] = getattr(args, "use_cdc_fm", False)
metadata["ss_cdc_k_neighbors"] = getattr(args, "cdc_k_neighbors", None)
metadata["ss_cdc_k_bandwidth"] = getattr(args, "cdc_k_bandwidth", None)
metadata["ss_cdc_d_cdc"] = getattr(args, "cdc_d_cdc", None)
metadata["ss_cdc_gamma"] = getattr(args, "cdc_gamma", None)
def is_text_encoder_not_needed_for_training(self, args):
return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)

View File

@@ -652,7 +652,7 @@ class NetworkTrainer:
if val_dataset_group is not None:
self.cache_text_encoder_outputs_if_needed(args, accelerator, unet, vae, text_encoders, val_dataset_group, weight_dtype)
if unet is none:
if unet is None:
# lazy load unet if needed. text encoders may be freed or replaced with dummy models for saving memory
unet, text_encoders = self.load_unet_lazily(args, weight_dtype, accelerator, text_encoders)
@@ -661,10 +661,10 @@ class NetworkTrainer:
accelerator.print("import network module:", args.network_module)
network_module = importlib.import_module(args.network_module)
if args.base_weights is not none:
if args.base_weights is not None:
# base_weights が指定されている場合は、指定された重みを読み込みマージする
for i, weight_path in enumerate(args.base_weights):
if args.base_weights_multiplier is none or len(args.base_weights_multiplier) <= i:
if args.base_weights_multiplier is None or len(args.base_weights_multiplier) <= i:
multiplier = 1.0
else:
multiplier = args.base_weights_multiplier[i]