Set is_swapping_blocks before loading_device, add warning for ignoring fp8_scaled if already fp8

This commit is contained in:
rockerBOO
2025-10-10 15:58:21 -04:00
parent 5e366acda4
commit f9710863ca
2 changed files with 14 additions and 3 deletions

View File

@@ -99,6 +99,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
def load_target_model(self, args, weight_dtype, accelerator):
# currently offload to cpu for some models
self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
# if the file is fp8 and we are using fp8_base, we can load it as is (fp8)
loading_dtype = None if args.fp8_base else weight_dtype
@@ -125,8 +127,6 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
# if args.split_mode:
# model = self.prepare_split_model(model, weight_dtype, accelerator)
self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
if self.is_swapping_blocks:
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")

View File

@@ -306,11 +306,22 @@ def load_safetensors_with_fp8_optimization(
state_dict[key] = value
continue
original_dtype = value.dtype
if original_dtype in (torch.float8_e4m3fn, torch.float8_e5m2, torch.float8_e4m3fnuz, torch.float8_e5m2fnuz):
logger.warning(
f"Skipping FP8 quantization for key {key} as it is already in FP8 format ({original_dtype}). "
"Loading checkpoint as-is without re-quantization."
)
target_device = calc_device if (calc_device is not None and move_to_device) else original_device
value = value.to(target_device)
state_dict[key] = value
continue
# Move to calculation device
if calc_device is not None:
value = value.to(calc_device)
original_dtype = value.dtype
quantized_weight, scale_tensor = quantize_weight(
key, value, fp8_dtype, max_value, min_value, quantization_mode, block_size
)