Set is_swapping_blocks before loading_device, add warning for ignoring fp8_scaled if already fp8

This commit is contained in:
rockerBOO
2025-10-10 15:58:21 -04:00
parent 5e366acda4
commit f9710863ca
2 changed files with 14 additions and 3 deletions

View File

@@ -306,11 +306,22 @@ def load_safetensors_with_fp8_optimization(
state_dict[key] = value
continue
original_dtype = value.dtype
if original_dtype in (torch.float8_e4m3fn, torch.float8_e5m2, torch.float8_e4m3fnuz, torch.float8_e5m2fnuz):
logger.warning(
f"Skipping FP8 quantization for key {key} as it is already in FP8 format ({original_dtype}). "
"Loading checkpoint as-is without re-quantization."
)
target_device = calc_device if (calc_device is not None and move_to_device) else original_device
value = value.to(target_device)
state_dict[key] = value
continue
# Move to calculation device
if calc_device is not None:
value = value.to(calc_device)
original_dtype = value.dtype
quantized_weight, scale_tensor = quantize_weight(
key, value, fp8_dtype, max_value, min_value, quantization_mode, block_size
)