Upgrade to vllm 0.17.0 corex v4.1 overlay

This commit is contained in:
2026-04-29 19:38:22 +08:00
parent 8fac6062e4
commit 938d0854a5
430 changed files with 35969 additions and 14511 deletions

View File

@@ -6,6 +6,7 @@ import torch
from torch.nn.parameter import Parameter
from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
@@ -77,6 +78,8 @@ class Mxfp4Backend(Enum):
# Triton Backend
TRITON = 6
CK = 7
def get_mxfp4_backend_with_lora() -> Mxfp4Backend:
"""
@@ -167,9 +170,15 @@ def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend:
elif current_platform.is_xpu():
logger.info_once("Using xpu backend on XPU")
return Mxfp4Backend.MARLIN
elif current_platform.is_rocm() and has_triton_kernels():
logger.info_once("Using Triton backend")
return Mxfp4Backend.TRITON
elif current_platform.is_rocm():
from vllm.platforms.rocm import on_gfx950
if rocm_aiter_ops.is_enabled() and on_gfx950():
logger.info_once("Using CK MXFP4 MoE backend (Aiter ROCm)")
return Mxfp4Backend.CK
elif has_triton_kernels():
logger.info_once("Using Triton backend")
return Mxfp4Backend.TRITON
return Mxfp4Backend.NONE
@@ -257,7 +266,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
)
self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
# Initialized in process_weights_after_loading for CUTLASS/SM90 backends
self.moe_mk: mk.FusedMoEModularKernel | None = None
self.moe_kernel: mk.FusedMoEKernel | None = None
def create_weights(
self,
@@ -338,6 +347,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self.intermediate_size = intermediate_size_per_partition_after_pad
self.hidden_size = hidden_size
self.hidden_pad = extra_weight_attrs.get("hidden_pad", 0)
self.intermediate_pad = (
intermediate_size_per_partition_after_pad - intermediate_size_per_partition
)
# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(
torch.zeros(
@@ -427,7 +440,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
)
assert prepare_finalize is not None
self.moe_mk = mk.FusedMoEModularKernel(
self.moe_kernel = mk.FusedMoEKernel(
prepare_finalize,
MarlinExperts(
self.moe,
@@ -776,7 +789,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
)
assert prepare_finalize is not None
self.moe_mk = mk.FusedMoEModularKernel(
self.moe_kernel = mk.FusedMoEKernel(
prepare_finalize,
FlashInferExperts(
moe_config=self.moe,
@@ -784,6 +797,66 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
),
shared_experts=None,
)
elif self.mxfp4_backend == Mxfp4Backend.CK:
if layer.w13_bias is not None:
layer.w13_bias.data = layer.w13_bias.data.to(torch.float32)
if layer.w2_bias.data is not None:
layer.w2_bias.data = layer.w2_bias.data.to(torch.float32)
e, n, k = layer.w13_weight.shape
layer.w13_weight.view(torch.uint8).copy_(
layer.w13_weight.data.view(torch.uint8)
.view(e, n // 2, 2, k)
.permute(0, 2, 1, 3)
.contiguous()
.view(e, n, k)
)
layer.w13_weight_scale.data = (
layer.w13_weight_scale.data.view(e, n // 2, 2, -1)
.permute(0, 2, 1, 3)
.contiguous()
.view(e, n, -1)
)
layer.w13_weight.data = layer.w13_weight.data.view(torch.float4_e2m1fn_x2)
layer.w2_weight.data = layer.w2_weight.data.view(torch.float4_e2m1fn_x2)
layer.w13_weight.data = rocm_aiter_ops.shuffle_weight_a16w4(
layer.w13_weight, 16, True
)
shuffled_w13_scale = rocm_aiter_ops.shuffle_scale_a16w4(
layer.w13_weight_scale.view(-1, layer.w13_weight_scale.shape[-1]),
self.num_experts,
True,
)
layer.w2_weight.data = rocm_aiter_ops.shuffle_weight_a16w4(
layer.w2_weight, 16, False
)
shuffled_w2_scale = rocm_aiter_ops.shuffle_scale_a16w4(
layer.w2_weight_scale.view(-1, layer.w2_weight_scale.shape[-1]),
self.num_experts,
False,
)
layer.w13_bias.data = (
layer.w13_bias.data.view(-1, n // 2, 2)
.permute(0, 2, 1)
.contiguous()
.view(-1, n)
)
layer.w13_weight_scale = torch.nn.Parameter(
shuffled_w13_scale, requires_grad=False
)
layer.w2_weight_scale = torch.nn.Parameter(
shuffled_w2_scale, requires_grad=False
)
# replace_parameter(layer, "w13_bias", w13_bias)
# replace_parameter(layer, "w13_weight_scale", w13_weight_scale)
# replace_parameter(layer, "w2_weight_scale", w2_weight_scale)
# replace_parameter(layer, "w13_weight", w13_weight)
# replace_parameter(layer, "w2_weight", w2_weight)
elif self.mxfp4_backend == Mxfp4Backend.TRITON:
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
@@ -792,18 +865,16 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
layer.w13_bias = Parameter(w13_bias, requires_grad=False)
layer.w2_bias = Parameter(w2_bias, requires_grad=False)
# Ideally we'd use FusedMoEModularKernel.prepare_finalize object
# (stored in self.fused_experts) to determine if the MoE has a
# batched activation format. As self.fused_experts is not
# initialized at this point, we resort to checking the MoE config
# directly.
is_batched_moe = self.moe.use_pplx_kernels or self.moe.use_deepep_ll_kernels
is_batched_moe = self.moe.use_deepep_ll_kernels
if is_batched_moe:
num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8
else:
num_warps = 8
w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
layer.w13_weight, layer.w13_weight_scale, num_warps
)
@@ -817,13 +888,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self.w2_precision_config = PrecisionConfig(
weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex)
)
self.w13_weight = w13_weight
self.w2_weight = w2_weight
del layer.w13_weight
del layer.w2_weight
layer.w13_weight = w13_weight
layer.w2_weight = w2_weight
else:
raise ValueError(
f"Unsupported mxfp4_backend: {self.mxfp4_backend}: "
@@ -862,6 +933,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
elif self.mxfp4_backend in [
Mxfp4Backend.SM100_FI_MXFP4_BF16,
Mxfp4Backend.SM90_FI_MXFP4_BF16,
Mxfp4Backend.CK,
]:
return mxfp4_w4a16_moe_quant_config(
w1_bias=layer.w13_bias,
@@ -882,9 +954,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
) -> mk.FusedMoEExpertsModular:
if (
prepare_finalize.activation_format
== mk.FusedMoEActivationFormat.BatchedExperts
@@ -929,10 +1001,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
@property
def is_monolithic(self) -> bool:
if self.moe.is_lora_enabled:
return False
return (
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
or self.mxfp4_backend == Mxfp4Backend.TRITON
or self.mxfp4_backend == Mxfp4Backend.CK
)
def apply(
@@ -968,8 +1043,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
or self.mxfp4_backend == Mxfp4Backend.MARLIN
)
assert self.moe_mk is not None
return self.moe_mk(
assert self.moe_kernel is not None
return self.moe_kernel.apply(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
@@ -1054,6 +1129,27 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
tune_max_num_tokens=max(self.max_capture_size, 1),
)[0]
return trtllm_gen_output
elif self.mxfp4_backend == Mxfp4Backend.CK:
topk_weights, topk_ids = rocm_aiter_ops.fused_topk(
x, router_logits, layer.top_k, True
)
output = rocm_aiter_ops.fused_moe(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
activation_method=rocm_aiter_ops.get_aiter_activation_type("swiglu"),
quant_method=rocm_aiter_ops.get_aiter_quant_type("per_1x32"),
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
doweight_stage1=False,
hidden_pad=self.hidden_pad // 128 * 128,
intermediate_pad=self.intermediate_pad // 64 * 64 * 2,
bias1=layer.w13_bias,
bias2=layer.w2_bias,
)
return output
elif self.mxfp4_backend == Mxfp4Backend.TRITON:
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( # noqa: E501
triton_kernel_moe_forward,
@@ -1162,7 +1258,7 @@ class XpuMxfp4MoEMethod(Mxfp4MoEMethod):
topk_weights=routing_weights,
topk_ids=selected_experts,
n_experts_per_token=layer.top_k,
activation=layer.activation,
activation=layer.activation.value,
num_experts=layer.local_num_experts,
is_mxfp4=True,
)