Sync from v0.13
This commit is contained in:
24
vllm/model_executor/layers/quantization/utils/mxfp8_utils.py
Normal file
24
vllm/model_executor/layers/quantization/utils/mxfp8_utils.py
Normal file
@@ -0,0 +1,24 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def mxfp8_e4m3_quantize(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
try:
|
||||
from flashinfer import mxfp8_quantize as mxfp8_e4m3_quantize
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"The package `flashinfer` is required to do "
|
||||
"MX-FP8 quantization. Please install it with"
|
||||
"`pip install flashinfer`"
|
||||
) from err
|
||||
|
||||
x_q, x_scales = mxfp8_e4m3_quantize(x, is_sf_swizzled_layout=False)
|
||||
if x_scales.ndim == 1:
|
||||
x_scales = x_scales.view(x.size(0), -1)
|
||||
return x_q, x_scales
|
||||
Reference in New Issue
Block a user