Support mxfp4 for GPT-OSS (#8843)
Co-authored-by: Co-author fzyzcjy <ch271828n@outlook.com> Co-authored-by: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Co-authored-by: zhuofan1123 <zhuofanl@nvidia.com> Co-authored-by: liz-badada <jinyanc@nvidia.com> Co-authored-by: xutizhou <xutingz@nvidia.com> Co-authored-by: linhu-nv <linhu@nvidia.com>
This commit is contained in:
443
python/sglang/srt/layers/quantization/mxfp4.py
Normal file
443
python/sglang/srt/layers/quantization/mxfp4.py
Normal file
@@ -0,0 +1,443 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
# from vllm.model_executor.layers.fused_moe import (
|
||||
# FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase,
|
||||
# FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize)
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
FusedMoEMethodBase,
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
||||
from sglang.srt.utils import (
|
||||
direct_register_custom_op,
|
||||
is_cuda,
|
||||
is_flashinfer_available,
|
||||
is_hip,
|
||||
next_power_of_2,
|
||||
round_up,
|
||||
set_weight_attrs,
|
||||
)
|
||||
|
||||
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
|
||||
|
||||
if is_flashinfer_available():
|
||||
# from flashinfer.fused_moe import cutlass_fused_moe
|
||||
from flashinfer import (
|
||||
mxfp8_quantize,
|
||||
shuffle_matrix_a,
|
||||
shuffle_matrix_sf_a,
|
||||
trtllm_fp4_block_scale_moe,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
|
||||
OCP_MX_BLOCK_SIZE = 32
|
||||
|
||||
|
||||
def _swizzle_mxfp4(quant_tensor, scale, num_warps):
|
||||
"""weight swizzle for mxfp4 moe, used for OAI mxfp4 kernel"""
|
||||
import triton_kernels.matmul_ogs_details.opt_flags as opt_flags
|
||||
from triton_kernels.numerics import InFlexData
|
||||
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
|
||||
from triton_kernels.tensor_details import layout
|
||||
|
||||
value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout(
|
||||
mx_axis=1
|
||||
)
|
||||
scale_layout, scale_layout_opts = layout.make_default_matmul_mxfp4_w_scale_layout(
|
||||
mx_axis=1, num_warps=num_warps
|
||||
)
|
||||
if is_cuda() and torch.cuda.get_device_capability()[0] == 10:
|
||||
constraints = {
|
||||
"is_persistent": True,
|
||||
"epilogue_subtile": 1,
|
||||
}
|
||||
opt_flags.update_opt_flags_constraints(constraints)
|
||||
# transpose the tensor so that the quantization axis is on dim1
|
||||
quant_tensor = quant_tensor.transpose(-2, -1)
|
||||
scale = scale.transpose(-2, -1)
|
||||
quant_tensor = convert_layout(
|
||||
wrap_torch_tensor(quant_tensor, dtype=FP4), value_layout, **value_layout_opts
|
||||
)
|
||||
scale = convert_layout(wrap_torch_tensor(scale), scale_layout, **scale_layout_opts)
|
||||
return quant_tensor, InFlexData(), scale
|
||||
|
||||
|
||||
def _dequant_mxfp4(
|
||||
x: torch.Tensor, scale: torch.Tensor, float_dtype: torch.dtype
|
||||
) -> torch.Tensor:
|
||||
try:
|
||||
from quark.torch.kernel import mx
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"The package `amd-quark` is required to use "
|
||||
"MX-FP4 models. Please install it with `pip install "
|
||||
"amd-quark`."
|
||||
) from err
|
||||
|
||||
return mx.dq_mxfp4(x, scale, float_dtype)
|
||||
|
||||
|
||||
def _dequant_mxfp4_fake(
|
||||
x: torch.Tensor, scale: torch.Tensor, float_dtype: torch.dtype
|
||||
) -> torch.Tensor:
|
||||
return torch.empty(
|
||||
(*x.shape[:-1], x.shape[-1] * 2), dtype=float_dtype, device=x.device
|
||||
)
|
||||
|
||||
|
||||
def _quant_dequant_mxfp4(
|
||||
x: torch.Tensor, scale_calculation_mode: str = "even"
|
||||
) -> torch.Tensor:
|
||||
try:
|
||||
from quark.torch.kernel import mx
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"The package `amd-quark` is required to use "
|
||||
"MX-FP4 models. Please install it with `pip install "
|
||||
"amd-quark`."
|
||||
) from err
|
||||
|
||||
return mx.qdq_mxfp4(x, scale_calculation_mode)
|
||||
|
||||
|
||||
def _quant_dequant_mxfp4_fake(
|
||||
x: torch.Tensor, scale_calculation_mode: str = "even"
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
|
||||
try:
|
||||
direct_register_custom_op(
|
||||
op_name="dequant_mxfp4",
|
||||
op_func=_dequant_mxfp4,
|
||||
mutates_args=[],
|
||||
fake_impl=_dequant_mxfp4_fake,
|
||||
)
|
||||
dequant_mxfp4 = torch.ops.sglang.dequant_mxfp4
|
||||
except AttributeError as error:
|
||||
raise error
|
||||
|
||||
try:
|
||||
direct_register_custom_op(
|
||||
op_name="quant_dequant_mxfp4",
|
||||
op_func=_quant_dequant_mxfp4,
|
||||
mutates_args=[],
|
||||
fake_impl=_quant_dequant_mxfp4_fake,
|
||||
)
|
||||
quant_dequant_mxfp4 = torch.ops.sglang.quant_dequant_mxfp4
|
||||
except AttributeError as error:
|
||||
raise error
|
||||
|
||||
|
||||
class Mxfp4Config(QuantizationConfig):
|
||||
|
||||
def __init__(self, ignored_layers: Optional[list[str]] = None):
|
||||
super().__init__()
|
||||
self.ignored_layers = ignored_layers
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
return cls()
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return "mxfp4"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16, torch.float16]
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return []
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
if self.ignored_layers and is_layer_skipped(
|
||||
prefix=prefix,
|
||||
ignored_layers=self.ignored_layers,
|
||||
fused_mapping=self.packed_modules_mapping,
|
||||
):
|
||||
return UnquantizedLinearMethod()
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return Mxfp4MoEMethod(use_triton_kernels=True, with_bias=True)
|
||||
else:
|
||||
raise NotImplementedError("Mxfp4 attention layer is not implemented")
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def __init__(self, use_triton_kernels: bool = True, with_bias: bool = True):
|
||||
super().__init__()
|
||||
self.topk_indices_dtype = None
|
||||
self.use_triton_kernels = use_triton_kernels
|
||||
self.with_bias = with_bias
|
||||
self.triton_kernel_moe_forward = None
|
||||
self.triton_kernel_moe_with_bias_forward = None
|
||||
if torch.cuda.is_available() and has_triton_kernels:
|
||||
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
|
||||
triton_kernel_moe_forward as _tk_forward,
|
||||
)
|
||||
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
|
||||
triton_kernel_moe_with_bias_forward as _tk_with_bias_forward,
|
||||
)
|
||||
|
||||
self.triton_kernel_moe_forward = _tk_forward
|
||||
self.triton_kernel_moe_with_bias_forward = _tk_with_bias_forward
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
# print(f"hi {self=} create_weights {layer=}")
|
||||
self.num_experts = num_experts
|
||||
weight_dtype = torch.uint8
|
||||
scale_dtype = torch.uint8
|
||||
|
||||
intermediate_size *= 2
|
||||
mxfp4_block = 32
|
||||
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_size = hidden_size
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts, 2 * intermediate_size, hidden_size // 2, dtype=weight_dtype
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size,
|
||||
hidden_size // mxfp4_block,
|
||||
dtype=scale_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
|
||||
w13_weight_bias = torch.nn.Parameter(
|
||||
torch.zeros(num_experts, 2 * intermediate_size, dtype=torch.bfloat16),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight_bias", w13_weight_bias)
|
||||
set_weight_attrs(w13_weight_bias, extra_weight_attrs)
|
||||
|
||||
# down_proj (row parallel)
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts, hidden_size, intermediate_size // 2, dtype=weight_dtype
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size // mxfp4_block,
|
||||
dtype=scale_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
|
||||
w2_weight_bias = torch.nn.Parameter(
|
||||
torch.zeros(num_experts, hidden_size, dtype=torch.bfloat16),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight_bias", w2_weight_bias)
|
||||
set_weight_attrs(w2_weight_bias, extra_weight_attrs)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
|
||||
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
|
||||
|
||||
w13_weight_bias = layer.w13_weight_bias.to(torch.float32)
|
||||
w2_weight_bias = layer.w2_weight_bias.to(torch.float32)
|
||||
|
||||
layer.w13_weight_bias = Parameter(w13_weight_bias, requires_grad=False)
|
||||
layer.w2_weight_bias = Parameter(w2_weight_bias, requires_grad=False)
|
||||
|
||||
num_warps = 8
|
||||
|
||||
w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
|
||||
layer.w13_weight, layer.w13_weight_scale, num_warps
|
||||
)
|
||||
w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
|
||||
layer.w2_weight, layer.w2_weight_scale, num_warps
|
||||
)
|
||||
|
||||
self.w13_precision_config = PrecisionConfig(
|
||||
weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex)
|
||||
)
|
||||
self.w2_precision_config = PrecisionConfig(
|
||||
weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex)
|
||||
)
|
||||
|
||||
self.w13_weight_triton_tensor = w13_weight
|
||||
self.w2_weight_triton_tensor = w2_weight
|
||||
|
||||
# need to delete the original weights to save memory on single GPU
|
||||
del layer.w13_weight
|
||||
del layer.w2_weight
|
||||
layer.w13_weight = None
|
||||
layer.w2_weight = None
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int):
|
||||
# Number of tokens in the input tensor.
|
||||
num_tokens = x.shape[0]
|
||||
# Factor to account for the imbalance of the experts.
|
||||
# factor equals to the
|
||||
# max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
|
||||
# - 1.0 means perfect expert distribution.
|
||||
# - > 1.0 means some experts have more
|
||||
# tokens than the perfect distribution.
|
||||
# - < 1.0 does not make sense.
|
||||
imbalance_factor = 1.3
|
||||
# Calculate the number of tokens per expert
|
||||
# assuming perfect distribution.
|
||||
num_tokens_per_expert = (num_tokens * top_k) // self.num_experts
|
||||
# Apply the imbalance factor.
|
||||
num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
|
||||
# And pad the number to the next power of 2.
|
||||
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
|
||||
# Cap to 8-64 tokens per CTA tile
|
||||
# as it's the range supported by the kernel.
|
||||
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
||||
|
||||
return tile_tokens_dim
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
*,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
activation_alpha: Optional[float] = None,
|
||||
swiglu_limit: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
# avoid import error when triton_kernel is not installed
|
||||
# from vllm.model_executor.layers.fused_moe.triton_kernels_moe import (
|
||||
# triton_kernel_moe_forward)
|
||||
|
||||
"""
|
||||
if (envs.VLLM_USE_FLASHINFER_MXFP4_MOE
|
||||
or envs.VLLM_USE_FLASHINFER_MXFP4_BF16_MOE):
|
||||
assert not self.moe.use_ep, (
|
||||
"EP is not supported for flashinfer mxfp4 moe backend yet.")
|
||||
if envs.VLLM_USE_FLASHINFER_MXFP4_BF16_MOE:
|
||||
assert x.dtype == torch.bfloat16
|
||||
x_quant = x
|
||||
x_scale = None
|
||||
else:
|
||||
x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8
|
||||
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)
|
||||
trtllm_gen_output = trtllm_fp4_block_scale_moe(
|
||||
router_logits.to(torch.bfloat16),
|
||||
None, # routing_bias
|
||||
x_quant,
|
||||
x_scale,
|
||||
layer.w13_weight, # uint8 (e2m1 x 2)
|
||||
layer.w13_weight_scale, # uint8 (e4m3 x 2)
|
||||
layer.w13_weight_bias, # fp32 per expert per channel
|
||||
layer.gemm1_alpha, # fp32 per expert
|
||||
layer.gemm1_beta, # fp32 per expert
|
||||
layer.gemm1_clamp_limit, # fp32 per expert
|
||||
layer.w2_weight, # uint8 (e2m1 x 2)
|
||||
layer.w2_weight_scale, # ue8m0
|
||||
layer.w2_weight_bias, # fp32 per expert per channel
|
||||
None, # output1_scale_scalar
|
||||
None, # output1_scale_gate_scalar
|
||||
None, # output2_scale_scalar
|
||||
self.num_experts,
|
||||
top_k,
|
||||
None, # n_group
|
||||
None, # topk_group
|
||||
self.intermediate_size, # padded to multiple of 256
|
||||
0, # local_expert_offset
|
||||
self.num_experts, # local num experts
|
||||
None,
|
||||
self._get_tile_tokens_dim(x, top_k),
|
||||
1, # routing_method_type, renormalize
|
||||
True, # do finalize
|
||||
)[0]
|
||||
return trtllm_gen_output
|
||||
"""
|
||||
|
||||
if self.use_triton_kernels:
|
||||
if self.with_bias:
|
||||
# TODO why we do not put weights on layer?
|
||||
assert layer.w13_weight is None
|
||||
assert layer.w2_weight is None
|
||||
return self.triton_kernel_moe_with_bias_forward(
|
||||
hidden_states=x,
|
||||
w1=self.w13_weight_triton_tensor,
|
||||
w1_pcg=self.w13_precision_config,
|
||||
w2=self.w2_weight_triton_tensor,
|
||||
w2_pcg=self.w2_precision_config,
|
||||
b1=layer.w13_weight_bias,
|
||||
b2=layer.w2_weight_bias,
|
||||
topk_output=topk_output,
|
||||
activation=activation,
|
||||
activation_alpha=activation_alpha,
|
||||
swiglu_limit=swiglu_limit,
|
||||
)
|
||||
else:
|
||||
return self.triton_kernel_moe_forward(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_output=topk_output,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
Reference in New Issue
Block a user