[gpt-oss] Add gpt-oss mxfp4 support

This commit is contained in:
2025-08-25 15:31:09 +08:00
parent db7f48eeac
commit 7a35b2f32d
32 changed files with 4835 additions and 1190 deletions

View File

@@ -261,6 +261,13 @@ def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
return s
def marlin_permute_bias(s: torch.Tensor) -> torch.Tensor:
origin_shape = s.shape
_, scale_perm_single = get_scale_perms()
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
return s.reshape(*origin_shape).contiguous()
def marlin_moe_permute_scales(
s: torch.Tensor,
size_k: int,
@@ -410,6 +417,7 @@ def apply_gptq_marlin_linear(
output = ops.gptq_marlin_gemm(reshaped_x,
None,
weight,
bias,
weight_scale,
None,
weight_zp,
@@ -425,9 +433,6 @@ def apply_gptq_marlin_linear(
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False)
if bias is not None:
output.add_(bias) # In-place add
return output.reshape(out_shape)
@@ -456,6 +461,7 @@ def apply_awq_marlin_linear(
output = ops.gptq_marlin_gemm(reshaped_x,
None,
weight,
bias,
weight_scale,
None,
weight_zp,
@@ -470,7 +476,4 @@ def apply_awq_marlin_linear(
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False)
if bias is not None:
output.add_(bias) # In-place add
return output.reshape(out_shape)

View File

@@ -8,8 +8,8 @@ import torch
import vllm._custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_scales,
should_use_atomic_add_reduce)
USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_bias,
marlin_permute_scales, should_use_atomic_add_reduce)
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
@@ -22,7 +22,7 @@ def is_fp4_marlin_supported():
return current_platform.has_device_capability(80)
def fp4_marlin_process_scales(marlin_scales):
def nvfp4_marlin_process_scales(marlin_scales):
if not (marlin_scales >= 0).all():
logger.warning_once(
"NVFP4 Marlin assumes the scales to be >=0, but has encountered "
@@ -56,7 +56,20 @@ def fp4_marlin_process_scales(marlin_scales):
return marlin_scales
def fp4_marlin_process_global_scale(global_scale):
def mxfp4_marlin_process_scales(marlin_scales):
# 8 is the number of scale number using by one thread
marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8)
marlin_scales = marlin_scales.permute(0, 2, 1, 3).reshape(
marlin_scales.size(0) * 2, -1)
# fit the layout of fp8 dequantization
marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view(
marlin_scales.size(0), -1)
marlin_scales = marlin_scales.to(torch.float8_e8m0fnu)
return marlin_scales
def nvfp4_marlin_process_global_scale(global_scale):
assert global_scale.dtype in [torch.half, torch.bfloat16]
fp4_exponent = 2
if global_scale.dtype == torch.half:
@@ -73,7 +86,7 @@ def apply_fp4_marlin_linear(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
weight_scale_2: torch.Tensor,
weight_scale_2: Optional[torch.Tensor],
workspace: torch.Tensor,
size_n: int,
size_k: int,
@@ -94,6 +107,7 @@ def apply_fp4_marlin_linear(
output = ops.gptq_marlin_gemm(a=reshaped_x,
c=None,
b_q_weight=weight,
b_bias=bias,
b_scales=weight_scale,
global_scale=weight_scale_2,
b_zeros=None,
@@ -107,9 +121,6 @@ def apply_fp4_marlin_linear(
use_atomic_add=use_atomic_add,
use_fp32_reduce=use_fp32_reduce)
if bias is not None:
output.add_(bias) # In-place add
return output.reshape(out_shape)
@@ -120,6 +131,9 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
"be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads.")
is_nvfp4 = hasattr(layer, "weight_scale_2")
group_size = 16 if is_nvfp4 else 32
part_size_n = layer.output_size_per_partition
part_size_k = layer.input_size_per_partition
param_dtype = layer.params_dtype
@@ -145,18 +159,35 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
# WEIGHT SCALES
# Permute scales
weight_scale = layer.weight_scale.T.to(param_dtype)
weight_scale = layer.weight_scale.T.contiguous()
if not is_nvfp4:
weight_scale = weight_scale.view(torch.float8_e8m0fnu)
weight_scale = weight_scale.to(param_dtype)
weight_scale = marlin_permute_scales(s=weight_scale,
size_k=part_size_k,
size_n=part_size_n,
group_size=16)
weight_scale = fp4_marlin_process_scales(weight_scale)
layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
group_size=group_size)
weight_scale_2 = layer.weight_scale_2.to(param_dtype)
weight_scale_2 = fp4_marlin_process_global_scale(weight_scale_2)
layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2,
requires_grad=False)
if is_nvfp4:
weight_scale = nvfp4_marlin_process_scales(weight_scale)
layer.weight_scale = torch.nn.Parameter(weight_scale,
requires_grad=False)
weight_scale_2 = layer.weight_scale_2.to(param_dtype)
weight_scale_2 = nvfp4_marlin_process_global_scale(weight_scale_2)
layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2,
requires_grad=False)
else:
weight_scale = mxfp4_marlin_process_scales(weight_scale)
layer.weight_scale = torch.nn.Parameter(weight_scale,
requires_grad=False)
if hasattr(layer, "bias") and layer.bias is not None:
assert layer.bias.shape == (part_size_n, )
bias = marlin_permute_bias(layer.bias)
layer.bias = torch.nn.Parameter(bias, requires_grad=False)
return
@@ -168,6 +199,9 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
"be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads.")
is_nvfp4 = hasattr(layer, "w13_weight_scale_2")
group_size = 16 if is_nvfp4 else 32
e = layer.num_experts
k = layer.hidden_size
n = layer.intermediate_size_per_partition
@@ -208,8 +242,13 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
# WEIGHT SCALES
# Permute scales
for name in ["w13", "w2"]:
scales = getattr(layer, name + "_weight_scale").to(param_dtype)
global_scale = getattr(layer, name + "_weight_scale_2").to(param_dtype)
scales = getattr(layer, name + "_weight_scale")
if not is_nvfp4:
scales = scales.view(torch.float8_e8m0fnu)
scales = scales.to(param_dtype)
if is_nvfp4:
global_scale = getattr(layer,
name + "_weight_scale_2").to(param_dtype)
tensor_list = []
if "w13" in name:
@@ -218,23 +257,47 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
size_n, size_k = k, n
for i in range(e):
marlin_scales = marlin_permute_scales(s=scales[i].T,
scale = scales[i].T
marlin_scales = marlin_permute_scales(s=scale,
size_k=size_k,
size_n=size_n,
group_size=16)
marlin_scales = fp4_marlin_process_scales(marlin_scales)
group_size=group_size)
if is_nvfp4:
marlin_scales = nvfp4_marlin_process_scales(marlin_scales)
else:
marlin_scales = mxfp4_marlin_process_scales(marlin_scales)
tensor_list.append(marlin_scales)
scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
scales = torch.nn.Parameter(scales, requires_grad=False)
setattr(layer, name + "_weight_scale", scales)
global_scale = fp4_marlin_process_global_scale(global_scale)
global_scale = torch.nn.Parameter(global_scale, requires_grad=False)
setattr(layer, name + "_weight_scale_2", global_scale)
if is_nvfp4:
global_scale = nvfp4_marlin_process_global_scale(global_scale)
global_scale = torch.nn.Parameter(global_scale,
requires_grad=False)
setattr(layer, name + "_weight_scale_2", global_scale)
# BIAS
# Permute bias
for name in ["w13_bias", "w2_bias"]:
if not hasattr(layer, name):
continue
bias = getattr(layer, name).to(param_dtype)
tensor_list = []
for i in range(e):
expert_bias = bias[i]
tensor_list.append(marlin_permute_bias(expert_bias))
bias = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
bias = torch.nn.Parameter(bias, requires_grad=False)
setattr(layer, name, bias)
def rand_marlin_weight_fp4_like(weight, group_size):
def rand_marlin_weight_nvfp4_like(weight, group_size):
assert group_size > 0
size_n, size_k = weight.shape
device = weight.device
@@ -276,8 +339,58 @@ def rand_marlin_weight_fp4_like(weight, group_size):
size_k=size_k,
size_n=size_n,
group_size=group_size)
marlin_scales = fp4_marlin_process_scales(marlin_scales)
marlin_scales = nvfp4_marlin_process_scales(marlin_scales)
global_scale = fp4_marlin_process_global_scale(global_scale)
global_scale = nvfp4_marlin_process_global_scale(global_scale)
return weight_ref.T, marlin_qweight, marlin_scales, global_scale
def rand_marlin_weight_mxfp4_like(weight, group_size):
assert group_size > 0
size_n, size_k = weight.shape
device = weight.device
scales = torch.randint(100,
125, (size_n, size_k // group_size),
dtype=torch.uint8,
device=weight.device)
scales = scales.view(torch.float8_e8m0fnu)
fp4_weight = torch.randint(0,
256, (size_n, size_k // 2),
dtype=torch.uint8,
device=weight.device)
fp4_weight_part_1 = ((fp4_weight & 0b10000000) |
((fp4_weight & 0b01110000) >> 2))
fp4_weight_part_1 = fp4_weight_part_1.view(torch.float8_e4m3fn)
fp4_weight_part_1 = fp4_weight_part_1.to(weight.dtype) * (2**6)
fp4_weight2 = fp4_weight << 4
fp4_weight_part_2 = ((fp4_weight2 & 0b10000000) |
((fp4_weight2 & 0b01110000) >> 2))
fp4_weight_part_2 = fp4_weight_part_2.view(torch.float8_e4m3fn)
fp4_weight_part_2 = fp4_weight_part_2.to(weight.dtype) * (2**6)
weight_ref = torch.cat(
[fp4_weight_part_2.unsqueeze(2),
fp4_weight_part_1.unsqueeze(2)], 2).view(size_n, size_k)
weight_ref = weight_ref * \
scales.repeat_interleave(group_size, 1).to(weight.dtype)
marlin_qweight = ops.gptq_marlin_repack(
b_q_weight=fp4_weight.view(torch.int32).T.contiguous(),
perm=torch.empty(0, dtype=torch.int, device=device),
size_k=size_k,
size_n=size_n,
num_bits=4,
)
marlin_scales = marlin_permute_scales(s=scales.T.to(weight.dtype),
size_k=size_k,
size_n=size_n,
group_size=group_size)
marlin_scales = mxfp4_marlin_process_scales(marlin_scales)
return weight_ref.T, marlin_qweight, marlin_scales.to(torch.float8_e8m0fnu)

View File

@@ -1,45 +1,133 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, Optional
import torch
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer
logger = init_logger(__name__)
OCP_MX_BLOCK_SIZE = 32
def per_token_group_quant_mxfp4(x: torch.Tensor,
block_k: int,
scale_calculation_mode: str = "even"
) -> tuple[torch.Tensor, torch.Tensor]:
def _swizzle_mxfp4(quant_tensor, scale, num_warps):
""" weight swizzle for mxfp4 moe, used for OAI mxfp4 kernel
"""
import triton_kernels.matmul_ogs_details.opt_flags as opt_flags
from triton_kernels.numerics import InFlexData
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
from triton_kernels.tensor_details import layout
from triton_kernels.tensor_details.layout import StridedLayout
if (current_platform.is_cuda()
and current_platform.is_device_capability(90)
and not is_torch_equal_or_newer("2.8.1")):
logger.warning_once(
"Mxfp4 on hopper is running on torch < 2.8.1, "
"this cause swizling to be disabled, which may "
"cause performance degradation. Please upgrade to torch nightly")
value_layout, value_layout_opts = StridedLayout, dict()
scale_layout, scale_layout_opts = StridedLayout, dict()
else:
value_layout, value_layout_opts = \
layout.make_default_matmul_mxfp4_w_layout(mx_axis=1)
scale_layout, scale_layout_opts = (
layout.make_default_matmul_mxfp4_w_scale_layout(
mx_axis=1, num_warps=num_warps))
if current_platform.is_cuda() and \
current_platform.is_device_capability(100):
constraints = {
"is_persistent": True,
"epilogue_subtile": 1,
}
opt_flags.update_opt_flags_constraints(constraints)
# transpose the tensor so that the quantization axis is on dim1
quant_tensor = quant_tensor.transpose(-2, -1)
scale = scale.transpose(-2, -1)
quant_tensor = convert_layout(wrap_torch_tensor(quant_tensor, dtype=FP4),
value_layout, **value_layout_opts)
scale = convert_layout(wrap_torch_tensor(scale), scale_layout,
**scale_layout_opts)
return quant_tensor, InFlexData(), scale
def _can_support_mxfp4(use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
scoring_func: str = "softmax",
activation: str = "swiglu_oai",
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None):
return not (use_grouped_topk or topk_group or num_expert_group
or expert_map or custom_routing_function
or e_score_correction_bias or apply_router_weight_on_input
or scoring_func != "softmax" or activation != "swiglu_oai"
or expert_load_view or logical_to_physical_map
or logical_replica_count)
def _dequant_mxfp4(x: torch.Tensor, scale: torch.Tensor,
float_dtype: torch.dtype) -> torch.Tensor:
try:
from quark.torch.kernel.hw_emulation.hw_emulation_interface import (
fake_quantize_fp4_fp6_per_group_with_scale)
from quark.torch.quantization.utils import (even_round,
reshape_to_blocks)
from quark.torch.kernel import mx
except ImportError as err:
raise ImportError("The package `amd-quark` is required to use "
"MX-FP4 models. Please install it with `pip install "
"amd-quark`.") from err
axis = -1
block_x = reshape_to_blocks(x, block_k, axis)
amax, _ = torch.max(torch.abs(block_x), dim=-1, keepdim=True)
amax = amax.squeeze(-1)
return mx.dq_mxfp4(x, scale, float_dtype)
# TODO: there are other rounding strategies supported in quark and in the
# config.json that we do not check for here!
if scale_calculation_mode != "even":
raise NotImplementedError(
f"Scale calculation mode {scale_calculation_mode} is not yet "
"supported in MX-FP4 quantization")
scale = even_round(amax, "fp4")
# Apply dequantize(quantize(x)).
x = fake_quantize_fp4_fp6_per_group_with_scale(
x,
scale.to(x.device),
axis=axis,
group_size=block_k,
quant_dtype="fp4",
def _dequant_mxfp4_fake(x: torch.Tensor, scale: torch.Tensor,
float_dtype: torch.dtype) -> torch.Tensor:
return torch.empty((*x.shape[:-1], x.shape[-1] * 2),
dtype=float_dtype,
device=x.device)
def _quant_dequant_mxfp4(x: torch.Tensor,
scale_calculation_mode: str = "even") -> torch.Tensor:
try:
from quark.torch.kernel import mx
except ImportError as err:
raise ImportError("The package `amd-quark` is required to use "
"MX-FP4 models. Please install it with `pip install "
"amd-quark`.") from err
return mx.qdq_mxfp4(x, scale_calculation_mode)
def _quant_dequant_mxfp4_fake(x: torch.Tensor,
scale_calculation_mode: str = "even"
) -> torch.Tensor:
return torch.empty_like(x)
try:
direct_register_custom_op(
op_name="dequant_mxfp4",
op_func=_dequant_mxfp4,
mutates_args=[],
fake_impl=_dequant_mxfp4_fake,
)
dequant_mxfp4 = torch.ops.vllm.dequant_mxfp4
except AttributeError as error:
raise error
return x, scale
try:
direct_register_custom_op(
op_name="quant_dequant_mxfp4",
op_func=_quant_dequant_mxfp4,
mutates_args=[],
fake_impl=_quant_dequant_mxfp4_fake,
)
quant_dequant_mxfp4 = torch.ops.vllm.quant_dequant_mxfp4
except AttributeError as error:
raise error

View File

@@ -3,22 +3,41 @@
"""This file is used for /tests and /benchmarks"""
from collections.abc import Mapping
from types import MappingProxyType
from typing import Optional
from typing import ClassVar, NamedTuple, Optional
import numpy
import torch
from vllm._custom_ops import cutlass_scaled_mm_supports_fp4
from vllm.model_executor.layers.quantization.qqq import (
MARLIN_QQQ_SUPPORTED_NUM_BITS)
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types
SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
# Use proxy as NamedTuple direct subclasses cannot have static members
class _GroupShape(NamedTuple):
row: int
col: int
class GroupShape(_GroupShape):
"""
This class describes the quantization group shape.
It includes static members for common shapes (per-tensor, per-token).
"""
# Aliases for common quantization group shapes
PER_TENSOR: ClassVar['GroupShape']
PER_TOKEN: ClassVar['GroupShape']
GroupShape.PER_TENSOR = GroupShape(-1, -1)
GroupShape.PER_TOKEN = GroupShape(1, -1)
# Normalize the group_shape to the full extent for any dims that are -1
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: tuple[int,
int]):
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape):
# -1 means full extent
return (group_shape[0] if group_shape[0] > 0 else x.shape[-2],
group_shape[1] if group_shape[1] > 0 else x.shape[-1])
@@ -58,7 +77,7 @@ def group_broadcast(t, shape):
# (i.e. per-token-per-group)
def scaled_quantize(
x: torch.Tensor,
group_shape: tuple[int, int],
group_shape: GroupShape,
quant_dtype: torch.dtype,
) -> tuple[torch.Tensor, torch.Tensor]:
group_shape = _normalize_quant_group_shape(x, group_shape)
@@ -99,7 +118,7 @@ def scaled_quantize(
def scaled_dequantize(
x_q: torch.Tensor,
x_s: torch.Tensor,
group_shape: Optional[tuple[int, int]] = None,
group_shape: Optional[GroupShape] = None,
out_dtype: torch.dtype = torch.float32,
) -> tuple[torch.Tensor, torch.Tensor]:
if group_shape is not None:
@@ -332,6 +351,10 @@ def quantize_weights(w: torch.Tensor,
)
SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
def gptq_quantize_weights(w: torch.Tensor,
quant_type: ScalarType,
group_size: int,
@@ -571,3 +594,56 @@ def awq_pack(
q_w = q_w.reshape((-1, size_n)).contiguous()
return pack_cols(q_w, num_bits, size_k, size_n)
def swizzle_blockscale(scale: torch.Tensor) -> torch.Tensor:
"""
Pad and block-interleave the FP4 block-scales so that they match the data
layout expected by the CUTLASS / FlashInfer kernels.
Parameters
----------
scale: torch.Tensor
Returns
-------
torch.Tensor
The swizzled tensor with the same logical shape as *scale*.
"""
assert scale.dtype == torch.float8_e4m3fn, (
"swizzle_blockscale expects the input tensor to be in "
"torch.float8_e4m3fn format.")
scale_ndim = scale.ndim
if scale_ndim == 2:
scale = scale.unsqueeze(0) # (1, M, K)
assert scale.ndim == 3, "Expected a 2-D or 3-D tensor for block scales."
B, M, K = scale.shape
def _round_up(x: int, m: int) -> int:
return (x + m - 1) // m * m
M_padded = _round_up(M, 128)
K_padded = _round_up(K, 4)
padded = torch.zeros((B, M_padded, K_padded),
dtype=scale.dtype,
device=scale.device)
padded[:B, :M, :K] = scale
# Reshape / permute to the layout required by the kernel.
padded = padded.reshape(B, M_padded // 128, 4, 32, K_padded // 4, 4)
swizzled = padded.permute(0, 1, 4, 3, 2, 5).contiguous().cuda()
if scale_ndim == 2:
return swizzled.reshape(M, K)
return swizzled.reshape(B, M, K)
def cutlass_fp4_supported() -> bool:
if not current_platform.is_cuda():
return False
capability_tuple = current_platform.get_device_capability()
capability = -1 if capability_tuple is None else capability_tuple.to_int()
return cutlass_scaled_mm_supports_fp4(capability)