Sync from v0.13
This commit is contained in:
189
vllm/model_executor/layers/quantization/utils/mxfp4_utils.py
Normal file
189
vllm/model_executor/layers/quantization/utils/mxfp4_utils.py
Normal file
@@ -0,0 +1,189 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.utils.import_utils import has_triton_kernels
|
||||
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _swizzle_mxfp4(quant_tensor, scale, num_warps):
|
||||
"""weight swizzle for mxfp4 moe, used for OAI mxfp4 kernel"""
|
||||
assert has_triton_kernels()
|
||||
import triton_kernels.matmul_ogs_details.opt_flags as opt_flags
|
||||
from triton_kernels.numerics import InFlexData
|
||||
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
|
||||
from triton_kernels.tensor_details import layout
|
||||
from triton_kernels.tensor_details.layout import StridedLayout
|
||||
|
||||
value_layout_opts: dict[str, Any] = {}
|
||||
scale_layout_opts: dict[str, Any] = {}
|
||||
|
||||
if (
|
||||
current_platform.is_cuda()
|
||||
and current_platform.is_device_capability(90)
|
||||
and not is_torch_equal_or_newer("2.8.1")
|
||||
):
|
||||
logger.warning_once(
|
||||
"Mxfp4 on hopper is running on torch < 2.8.1, "
|
||||
"this cause swizling to be disabled, which may "
|
||||
"cause performance degradation. Please upgrade to torch nightly"
|
||||
)
|
||||
value_layout = StridedLayout
|
||||
scale_layout = StridedLayout
|
||||
elif current_platform.is_rocm():
|
||||
from vllm.platforms.rocm import on_gfx950
|
||||
|
||||
value_layout = StridedLayout
|
||||
if on_gfx950():
|
||||
from triton_kernels.tensor_details.layout import GFX950MXScaleLayout
|
||||
|
||||
scale_layout = GFX950MXScaleLayout
|
||||
else:
|
||||
scale_layout = StridedLayout
|
||||
else:
|
||||
value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout(
|
||||
mx_axis=1
|
||||
)
|
||||
scale_layout, scale_layout_opts = (
|
||||
layout.make_default_matmul_mxfp4_w_scale_layout(
|
||||
mx_axis=1, num_warps=num_warps
|
||||
)
|
||||
)
|
||||
if current_platform.is_cuda():
|
||||
if current_platform.is_device_capability(90):
|
||||
constraints = {
|
||||
"split_k": 1,
|
||||
}
|
||||
opt_flags.update_opt_flags_constraints(constraints)
|
||||
elif current_platform.is_device_capability_family(100):
|
||||
constraints = {
|
||||
"is_persistent": True,
|
||||
"epilogue_subtile": 1,
|
||||
}
|
||||
opt_flags.update_opt_flags_constraints(constraints)
|
||||
# transpose the tensor so that the quantization axis is on dim1
|
||||
quant_tensor = quant_tensor.transpose(-2, -1)
|
||||
scale = scale.transpose(-2, -1)
|
||||
quant_tensor = convert_layout(
|
||||
wrap_torch_tensor(quant_tensor, dtype=FP4), value_layout, **value_layout_opts
|
||||
)
|
||||
scale = convert_layout(wrap_torch_tensor(scale), scale_layout, **scale_layout_opts)
|
||||
return quant_tensor, InFlexData(), scale
|
||||
|
||||
|
||||
def _can_support_mxfp4(
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
scoring_func: str = "softmax",
|
||||
activation: str = "swigluoai",
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
):
|
||||
return not (
|
||||
use_grouped_topk
|
||||
or topk_group
|
||||
or num_expert_group
|
||||
or custom_routing_function
|
||||
or e_score_correction_bias
|
||||
or apply_router_weight_on_input
|
||||
or scoring_func != "softmax"
|
||||
or activation != "swigluoai"
|
||||
or expert_load_view
|
||||
or logical_to_physical_map
|
||||
or logical_replica_count
|
||||
)
|
||||
|
||||
|
||||
def get_padding_alignment():
|
||||
return (
|
||||
256
|
||||
if triton.runtime.driver.active.get_current_target().arch in ("gfx950",)
|
||||
else 128
|
||||
)
|
||||
|
||||
|
||||
def _dequant_mxfp4(
|
||||
x: torch.Tensor, scale: torch.Tensor, float_dtype: torch.dtype
|
||||
) -> torch.Tensor:
|
||||
try:
|
||||
from quark.torch.kernel import mx
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"The package `amd-quark` is required to use "
|
||||
"MX-FP4 models. Please install it with `pip install "
|
||||
"amd-quark`."
|
||||
) from err
|
||||
|
||||
return mx.dq_mxfp4(x, scale, float_dtype)
|
||||
|
||||
|
||||
def _dequant_mxfp4_fake(
|
||||
x: torch.Tensor, scale: torch.Tensor, float_dtype: torch.dtype
|
||||
) -> torch.Tensor:
|
||||
return torch.empty(
|
||||
(*x.shape[:-1], x.shape[-1] * 2), dtype=float_dtype, device=x.device
|
||||
)
|
||||
|
||||
|
||||
def _quant_dequant_mxfp4(
|
||||
x: torch.Tensor, scale_calculation_mode: str = "even"
|
||||
) -> torch.Tensor:
|
||||
try:
|
||||
from quark.torch.kernel import mx
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"The package `amd-quark` is required to use "
|
||||
"MX-FP4 models. Please install it with `pip install "
|
||||
"amd-quark`."
|
||||
) from err
|
||||
|
||||
return mx.qdq_mxfp4(x, scale_calculation_mode)
|
||||
|
||||
|
||||
def _quant_dequant_mxfp4_fake(
|
||||
x: torch.Tensor, scale_calculation_mode: str = "even"
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
|
||||
# Protect these operations into a torch custom op to avoid errors as
|
||||
# torch._dynamo.exc.Unsupported: Attempted to call function marked as skipped
|
||||
# Explanation: Dynamo does not know how to trace the builtin
|
||||
# `kernel_ext.PyCapsule.dq_uint8_mxfp4_to_half.` This function is either a
|
||||
# Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python
|
||||
# extension (perhaps created with pybind).
|
||||
# TODO: Make sure there is no way to avoid having these functions
|
||||
# marked as skipped by dynamo.
|
||||
try:
|
||||
direct_register_custom_op(
|
||||
op_name="dequant_mxfp4",
|
||||
op_func=_dequant_mxfp4,
|
||||
fake_impl=_dequant_mxfp4_fake,
|
||||
)
|
||||
dequant_mxfp4 = torch.ops.vllm.dequant_mxfp4
|
||||
except AttributeError as error:
|
||||
raise error
|
||||
|
||||
try:
|
||||
direct_register_custom_op(
|
||||
op_name="quant_dequant_mxfp4",
|
||||
op_func=_quant_dequant_mxfp4,
|
||||
fake_impl=_quant_dequant_mxfp4_fake,
|
||||
)
|
||||
quant_dequant_mxfp4 = torch.ops.vllm.quant_dequant_mxfp4
|
||||
except AttributeError as error:
|
||||
raise error
|
||||
Reference in New Issue
Block a user