mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 06:54:17 +00:00
Set is_swapping_blocks before loading_device, add warning for ignoring fp8_scaled if already fp8
This commit is contained in:
@@ -99,6 +99,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
def load_target_model(self, args, weight_dtype, accelerator):
|
||||
# currently offload to cpu for some models
|
||||
|
||||
self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
|
||||
|
||||
# if the file is fp8 and we are using fp8_base, we can load it as is (fp8)
|
||||
loading_dtype = None if args.fp8_base else weight_dtype
|
||||
|
||||
@@ -125,8 +127,6 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
|
||||
# if args.split_mode:
|
||||
# model = self.prepare_split_model(model, weight_dtype, accelerator)
|
||||
|
||||
self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
|
||||
if self.is_swapping_blocks:
|
||||
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
|
||||
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
|
||||
|
||||
@@ -306,11 +306,22 @@ def load_safetensors_with_fp8_optimization(
|
||||
state_dict[key] = value
|
||||
continue
|
||||
|
||||
original_dtype = value.dtype
|
||||
|
||||
if original_dtype in (torch.float8_e4m3fn, torch.float8_e5m2, torch.float8_e4m3fnuz, torch.float8_e5m2fnuz):
|
||||
logger.warning(
|
||||
f"Skipping FP8 quantization for key {key} as it is already in FP8 format ({original_dtype}). "
|
||||
"Loading checkpoint as-is without re-quantization."
|
||||
)
|
||||
target_device = calc_device if (calc_device is not None and move_to_device) else original_device
|
||||
value = value.to(target_device)
|
||||
state_dict[key] = value
|
||||
continue
|
||||
|
||||
# Move to calculation device
|
||||
if calc_device is not None:
|
||||
value = value.to(calc_device)
|
||||
|
||||
original_dtype = value.dtype
|
||||
quantized_weight, scale_tensor = quantize_weight(
|
||||
key, value, fp8_dtype, max_value, min_value, quantization_mode, block_size
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user