Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -64,7 +64,7 @@ if current_platform.is_cuda_alike():
|
||||
|
||||
# TODO(bowen): When using `FusedMoEModularKernel`, this
|
||||
# can be done in a more unified way, since
|
||||
# `FusedMoEPrepareAndFinalize` will return the expert
|
||||
# `FusedMoEPrepareAndFinalizeModular` will return the expert
|
||||
# token count, in some cases directly from the kernel.
|
||||
# However, now there are many code paths not using
|
||||
# the modular kernel, e.g. calling `fused_experts`,
|
||||
@@ -175,6 +175,7 @@ class BaseRouter(FusedMoERouter):
|
||||
topk_ids = topk_ids.to(dtype=indices_type)
|
||||
|
||||
assert topk_ids.dtype == indices_type or indices_type is None
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
return topk_ids
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -31,7 +31,7 @@ def vllm_topk_softmax(
|
||||
token_expert_indices,
|
||||
gating_output,
|
||||
renormalize,
|
||||
e_score_correction_bias,
|
||||
e_score_correction_bias
|
||||
)
|
||||
|
||||
return topk_weights, topk_indices
|
||||
@@ -85,13 +85,14 @@ def fused_topk_bias(
|
||||
token_expert_indices = torch.empty(
|
||||
M, topk, dtype=torch.int32, device=hidden_states.device
|
||||
)
|
||||
gating_output_float = gating_output.float() # TODO(woosuk): Optimize this.
|
||||
|
||||
if scoring_func == "softmax":
|
||||
topk_weights, topk_ids = vllm_topk_softmax(
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
token_expert_indices,
|
||||
gating_output,
|
||||
gating_output_float,
|
||||
renormalize,
|
||||
e_score_correction_bias,
|
||||
)
|
||||
@@ -186,7 +187,7 @@ class FusedTopKBiasRouter(BaseRouter):
|
||||
indices_type=indices_type,
|
||||
)
|
||||
|
||||
if self.routed_scaling_factor != 1.0:
|
||||
topk_weights *= self.routed_scaling_factor
|
||||
# if self.routed_scaling_factor != 1.0:
|
||||
# topk_weights *= self.routed_scaling_factor
|
||||
|
||||
return topk_weights, topk_ids
|
||||
|
||||
@@ -26,8 +26,9 @@ def vllm_topk_softmax(
|
||||
topk_indices,
|
||||
token_expert_indices,
|
||||
gating_output,
|
||||
renormalize,
|
||||
)
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
return topk_weights, topk_indices
|
||||
|
||||
@@ -90,13 +91,14 @@ def fused_topk(
|
||||
token_expert_indices = torch.empty(
|
||||
M, topk, dtype=torch.int32, device=hidden_states.device
|
||||
)
|
||||
gating_output_float = gating_output.float()
|
||||
|
||||
if scoring_func == "softmax":
|
||||
topk_func = dispatch_topk_softmax_func(
|
||||
use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled()
|
||||
)
|
||||
topk_weights, topk_ids = topk_func(
|
||||
topk_weights, topk_ids, token_expert_indices, gating_output.float(), renormalize
|
||||
topk_weights, topk_ids, token_expert_indices, gating_output_float, renormalize
|
||||
)
|
||||
|
||||
return topk_weights, topk_ids, token_expert_indices
|
||||
@@ -105,7 +107,7 @@ def fused_topk(
|
||||
use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled()
|
||||
)
|
||||
topk_weights, topk_ids = topk_func(
|
||||
topk_weights, topk_ids, token_expert_indices, gating_output.float(), renormalize
|
||||
topk_weights, topk_ids, token_expert_indices, gating_output_float, renormalize
|
||||
)
|
||||
|
||||
return topk_weights, topk_ids, token_expert_indices
|
||||
|
||||
115
vllm/model_executor/layers/fused_moe/router/gate_linear.py
Normal file
115
vllm/model_executor/layers/fused_moe/router/gate_linear.py
Normal file
@@ -0,0 +1,115 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.model_executor.custom_op import PluggableLayer
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@PluggableLayer.register("gate_linear")
|
||||
class GateLinear(ReplicatedLinear):
|
||||
"""MoE gate linear layer with three-tier GEMM dispatch:
|
||||
|
||||
1. DSV3 specialized kernel (SM90+, batch<=16, supported dims)
|
||||
2. cuBLAS bf16×bf16→fp32 (SM90+ + bf16 + fp32 out_dtype)
|
||||
3. F.linear via ReplicatedLinear (ultimate fallback)
|
||||
|
||||
The ``out_dtype`` attribute is mutable and can be set after init
|
||||
(e.g. when the required dtype depends on the expert quantization
|
||||
method which is only known later).
|
||||
"""
|
||||
|
||||
# Dimensions supported by the DSV3 specialized kernel
|
||||
DSV3_SUPPORTED_NUM_EXPERTS = [256, 384]
|
||||
DSV3_SUPPORTED_HIDDEN_SIZES = [7168]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = False,
|
||||
out_dtype: torch.dtype | None = None,
|
||||
params_dtype: torch.dtype | None = None,
|
||||
force_fp32_compute: bool = False,
|
||||
prefix: str = "",
|
||||
):
|
||||
is_hopper_or_blackwell = current_platform.is_device_capability(
|
||||
(9, 0)
|
||||
) or current_platform.is_device_capability_family(100)
|
||||
can_use_specialized_kernels = False
|
||||
|
||||
# If fp32 compute is required and no specialized kernel is available,
|
||||
# store weights in fp32 so Tier 3 computes in fp32 natively.
|
||||
if force_fp32_compute and not can_use_specialized_kernels:
|
||||
params_dtype = torch.float32
|
||||
|
||||
super().__init__(
|
||||
input_size,
|
||||
output_size,
|
||||
bias=bias,
|
||||
params_dtype=params_dtype,
|
||||
quant_config=None,
|
||||
prefix=prefix,
|
||||
)
|
||||
self.out_dtype = out_dtype
|
||||
|
||||
# DSV3 specialized kernel eligibility (SM90+, exact dims)
|
||||
self.allow_specialized_router_gemm = can_use_specialized_kernels
|
||||
self.allow_dsv3_router_gemm = (
|
||||
self.allow_specialized_router_gemm
|
||||
and output_size in self.DSV3_SUPPORTED_NUM_EXPERTS
|
||||
and input_size in self.DSV3_SUPPORTED_HIDDEN_SIZES
|
||||
)
|
||||
|
||||
# cuBLAS bf16→fp32 eligibility
|
||||
self.allow_cublas_router_gemm = (
|
||||
self.allow_specialized_router_gemm
|
||||
and self.weight.dtype == torch.bfloat16
|
||||
and self.out_dtype == torch.float32
|
||||
)
|
||||
|
||||
def set_out_dtype(self, out_dtype: torch.dtype) -> None:
|
||||
"""Set output dtype for the router logits after init.
|
||||
|
||||
Useful when the required dtype depends on the expert quantization
|
||||
method which is only known after the gate is constructed.
|
||||
"""
|
||||
if self.out_dtype is not None:
|
||||
raise ValueError("out_dtype has already been set")
|
||||
self.out_dtype = out_dtype
|
||||
|
||||
if (
|
||||
not self.allow_cublas_router_gemm
|
||||
and self.allow_specialized_router_gemm
|
||||
and out_dtype == torch.float32
|
||||
):
|
||||
self.allow_cublas_router_gemm = self.weight.dtype == torch.bfloat16
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor
|
||||
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
|
||||
import vllm._custom_ops as ops
|
||||
|
||||
# Tier 1: DSV3 specialized kernel
|
||||
if self.allow_dsv3_router_gemm and x.shape[0] <= 16:
|
||||
output = ops.dsv3_router_gemm(
|
||||
hidden_states=x,
|
||||
router_weight=self.weight,
|
||||
output_dtype=self.out_dtype,
|
||||
)
|
||||
return output, None
|
||||
|
||||
# Tier 2: cuBLAS bf16→fp32
|
||||
if self.allow_cublas_router_gemm and x.dtype == torch.bfloat16:
|
||||
output = ops.router_gemm_bf16_fp32(x, self.weight)
|
||||
return output, None
|
||||
|
||||
# Tier 3: F.linear (ReplicatedLinear)
|
||||
if self.out_dtype is not None and x.dtype != self.weight.dtype:
|
||||
x = x.to(self.weight.dtype)
|
||||
output, output_bias = super().forward(x)
|
||||
if self.out_dtype is not None and output.dtype != self.out_dtype:
|
||||
output = output.to(self.out_dtype)
|
||||
return output, output_bias
|
||||
@@ -92,77 +92,9 @@ def grouped_topk(
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if (
|
||||
envs.VLLM_USE_FUSED_MOE_GROUPED_TOPK
|
||||
and current_platform.is_cuda()
|
||||
and num_expert_group <= 32
|
||||
and topk <= 32
|
||||
and e_score_correction_bias is not None
|
||||
):
|
||||
return fused_grouped_topk(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=gating_output,
|
||||
topk=topk,
|
||||
renormalize=renormalize,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
|
||||
assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch"
|
||||
|
||||
if scoring_func == "softmax":
|
||||
scores = torch.softmax(gating_output, dim=-1)
|
||||
elif scoring_func == "sigmoid":
|
||||
scores = gating_output.sigmoid()
|
||||
else:
|
||||
raise ValueError(f"Unsupported scoring function: {scoring_func}")
|
||||
|
||||
num_token = scores.size(0)
|
||||
if e_score_correction_bias is not None:
|
||||
# Store original scores before applying correction bias. We use biased
|
||||
# scores for expert selection but original scores for routing weights
|
||||
original_scores = scores
|
||||
scores = scores + e_score_correction_bias.unsqueeze(0)
|
||||
group_scores = (
|
||||
scores.view(num_token, num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
|
||||
)
|
||||
else:
|
||||
group_scores = (
|
||||
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
||||
) # [n, n_group]
|
||||
|
||||
# For batch invariance, use sorted=True to ensure deterministic expert selection
|
||||
use_sorted = vllm_is_batch_invariant()
|
||||
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[
|
||||
1
|
||||
] # [n, top_k_group]
|
||||
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
||||
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
||||
score_mask = (
|
||||
group_mask.unsqueeze(-1)
|
||||
.expand(num_token, num_expert_group, scores.size(-1) // num_expert_group)
|
||||
.reshape(num_token, -1)
|
||||
) # [n, e]
|
||||
tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e]
|
||||
|
||||
if e_score_correction_bias is not None:
|
||||
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=use_sorted)[1]
|
||||
# Use original unbiased scores for the routing weights
|
||||
topk_weights = original_scores.gather(1, topk_ids)
|
||||
else:
|
||||
topk_weights, topk_ids = torch.topk(
|
||||
tmp_scores, k=topk, dim=-1, sorted=use_sorted
|
||||
)
|
||||
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
if routed_scaling_factor != 1.0:
|
||||
topk_weights = topk_weights * routed_scaling_factor
|
||||
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
||||
from ixformer.inference.functions import moe_grouped_topk as grouped_topk
|
||||
topk_weights, topk_ids = grouped_topk(gating_output, topk, num_expert_group, topk_group, scoring_func, e_score_correction_bias,renormalize = renormalize)
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
# --8<-- [start:grouped_topk]
|
||||
@@ -246,7 +178,6 @@ class GroupedTopk(CustomOp):
|
||||
hidden_states, gating_output, e_score_correction_bias
|
||||
)
|
||||
|
||||
from ixformer.inference.functions import moe_grouped_topk as grouped_topk
|
||||
|
||||
class GroupedTopKRouter(BaseRouter):
|
||||
"""Router using grouped top-k routing (e.g., DeepSeekV2/V3)."""
|
||||
@@ -316,8 +247,8 @@ class GroupedTopKRouter(BaseRouter):
|
||||
topk=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
)
|
||||
if self.routed_scaling_factor != 1.0:
|
||||
topk_weights *= self.routed_scaling_factor
|
||||
# if self.routed_scaling_factor != 1.0:
|
||||
# topk_weights *= self.routed_scaling_factor
|
||||
else:
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
hidden_states=hidden_states,
|
||||
@@ -340,14 +271,14 @@ class GroupedTopKRouter(BaseRouter):
|
||||
grouped_topk_impl = grouped_topk
|
||||
|
||||
topk_weights, topk_ids = grouped_topk_impl(
|
||||
# hidden_states=hidden_states,
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
topk=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
num_expert_group=self.num_expert_group,
|
||||
topk_group=self.topk_group,
|
||||
scoring_func=self.scoring_func,
|
||||
# routed_scaling_factor=self.routed_scaling_factor,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
e_score_correction_bias=self.e_score_correction_bias,
|
||||
)
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ def create_fused_moe_router(
|
||||
# grouped topk + fused topk bias parameters
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
# custom routing paramaters
|
||||
# custom routing parameters
|
||||
custom_routing_function: Callable | None = None,
|
||||
# eplb parameters
|
||||
enable_eplb: bool = False,
|
||||
|
||||
Reference in New Issue
Block a user