Files
2026-01-19 10:38:50 +08:00

25 lines
754 B
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.logger import init_logger
logger = init_logger(__name__)
def mxfp8_e4m3_quantize(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
try:
from flashinfer import mxfp8_quantize as mxfp8_e4m3_quantize
except ImportError as err:
raise ImportError(
"The package `flashinfer` is required to do "
"MX-FP8 quantization. Please install it with"
"`pip install flashinfer`"
) from err
x_q, x_scales = mxfp8_e4m3_quantize(x, is_sf_swizzled_layout=False)
if x_scales.ndim == 1:
x_scales = x_scales.view(x.size(0), -1)
return x_q, x_scales