143 lines
4.4 KiB
Python
143 lines
4.4 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import torch
|
|
|
|
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_BLOCK_SIZE
|
|
from vllm.utils.torch_utils import direct_register_custom_op
|
|
|
|
|
|
def _quant_dequant_mxfp6(
|
|
x: torch.Tensor,
|
|
quant_dtype: str,
|
|
scale_calculation_mode: str = "even",
|
|
) -> torch.Tensor:
|
|
try:
|
|
from quark.torch.kernel.hw_emulation.hw_emulation_interface import (
|
|
fake_quantize_fp4_fp6_per_group_with_scale,
|
|
)
|
|
from quark.torch.quantization.utils import even_round, reshape_to_blocks
|
|
except ImportError as err:
|
|
raise ImportError(
|
|
"The package `amd-quark` is required to use "
|
|
"MX-FP6 models. Please install it with `pip install "
|
|
"amd-quark`."
|
|
) from err
|
|
|
|
axis = -1
|
|
block_x = reshape_to_blocks(x, OCP_MX_BLOCK_SIZE, axis)
|
|
amax, _ = torch.max(torch.abs(block_x), dim=-1, keepdim=True)
|
|
amax = amax.squeeze(-1)
|
|
|
|
# TODO: there are other rounding strategies supported in quark and in the
|
|
# config.json that we do not check for here!
|
|
if scale_calculation_mode != "even":
|
|
raise NotImplementedError(
|
|
f"Scale calculation mode {scale_calculation_mode} is not yet "
|
|
"supported in MX-FP6 quantization"
|
|
)
|
|
scale = even_round(amax, quant_dtype)
|
|
|
|
# Apply dequantize(quantize(x)).
|
|
x = fake_quantize_fp4_fp6_per_group_with_scale(
|
|
x,
|
|
scale.to(x.device),
|
|
axis=axis,
|
|
group_size=OCP_MX_BLOCK_SIZE,
|
|
quant_dtype=quant_dtype,
|
|
)
|
|
|
|
return x
|
|
|
|
|
|
def _quant_dequant_mxfp6_fake(
|
|
x: torch.Tensor,
|
|
quant_dtype: str,
|
|
scale_calculation_mode: str = "even",
|
|
) -> torch.Tensor:
|
|
return torch.empty_like(x)
|
|
|
|
|
|
def _dequant_mxfp6(
|
|
x: torch.Tensor, scale: torch.Tensor, float_dtype: torch.dtype, quant_dtype: str
|
|
) -> torch.Tensor:
|
|
try:
|
|
from quark.torch.kernel.hw_emulation.hw_emulation_interface import (
|
|
dequantize_fp4_fp6_per_group,
|
|
)
|
|
from quark.torch.utils.pack import create_pack_method
|
|
except ImportError as e:
|
|
raise ImportError(
|
|
"The package `amd-quark` is required to use "
|
|
"MX-FP6 models. Please install it with `pip install "
|
|
"amd-quark`."
|
|
) from e
|
|
|
|
pack_method = create_pack_method(None, dtype=quant_dtype)
|
|
unpacked_x = pack_method.unpack(x, reorder=False)
|
|
|
|
scale = 2 ** (scale.view(torch.uint8).to(torch.int16) - 127).to(float_dtype)
|
|
|
|
# TODO: `dequantize_fp4_fp6_per_group` and `prepare_inputs_per_group`
|
|
# always return fp32.
|
|
return dequantize_fp4_fp6_per_group(
|
|
unpacked_x,
|
|
scale,
|
|
axis=-1,
|
|
group_size=OCP_MX_BLOCK_SIZE,
|
|
quant_dtype=quant_dtype,
|
|
).to(float_dtype)
|
|
|
|
|
|
def _dequant_mxfp6_fake(
|
|
x: torch.Tensor, scale: torch.Tensor, float_dtype: torch.dtype, quant_dtype: str
|
|
) -> torch.Tensor:
|
|
assert (x.shape[-1] * 4) % 3 == 0
|
|
return torch.empty(
|
|
(*x.shape[:-1], (x.shape[-1] * 4) // 3), dtype=float_dtype, device=x.device
|
|
)
|
|
|
|
|
|
# Protect these operations into a torch custom op to avoid errors as
|
|
# torch._dynamo.exc.Unsupported: Attempted to call function marked as skipped
|
|
# Explanation: Dynamo does not know how to trace the builtin
|
|
# `kernel_ext.PyCapsule.dq_uint8_mxfp4_to_half.` This function is either a
|
|
# Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python
|
|
# extension (perhaps created with pybind).
|
|
# TODO: Make sure there is no way to avoid having these functions
|
|
# marked as skipped by dynamo.
|
|
try:
|
|
direct_register_custom_op(
|
|
op_name="quant_dequant_mxfp6",
|
|
op_func=_quant_dequant_mxfp6,
|
|
mutates_args=[],
|
|
fake_impl=_quant_dequant_mxfp6_fake,
|
|
)
|
|
except AttributeError as error:
|
|
raise error
|
|
|
|
|
|
# Expose keyword arguments.
|
|
def quant_dequant_mxfp6(
|
|
x: torch.Tensor,
|
|
quant_dtype: str,
|
|
scale_calculation_mode: str = "even",
|
|
) -> torch.Tensor:
|
|
return torch.ops.vllm.quant_dequant_mxfp6(x, quant_dtype, scale_calculation_mode)
|
|
|
|
|
|
try:
|
|
direct_register_custom_op(
|
|
op_name="dequant_mxfp6",
|
|
op_func=_dequant_mxfp6,
|
|
mutates_args=[],
|
|
fake_impl=_dequant_mxfp6_fake,
|
|
)
|
|
except AttributeError as error:
|
|
raise error
|
|
|
|
|
|
def dequant_mxfp6(
|
|
x: torch.Tensor, scale: torch.Tensor, float_dtype: torch.dtype, quant_dtype: str
|
|
) -> torch.Tensor:
|
|
return torch.ops.vllm.dequant_mxfp6(x, scale, float_dtype, quant_dtype)
|