# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch from vllm.logger import init_logger logger = init_logger(__name__) def mxfp8_e4m3_quantize(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: try: from flashinfer import mxfp8_quantize as mxfp8_e4m3_quantize except ImportError as err: raise ImportError( "The package `flashinfer` is required to do " "MX-FP8 quantization. Please install it with" "`pip install flashinfer`" ) from err x_q, x_scales = mxfp8_e4m3_quantize(x, is_sf_swizzled_layout=False) if x_scales.ndim == 1: x_scales = x_scales.view(x.size(0), -1) return x_q, x_scales