init
This commit is contained in:
45
model_executor/layers/quantization/utils/mxfp4_utils.py
Normal file
45
model_executor/layers/quantization/utils/mxfp4_utils.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
OCP_MX_BLOCK_SIZE = 32
|
||||
|
||||
|
||||
def per_token_group_quant_mxfp4(x: torch.Tensor,
|
||||
block_k: int,
|
||||
scale_calculation_mode: str = "even"
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
try:
|
||||
from quark.torch.kernel.hw_emulation.hw_emulation_interface import (
|
||||
fake_quantize_fp4_fp6_per_group_with_scale)
|
||||
from quark.torch.quantization.utils import (even_round,
|
||||
reshape_to_blocks)
|
||||
except ImportError as err:
|
||||
raise ImportError("The package `amd-quark` is required to use "
|
||||
"MX-FP4 models. Please install it with `pip install "
|
||||
"amd-quark`.") from err
|
||||
|
||||
axis = -1
|
||||
block_x = reshape_to_blocks(x, block_k, axis)
|
||||
amax, _ = torch.max(torch.abs(block_x), dim=-1, keepdim=True)
|
||||
amax = amax.squeeze(-1)
|
||||
|
||||
# TODO: there are other rounding strategies supported in quark and in the
|
||||
# config.json that we do not check for here!
|
||||
if scale_calculation_mode != "even":
|
||||
raise NotImplementedError(
|
||||
f"Scale calculation mode {scale_calculation_mode} is not yet "
|
||||
"supported in MX-FP4 quantization")
|
||||
scale = even_round(amax, "fp4")
|
||||
|
||||
# Apply dequantize(quantize(x)).
|
||||
x = fake_quantize_fp4_fp6_per_group_with_scale(
|
||||
x,
|
||||
scale.to(x.device),
|
||||
axis=axis,
|
||||
group_size=block_k,
|
||||
quant_dtype="fp4",
|
||||
)
|
||||
|
||||
return x, scale
|
||||
Reference in New Issue
Block a user