From f9710863ca0e80d3d781c2f04ef1a23e03d9fd90 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 10 Oct 2025 15:58:21 -0400 Subject: [PATCH] Set is_swapping_blocks before loading_device, add warning for ignoring fp8_scaled if already fp8 --- flux_train_network.py | 4 ++-- library/fp8_optimization_utils.py | 13 ++++++++++++- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index cfc61708..8620a6f2 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -99,6 +99,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): def load_target_model(self, args, weight_dtype, accelerator): # currently offload to cpu for some models + self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 + # if the file is fp8 and we are using fp8_base, we can load it as is (fp8) loading_dtype = None if args.fp8_base else weight_dtype @@ -125,8 +127,6 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): # if args.split_mode: # model = self.prepare_split_model(model, weight_dtype, accelerator) - - self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 if self.is_swapping_blocks: # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") diff --git a/library/fp8_optimization_utils.py b/library/fp8_optimization_utils.py index 02f99ab6..9ea62a58 100644 --- a/library/fp8_optimization_utils.py +++ b/library/fp8_optimization_utils.py @@ -306,11 +306,22 @@ def load_safetensors_with_fp8_optimization( state_dict[key] = value continue + original_dtype = value.dtype + + if original_dtype in (torch.float8_e4m3fn, torch.float8_e5m2, torch.float8_e4m3fnuz, torch.float8_e5m2fnuz): + logger.warning( + f"Skipping FP8 quantization for key {key} as it is already in FP8 format ({original_dtype}). " + "Loading checkpoint as-is without re-quantization." + ) + target_device = calc_device if (calc_device is not None and move_to_device) else original_device + value = value.to(target_device) + state_dict[key] = value + continue + # Move to calculation device if calc_device is not None: value = value.to(calc_device) - original_dtype = value.dtype quantized_weight, scale_tensor = quantize_weight( key, value, fp8_dtype, max_value, min_value, quantization_mode, block_size )