Sync from v0.13

This commit is contained in:
2026-01-19 10:38:50 +08:00
parent b2ef04d792
commit 5aef6c175a
3714 changed files with 854317 additions and 89342 deletions

View File

@@ -0,0 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .quark_ocp_mx import QuarkOCP_MX
from .quark_scheme import QuarkScheme
from .quark_w8a8_fp8 import QuarkW8A8Fp8
from .quark_w8a8_int8 import QuarkW8A8Int8
__all__ = ["QuarkScheme", "QuarkW8A8Fp8", "QuarkW8A8Int8", "QuarkOCP_MX"]

View File

@@ -0,0 +1,343 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from fractions import Fraction
from functools import cache, partial
from typing import Any
import torch
import torch.nn.functional as F
from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
dequant_mxfp4,
quant_dequant_mxfp4,
)
from vllm.model_executor.layers.quantization.utils.mxfp6_utils import (
dequant_mxfp6,
quant_dequant_mxfp6,
)
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
OCP_MX_BLOCK_SIZE,
OCP_MX_Scheme,
)
from vllm.model_executor.parameter import GroupQuantScaleParameter, PackedvLLMParameter
from vllm.platforms import current_platform
from .quark_scheme import QuarkScheme
logger = init_logger(__name__)
# TODO: move registration of custom op to aiter_ops.py
# `from vllm._aiter_ops import rocm_aiter_ops`
# use `rocm_aiter_ops.is_asm_fp4_gemm_dynamic_quant_enabled()`
# for envs checks which does not require @cache anymore.
# triton kernel is torch compile compatible.
# does not require direct registration.
# use `rocm_aiter_ops.triton_fp4_gemm_dynamic_qaunt`.
@cache
def is_rocm_aiter_fp4_asm_gemm_enabled() -> bool:
return (
current_platform.is_rocm()
and envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
and envs.VLLM_ROCM_USE_AITER
)
try:
from aiter.ops.shuffle import shuffle_weight
from aiter.ops.triton.gemm_afp4wfp4 import (
gemm_afp4wfp4,
gemm_afp4wfp4_preshuffled_weight_scales,
)
from aiter.ops.triton.quant import dynamic_mxfp4_quant
from vllm.utils.torch_utils import direct_register_custom_op
if is_rocm_aiter_fp4_asm_gemm_enabled():
from aiter import gemm_a4w4, per_1x32_f4_quant_hip
def gemm_with_dynamic_quant(
x: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
rocm_use_aiter_fp4_asm_gemm: bool = False,
out_dtype: torch.dtype | None = torch.bfloat16,
x_scales: torch.Tensor | None = None,
) -> torch.Tensor:
M = x.shape[0]
N = weight.shape[0]
K = weight.shape[1]
if rocm_use_aiter_fp4_asm_gemm:
if M <= 64 and rocm_aiter_ops.is_triton_gemm_afp4wfp4_presh_ws_tuned(N, K):
if x_scales is None:
# use hip quant kernel for performance
if M >= 32:
x_q, x_s = per_1x32_f4_quant_hip(x, shuffle=True)
else:
x_q, x_s = per_1x32_f4_quant_hip(x, shuffle=False)
else:
x_q = x
x_s = x_scales
if M >= 32:
x_s = x_s.view(torch.uint8).view(x_s.shape[0] // 32, -1)
else:
x_s = x_s[:M, ...].view(torch.uint8)
y = torch.empty(M, N, device=x_q.device, dtype=out_dtype)
gemm_afp4wfp4_preshuffled_weight_scales(
x_q.view(torch.uint8),
weight.view(torch.uint8).view(weight.shape[0] // 16, -1),
x_s,
weight_scale.view(torch.uint8).view(
weight_scale.shape[0] // 32, -1
),
out_dtype,
y,
)
else:
if x_scales is None:
# use hip quant kernel for performance
x_q, x_s = per_1x32_f4_quant_hip(x, shuffle=True)
else:
x_q = x
x_s = x_scales
# 32 alignment is enough for dim0 padding of output for
# gemm_a4w4 kernel
y = torch.empty(
(M + 31) // 32 * 32,
weight.shape[0],
device=x_q.device,
dtype=out_dtype,
)
gemm_a4w4(
x_q, weight, x_s, weight_scale.view(x_s.dtype), y, bpreshuffle=True
)
return y[:M]
else:
if x_scales is None:
x_q, x_s = dynamic_mxfp4_quant(x)
else:
x_q = x
x_s = x_scales
y = torch.empty(
x_q.shape[0], weight.shape[0], device=x_q.device, dtype=out_dtype
)
gemm_afp4wfp4(x_q, weight, x_s, weight_scale.T, out_dtype, y)
return y
def gemm_with_dynamic_quant_fake(
x: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
x_scales: torch.Tensor = None,
rocm_use_aiter_fp4_asm_gemm: bool = False,
out_dtype: torch.dtype | None = torch.bfloat16,
) -> torch.Tensor:
return torch.empty(
(*x.shape[:-1], weight.shape[0]), dtype=out_dtype, device=x.device
)
direct_register_custom_op(
op_name="gemm_with_dynamic_quant",
op_func=gemm_with_dynamic_quant,
mutates_args=[],
fake_impl=gemm_with_dynamic_quant_fake,
dispatch_key=current_platform.dispatch_key,
)
except (ImportError, AttributeError):
dynamic_mxfp4_quant = gemm_afp4wfp4 = None
class QuarkOCP_MX(QuarkScheme):
def __init__(
self, weight_quant_spec: dict[str, Any], input_quant_spec: dict[str, Any]
):
self.out_dtype = torch.get_default_dtype()
self.qscheme = "per_group"
self.weight_quant_spec = weight_quant_spec
self.input_quant_spec = input_quant_spec
self.weight_dtype = weight_quant_spec["dtype"].replace("fp", "mxfp")
self.input_dtype = input_quant_spec["dtype"].replace("fp", "mxfp")
self.ocp_mx_scheme = OCP_MX_Scheme.from_quant_dtype(
self.input_dtype, self.weight_dtype
)
if self.weight_dtype == "mxfp4":
self.packed_factor: int | Fraction = 2
self.dequant_func = dequant_mxfp4
else:
self.packed_factor = Fraction(numerator=8, denominator=6)
self.dequant_func = partial(
dequant_mxfp6, quant_dtype=self.weight_dtype.replace("mx", "")
)
if self.input_dtype == "mxfp4":
self.quant_dequant_func = quant_dequant_mxfp4
else:
self.quant_dequant_func = partial(
quant_dequant_mxfp6, quant_dtype=self.input_dtype.replace("mx", "")
)
self.static_input_scales = not input_quant_spec.get("is_dynamic")
if self.static_input_scales:
raise NotImplementedError(
"QuarkOCP_MX with static input scales is currently not "
"implemented. Please open an issue."
)
# TODO: integrate (or test) mixed-precision kernel.
self.emulate = not current_platform.supports_mx() or (
self.input_dtype != "mxfp4" or self.weight_dtype != "mxfp4"
)
self.rocm_use_aiter_fp4_asm_gemm = is_rocm_aiter_fp4_asm_gemm_enabled()
if not self.emulate and (dynamic_mxfp4_quant is None or gemm_afp4wfp4 is None):
# Currently need these kernels if not emulating
raise NotImplementedError(
f"{self.__class__.__name__} requires AITER to be installed "
"for non-emulation mode! Please refer to "
"https://github.com/ROCm/aiter for installation details."
)
if not current_platform.supports_mx():
logger.warning_once(
"The current platform does not support native MXFP4/MXFP6 "
"computation. Simulated weight dequantization and activation "
"QDQ (quantize and dequantize) will be used, with the linear "
"layers computed in high precision."
)
if current_platform.supports_mx() and (
self.input_dtype != "mxfp4" or self.weight_dtype != "mxfp4"
):
logger.warning_once(
"The current platform supports native MXFP4/MXFP6 "
f"computation, but kernels for input_dtype={self.input_dtype} "
f"and weight_dtype={self.weight_dtype} are not yet integrated "
"in vLLM. Simulated weight dequantization and activation "
"QDQ (quantize and dequantize) will be used, with the linear "
"layers computed in high precision."
)
def get_packed_dim(self, dim: int, quant_dtype: str):
if quant_dtype == "mxfp4":
assert dim % 2 == 0
return dim // 2
elif quant_dtype in {"mxfp6_e3m2", "mxfp6_e2m3"}:
# FP6 packs 4 * 6 = 24 bits on 3 bytes.
assert (dim * 3) % 4 == 0
return (dim * 3) // 4
else:
raise NotImplementedError(
"Unsupported quant_dtype in QuarkOCP_MX.get_packed_dim, "
f"got quant_dtype={quant_dtype}. Something is wrong, please "
"open an issue."
)
@classmethod
def get_min_capability(cls) -> int:
return 70
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
if self.emulate:
layer.weight_scale = torch.nn.Parameter(
layer.weight_scale.data, requires_grad=False
)
else:
if self.rocm_use_aiter_fp4_asm_gemm:
# shuffle weight scale
weight_scale_shuffle = layer.weight_scale.data
sm, sn = weight_scale_shuffle.shape
weight_scale_shuffle = weight_scale_shuffle.view(
sm // 32, 2, 16, sn // 8, 2, 4, 1
)
weight_scale_shuffle = weight_scale_shuffle.permute(
0, 3, 5, 2, 4, 1, 6
).contiguous()
weight_scale_shuffle = weight_scale_shuffle.view(sm, sn)
layer.weight_scale = torch.nn.Parameter(
weight_scale_shuffle, requires_grad=False
)
# shuffle weight
weight_shuffle = layer.weight.data
weight_shuffle = shuffle_weight(weight_shuffle, layout=(16, 16))
layer.weight = torch.nn.Parameter(weight_shuffle, requires_grad=False)
else:
layer.weight_scale = torch.nn.Parameter(
layer.weight_scale.data.T.contiguous(), requires_grad=False
)
def create_weights(
self,
layer: torch.nn.Module,
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype,
weight_loader: Callable,
**kwargs,
):
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
# WEIGHT
weight = PackedvLLMParameter(
data=torch.empty(
output_size_per_partition,
self.get_packed_dim(input_size_per_partition, self.weight_dtype),
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
packed_dim=1,
packed_factor=self.packed_factor,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
# WEIGHT SCALE
weight_scale = GroupQuantScaleParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition // OCP_MX_BLOCK_SIZE,
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
if self.emulate:
dq_w = self.dequant_func(layer.weight, layer.weight_scale, x.dtype)
qdq_x = self.quant_dequant_func(x)
return F.linear(qdq_x, dq_w, bias)
else:
return torch.ops.vllm.gemm_with_dynamic_quant(
x,
layer.weight,
layer.weight_scale,
self.rocm_use_aiter_fp4_asm_gemm,
self.out_dtype,
)

View File

@@ -0,0 +1,55 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
import torch
__all__ = ["QuarkScheme"]
class QuarkScheme(ABC):
"""
Abstract class used to describe the weight creation and forward pass
of different quantization schemes supported by Quark.
"""
@classmethod
@abstractmethod
def get_min_capability(cls) -> int:
"""
Get minimum device capability.
"""
raise NotImplementedError
@abstractmethod
def create_weights(self, *args, **kwargs):
"""
Weight creation for the particular scheme. Inputs to this function
"""
raise NotImplementedError
@abstractmethod
def apply_weights(
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None
):
"""
Run the forward pass for the particular scheme. This is where
scheme-specific dequant/quant steps/kernels should be applied.
:param layer: torch.nn.Module with the registered weights and
other parameters relevant to the particular scheme.
:param x: input to the layer
:param bias: bias parameter
"""
raise NotImplementedError
@abstractmethod
def process_weights_after_loading(self, layer: torch.nn.Module):
"""
Called after weight loading is complete for any cleanup that
needs to occur.
"""
raise NotImplementedError

View File

@@ -0,0 +1,179 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from typing import Any, cast
import torch
from torch.nn import Parameter
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp,
normalize_e4m3fn_to_e4m3fnuz,
requantize_with_max_scale,
)
from vllm.model_executor.parameter import (
ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter,
)
from vllm.platforms import current_platform
__all__ = ["QuarkW8A8Fp8"]
class QuarkW8A8Fp8(QuarkScheme):
def __init__(
self, weight_config: dict[str, Any], input_config: dict[str, Any] | None
):
self.weight_qscheme = cast(str, weight_config.get("qscheme"))
self.is_static_input_scheme: bool = False
self.input_qscheme: str | None = None
if input_config is not None:
self.is_static_input_scheme = not cast(bool, input_config.get("is_dynamic"))
self.input_qscheme = cast(str, input_config.get("qscheme"))
per_token = (
not self.is_static_input_scheme and self.input_qscheme == "per_channel"
)
self.act_quant_group_shape = (
GroupShape.PER_TOKEN if per_token else GroupShape.PER_TENSOR
)
self.fp8_linear = Fp8LinearOp(
act_quant_static=self.is_static_input_scheme,
act_quant_group_shape=self.act_quant_group_shape,
)
self.out_dtype = torch.get_default_dtype()
@classmethod
def get_min_capability(cls) -> int:
# lovelace and up
return 89
def process_weights_after_loading(self, layer) -> None:
# If per tensor, when we have a fused module (e.g. QKV) with per
# tensor scales (thus N scales being passed to the kernel),
# requantize so we can always run per tensor
if self.weight_qscheme == "per_tensor":
if current_platform.is_fp8_fnuz():
input_scale = getattr(layer, "input_scale", None)
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=input_scale,
)
if input_scale is not None:
layer.input_scale = Parameter(input_scale, requires_grad=False)
else:
max_w_scale = layer.weight_scale
weight = layer.weight
max_w_scale, weight = requantize_with_max_scale(
weight=weight,
weight_scale=max_w_scale,
logical_widths=layer.logical_widths,
)
layer.weight = Parameter(weight.t(), requires_grad=False)
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
# If channelwise, scales are already lined up, so just transpose.
elif self.weight_qscheme == "per_channel":
weight = layer.weight
if current_platform.is_fp8_fnuz():
input_scale = getattr(layer, "input_scale", None)
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
weight=weight,
weight_scale=layer.weight_scale,
input_scale=input_scale,
)
if input_scale is not None:
layer.input_scale = Parameter(input_scale, requires_grad=False)
else:
weight_scale = layer.weight_scale.data
if self.act_quant_group_shape == GroupShape.PER_TOKEN:
weight_scale = weight_scale.view(-1, 1)
layer.weight = Parameter(weight.t(), requires_grad=False)
# required by torch.compile to be torch.nn.Parameter
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
else:
raise ValueError(f"Unknown quantization scheme {self.weight_qscheme}")
# INPUT SCALE
if self.is_static_input_scheme:
layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
else:
layer.input_scale = None
def create_weights(
self,
layer: torch.nn.Module,
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype,
weight_loader: Callable,
**kwargs,
):
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
# WEIGHT
weight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=torch.float8_e4m3fn,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
# WEIGHT SCALE
# TODO: update create_xxx_parameter functions to return
# the newly added parameters
if self.weight_qscheme == "per_channel":
weight_scale = ChannelQuantScaleParameter(
data=torch.empty((sum(output_partition_sizes)), dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader,
)
else:
assert self.weight_qscheme == "per_tensor"
weight_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)
# min requirement for fp8 kernels
weight_scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE
if self.is_static_input_scheme:
input_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)
input_scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("input_scale", input_scale)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
return self.fp8_linear.apply(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
out_dtype=self.out_dtype,
input_scale=layer.input_scale,
bias=bias,
)

View File

@@ -0,0 +1,139 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
ScaledMMLinearLayerConfig,
choose_scaled_mm_linear_kernel,
)
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
from vllm.model_executor.parameter import (
BasevLLMParameter,
ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter,
)
logger = init_logger(__name__)
class QuarkW8A8Int8(QuarkScheme):
_kernel_backends_being_used: set[str] = set()
def __init__(
self,
qscheme: str,
is_static_input_scheme: bool | None,
input_symmetric: bool | None,
):
self.qscheme = qscheme
self.is_static_input_scheme = is_static_input_scheme
self.input_symmetric = input_symmetric
@classmethod
def get_min_capability(cls) -> int:
# turing and up
return 75
def create_weights(
self,
layer: torch.nn.Module,
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype,
weight_loader: Callable,
**kwargs,
):
layer.logical_widths = output_partition_sizes
scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig(
is_channelwise=(self.qscheme == "per_channel"),
is_static_input_scheme=(self.is_static_input_scheme is True),
input_symmetric=(self.input_symmetric is True),
)
kernel_type = choose_scaled_mm_linear_kernel(scaled_mm_linear_kernel_config)
if kernel_type.__name__ not in self._kernel_backends_being_used:
logger.info("Using %s for QuarkW8A8Int8", kernel_type.__name__)
self._kernel_backends_being_used.add(kernel_type.__name__)
# WEIGHT
weight = ModelWeightParameter(
data=torch.empty(
sum(output_partition_sizes), input_size_per_partition, dtype=torch.int8
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
# WEIGHT SCALE
if self.qscheme == "per_channel":
weight_scale = ChannelQuantScaleParameter(
data=torch.empty((sum(output_partition_sizes)), dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader,
)
ChannelQuantZPParameter = ChannelQuantScaleParameter
weight_zero_point = ChannelQuantZPParameter(
data=torch.empty((sum(output_partition_sizes)), dtype=torch.int8),
output_dim=0,
weight_loader=weight_loader,
)
else:
assert self.qscheme == "per_tensor"
weight_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)
PerTensorZPParameter = PerTensorScaleParameter
weight_zero_point = PerTensorZPParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.int8),
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)
layer.register_parameter("weight_zero_point", weight_zero_point)
# INPUT SCALE
if self.is_static_input_scheme:
input_scale = BasevLLMParameter(
data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader
)
layer.register_parameter("input_scale", input_scale)
input_zero_point = BasevLLMParameter(
data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader
)
layer.register_parameter("input_zero_point", input_zero_point)
self.kernel = kernel_type(
c=scaled_mm_linear_kernel_config,
w_q_param_name="weight",
w_s_param_name="weight_scale",
i_s_param_name="input_scale",
i_zp_param_name="input_zero_point",
azp_adj_param_name="azp_adj",
)
# Checkpoints are serialized in quark format, which is
# different from the format the kernel may want. Handle repacking here.
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.register_parameter("weight_zero_point", None)
delattr(layer, "weight_zero_point")
if self.input_symmetric:
layer.register_parameter("input_zero_point", None)
delattr(layer, "input_zero_point")
self.kernel.process_weights_after_loading(layer)
def apply_weights(
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None
) -> torch.Tensor:
return self.kernel.apply_weights(layer, x, bias)