mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-17 01:12:41 +00:00
Add CDC-FM parameters to model metadata
- Add ss_use_cdc_fm, ss_cdc_k_neighbors, ss_cdc_k_bandwidth, ss_cdc_d_cdc, ss_cdc_gamma - Ensures CDC-FM training parameters are tracked in model metadata - Enables reproducibility and model provenance tracking
This commit is contained in:
@@ -461,6 +461,13 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
metadata["ss_model_prediction_type"] = args.model_prediction_type
|
||||
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
|
||||
|
||||
# CDC-FM metadata
|
||||
metadata["ss_use_cdc_fm"] = getattr(args, "use_cdc_fm", False)
|
||||
metadata["ss_cdc_k_neighbors"] = getattr(args, "cdc_k_neighbors", None)
|
||||
metadata["ss_cdc_k_bandwidth"] = getattr(args, "cdc_k_bandwidth", None)
|
||||
metadata["ss_cdc_d_cdc"] = getattr(args, "cdc_d_cdc", None)
|
||||
metadata["ss_cdc_gamma"] = getattr(args, "cdc_gamma", None)
|
||||
|
||||
def is_text_encoder_not_needed_for_training(self, args):
|
||||
return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
|
||||
|
||||
|
||||
@@ -652,7 +652,7 @@ class NetworkTrainer:
|
||||
if val_dataset_group is not None:
|
||||
self.cache_text_encoder_outputs_if_needed(args, accelerator, unet, vae, text_encoders, val_dataset_group, weight_dtype)
|
||||
|
||||
if unet is none:
|
||||
if unet is None:
|
||||
# lazy load unet if needed. text encoders may be freed or replaced with dummy models for saving memory
|
||||
unet, text_encoders = self.load_unet_lazily(args, weight_dtype, accelerator, text_encoders)
|
||||
|
||||
@@ -661,10 +661,10 @@ class NetworkTrainer:
|
||||
accelerator.print("import network module:", args.network_module)
|
||||
network_module = importlib.import_module(args.network_module)
|
||||
|
||||
if args.base_weights is not none:
|
||||
if args.base_weights is not None:
|
||||
# base_weights が指定されている場合は、指定された重みを読み込みマージする
|
||||
for i, weight_path in enumerate(args.base_weights):
|
||||
if args.base_weights_multiplier is none or len(args.base_weights_multiplier) <= i:
|
||||
if args.base_weights_multiplier is None or len(args.base_weights_multiplier) <= i:
|
||||
multiplier = 1.0
|
||||
else:
|
||||
multiplier = args.base_weights_multiplier[i]
|
||||
|
||||
Reference in New Issue
Block a user