update
This commit is contained in:
2
vllm/model_executor/layers/fused_moe/router/__init__.py
Normal file
2
vllm/model_executor/layers/fused_moe/router/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
249
vllm/model_executor/layers/fused_moe/router/base_router.py
Normal file
249
vllm/model_executor/layers/fused_moe/router/base_router.py
Normal file
@@ -0,0 +1,249 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed.eplb.eplb_state import EplbLayerState
|
||||
from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
|
||||
FusedMoERouter,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
|
||||
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
|
||||
def eplb_map_to_physical_and_record(
|
||||
topk_ids: torch.Tensor,
|
||||
expert_load_view: torch.Tensor,
|
||||
logical_to_physical_map: torch.Tensor,
|
||||
logical_replica_count: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Map the logical expert ids to physical expert ids
|
||||
and record the expert load metrics.
|
||||
|
||||
This will select a pseudo-random replica for each logical expert.
|
||||
Only used for EPLB.
|
||||
|
||||
Args:
|
||||
topk_ids: The logical expert ids.
|
||||
expert_load_view: The expert load view.
|
||||
logical_to_physical_map: The logical to physical map.
|
||||
logical_replica_count: The logical replica count.
|
||||
|
||||
Returns:
|
||||
The physical expert ids.
|
||||
"""
|
||||
|
||||
# 1. Convert the logical expert ids to physical expert ids
|
||||
# Directly select a random replica for each logical expert
|
||||
|
||||
# In case `indices_type` is not `torch.long` or `torch.int`,
|
||||
# e.g. `torch.uint32` as required by dispatch/combine kernels
|
||||
topk_ids_long = topk_ids.long()
|
||||
# Use (token position) modulo (replica count)
|
||||
# to deterministically choose a replica
|
||||
replica_count = logical_replica_count[topk_ids_long]
|
||||
# Flatten-position based index, reshaped back to `topk_ids` shape
|
||||
pos_indices = torch.arange(
|
||||
topk_ids.numel(), device=topk_ids.device, dtype=torch.long
|
||||
).reshape_as(topk_ids)
|
||||
# Compute pseudo-random indices by modulo
|
||||
replica_indices = (pos_indices % replica_count).unsqueeze(-1)
|
||||
physical_ids = (
|
||||
logical_to_physical_map[topk_ids_long]
|
||||
.gather(-1, replica_indices)
|
||||
.squeeze(-1)
|
||||
)
|
||||
|
||||
topk_ids = physical_ids
|
||||
|
||||
# 2. Record expert load metrics.
|
||||
|
||||
# TODO(bowen): When using `FusedMoEModularKernel`, this
|
||||
# can be done in a more unified way, since
|
||||
# `FusedMoEPrepareAndFinalize` will return the expert
|
||||
# token count, in some cases directly from the kernel.
|
||||
# However, now there are many code paths not using
|
||||
# the modular kernel, e.g. calling `fused_experts`,
|
||||
# so we decide to keep the logic here.
|
||||
#
|
||||
# If later refactor moved all the MoE kernel calls
|
||||
# to the modular kernel, we can move this logic there
|
||||
# to achieve better efficiency.
|
||||
|
||||
# `expert_load_view`: (num_physical_experts,)
|
||||
|
||||
# `torch.bincount` is not compilable, so use `scatter_add_` instead.
|
||||
topk_ids_flatten = topk_ids.flatten()
|
||||
expert_load_view.scatter_add_(
|
||||
dim=0,
|
||||
index=topk_ids_flatten.long(),
|
||||
src=torch.ones_like(topk_ids_flatten).to(expert_load_view),
|
||||
)
|
||||
return topk_ids
|
||||
else:
|
||||
|
||||
def eplb_map_to_physical_and_record(
|
||||
topk_ids: torch.Tensor,
|
||||
expert_load_view: torch.Tensor,
|
||||
logical_to_physical_map: torch.Tensor,
|
||||
logical_replica_count: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# CPU fallback: no EPLB so just return as is
|
||||
return topk_ids
|
||||
|
||||
|
||||
class BaseRouter(FusedMoERouter):
|
||||
"""
|
||||
Base router class that provides common functionality for all router implementations.
|
||||
|
||||
This class implements the template method pattern where select_experts() handles
|
||||
common pre-processing and post-processing, delegating the actual routing logic
|
||||
to the abstract _compute_routing() method.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
top_k: int,
|
||||
global_num_experts: int,
|
||||
eplb_state: EplbLayerState,
|
||||
enable_eplb: bool = False,
|
||||
# TODO(bnell): Once the MK is constructed at layer init time, we
|
||||
# can make this a plain value instead of a callback.
|
||||
indices_type_getter: Callable[[], torch.dtype | None] | None = None,
|
||||
):
|
||||
"""
|
||||
Note: the indices dtype might not be available at router construction
|
||||
time, so we need to supply a callback to get it at runtime. This is
|
||||
because the indices type is supplied by modular kernels which are
|
||||
created after MoE layer/router construction.
|
||||
"""
|
||||
super().__init__()
|
||||
self.top_k = top_k
|
||||
self.global_num_experts = global_num_experts
|
||||
self.eplb_state = eplb_state
|
||||
self.enable_eplb = enable_eplb
|
||||
self.indices_type_getter = indices_type_getter
|
||||
self.capture_fn: Callable[[torch.Tensor], None] | None = None
|
||||
|
||||
def set_capture_fn(self, capture_fn: Callable[[torch.Tensor], None] | None) -> None:
|
||||
"""Set a capture callback for logical routed expert IDs."""
|
||||
self.capture_fn = capture_fn
|
||||
|
||||
def _validate_eplb_state(self) -> None:
|
||||
"""Validate that EPLB state is properly initialized if EPLB is enabled."""
|
||||
if self.enable_eplb:
|
||||
if self.eplb_state.expert_load_view is None:
|
||||
raise ValueError("enable_eplb=True requires expert_load_view != None")
|
||||
if self.eplb_state.logical_to_physical_map is None:
|
||||
raise ValueError(
|
||||
"enable_eplb=True requires logical_to_physical_map != None"
|
||||
)
|
||||
if self.eplb_state.logical_replica_count is None:
|
||||
raise ValueError(
|
||||
"enable_eplb=True requires logical_replica_count != None"
|
||||
)
|
||||
|
||||
def _get_indices_type(self) -> torch.dtype | None:
|
||||
"""Get the desired indices dtype from the getter function."""
|
||||
return (
|
||||
self.indices_type_getter() if self.indices_type_getter is not None else None
|
||||
)
|
||||
|
||||
def _apply_eplb_mapping(self, topk_ids: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply EPLB mapping to convert logical expert IDs to physical expert IDs."""
|
||||
if self.enable_eplb:
|
||||
assert self.eplb_state.expert_load_view is not None
|
||||
assert self.eplb_state.logical_to_physical_map is not None
|
||||
assert self.eplb_state.logical_replica_count is not None
|
||||
return eplb_map_to_physical_and_record(
|
||||
topk_ids=topk_ids,
|
||||
expert_load_view=self.eplb_state.expert_load_view,
|
||||
logical_to_physical_map=self.eplb_state.logical_to_physical_map,
|
||||
logical_replica_count=self.eplb_state.logical_replica_count,
|
||||
)
|
||||
return topk_ids
|
||||
|
||||
def _convert_indices_dtype(
|
||||
self, topk_ids: torch.Tensor, indices_type: torch.dtype | None
|
||||
) -> torch.Tensor:
|
||||
"""Convert topk_ids to the desired dtype if needed."""
|
||||
if (indices_type is not None) and topk_ids.dtype != indices_type:
|
||||
topk_ids = topk_ids.to(dtype=indices_type)
|
||||
|
||||
assert topk_ids.dtype == indices_type or indices_type is None
|
||||
return topk_ids
|
||||
|
||||
@abstractmethod
|
||||
def _compute_routing(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
indices_type: torch.dtype | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Compute the actual routing logic.
|
||||
|
||||
This method must be implemented by subclasses to provide the specific
|
||||
routing algorithm (e.g., grouped_topk, fused_topk, custom routing, etc.).
|
||||
|
||||
Args:
|
||||
hidden_states: Input hidden states
|
||||
router_logits: Router logits for expert selection
|
||||
indices_type: Desired dtype for expert indices (may be None)
|
||||
|
||||
Returns:
|
||||
tuple of (topk_weights, topk_ids)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def select_experts(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Route the input hidden states to the top-k experts based on the
|
||||
router logits.
|
||||
|
||||
This method implements the template method pattern:
|
||||
1. Validates EPLB state
|
||||
2. Gets indices type
|
||||
3. Calls _compute_routing() to get topk_weights and topk_ids
|
||||
4. Applies EPLB mapping if enabled
|
||||
5. Converts indices dtype if needed
|
||||
|
||||
Returns:
|
||||
(topk_weights, topk_ids)
|
||||
(tuple[torch.Tensor, torch.Tensor]):
|
||||
The weights and expert ids computation result.
|
||||
|
||||
**Compatibility**: When EPLB is not enabled, the returned ids are
|
||||
equivalent to global logical ids, so should be compatible with
|
||||
plain MoE implementations without redundant experts.
|
||||
"""
|
||||
# Step 1: Validate EPLB state
|
||||
self._validate_eplb_state()
|
||||
|
||||
# Step 2: Get indices type.
|
||||
indices_type = self._get_indices_type()
|
||||
|
||||
# Step 3: Compute routing (delegated to subclass)
|
||||
topk_weights, topk_ids = self._compute_routing(
|
||||
hidden_states, router_logits, indices_type
|
||||
)
|
||||
|
||||
# Capture logical ids before EPLB mapping.
|
||||
if self.capture_fn is not None:
|
||||
self.capture_fn(topk_ids)
|
||||
|
||||
# Step 4: Apply EPLB mapping
|
||||
topk_ids = self._apply_eplb_mapping(topk_ids)
|
||||
|
||||
# Step 5: Convert indices dtype
|
||||
topk_ids = self._convert_indices_dtype(topk_ids, indices_type)
|
||||
|
||||
return topk_weights, topk_ids
|
||||
@@ -0,0 +1,60 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed.eplb.eplb_state import EplbLayerState
|
||||
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
|
||||
from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter
|
||||
|
||||
|
||||
class CustomRoutingRouter(BaseRouter):
|
||||
"""Router using a custom user-provided routing function."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
top_k: int,
|
||||
global_num_experts: int,
|
||||
eplb_state: EplbLayerState,
|
||||
custom_routing_function: Callable,
|
||||
renormalize: bool = True,
|
||||
enable_eplb: bool = False,
|
||||
indices_type_getter: Callable[[], torch.dtype | None] | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
eplb_state=eplb_state,
|
||||
enable_eplb=enable_eplb,
|
||||
indices_type_getter=indices_type_getter,
|
||||
)
|
||||
self.custom_routing_function = custom_routing_function
|
||||
self.renormalize = renormalize
|
||||
|
||||
@property
|
||||
def routing_method_type(self) -> RoutingMethodType:
|
||||
from vllm.model_executor.models.llama4 import Llama4MoE
|
||||
|
||||
# NOTE: FLASHINFER_TRTLLM support the Llama4 router.
|
||||
if self.custom_routing_function == Llama4MoE.custom_routing_function:
|
||||
return RoutingMethodType.Llama4
|
||||
return RoutingMethodType.Custom
|
||||
|
||||
def _compute_routing(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
indices_type: torch.dtype | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute routing using the custom routing function."""
|
||||
topk_weights, topk_ids = self.custom_routing_function(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
topk=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
)
|
||||
|
||||
return topk_weights.to(torch.float32), topk_ids.to(
|
||||
torch.int32 if indices_type is None else indices_type
|
||||
)
|
||||
@@ -0,0 +1,40 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
|
||||
|
||||
|
||||
class FusedMoERouter(ABC):
|
||||
"""
|
||||
FusedMoERouter is an abstract class that provides a 'select_experts'
|
||||
method that is used for routing hidden states based on router logits.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def routing_method_type(self) -> RoutingMethodType:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def select_experts(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Route the input hidden states to the top-k experts based on the
|
||||
router logits.
|
||||
|
||||
Returns:
|
||||
(topk_weights, topk_ids)
|
||||
(tuple[torch.Tensor, torch.Tensor]):
|
||||
The weights and expert ids computation result.
|
||||
|
||||
**Compatibility**: When EPLB is not enabled, the returned ids are
|
||||
equivalent to global logical ids, so should be compatible with
|
||||
plain MoE implementations without redundant experts.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,192 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.distributed.eplb.eplb_state import EplbLayerState
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
RoutingMethodType,
|
||||
get_routing_method_type,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter
|
||||
|
||||
|
||||
def vllm_topk_softmax(
|
||||
topk_weights: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
token_expert_indices: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
renormalize: bool = False,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
ops.topk_softmax(
|
||||
topk_weights,
|
||||
topk_indices,
|
||||
token_expert_indices,
|
||||
gating_output,
|
||||
renormalize,
|
||||
e_score_correction_bias,
|
||||
)
|
||||
|
||||
return topk_weights, topk_indices
|
||||
|
||||
|
||||
def vllm_topk_sigmoid(
|
||||
topk_weights: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
token_expert_indices: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
renormalize: bool = False,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
ops.topk_sigmoid(
|
||||
topk_weights,
|
||||
topk_indices,
|
||||
token_expert_indices,
|
||||
gating_output,
|
||||
renormalize,
|
||||
e_score_correction_bias,
|
||||
)
|
||||
|
||||
return topk_weights, topk_indices
|
||||
|
||||
|
||||
def fused_topk_bias(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
e_score_correction_bias: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
scoring_func: str = "softmax",
|
||||
indices_type: torch.dtype | None = None,
|
||||
):
|
||||
if not rocm_aiter_ops.is_fused_moe_enabled():
|
||||
assert hidden_states.size(0) == gating_output.size(0), (
|
||||
"Number of tokens mismatch"
|
||||
)
|
||||
|
||||
M, _ = hidden_states.size()
|
||||
|
||||
topk_weights = torch.empty(
|
||||
M, topk, dtype=torch.float32, device=hidden_states.device
|
||||
)
|
||||
topk_ids = torch.empty(
|
||||
M,
|
||||
topk,
|
||||
dtype=torch.int32 if indices_type is None else indices_type,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
token_expert_indices = torch.empty(
|
||||
M, topk, dtype=torch.int32, device=hidden_states.device
|
||||
)
|
||||
|
||||
if scoring_func == "softmax":
|
||||
topk_weights, topk_ids = vllm_topk_softmax(
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
token_expert_indices,
|
||||
gating_output,
|
||||
renormalize,
|
||||
e_score_correction_bias,
|
||||
)
|
||||
return topk_weights, topk_ids
|
||||
elif scoring_func == "sigmoid":
|
||||
topk_weights, topk_ids = vllm_topk_sigmoid(
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
token_expert_indices,
|
||||
gating_output,
|
||||
renormalize,
|
||||
e_score_correction_bias,
|
||||
)
|
||||
return topk_weights, topk_ids
|
||||
else:
|
||||
raise ValueError(f"Unsupported scoring function: {scoring_func}")
|
||||
|
||||
n_routed_experts = gating_output.shape[-1]
|
||||
if scoring_func == "softmax":
|
||||
scores = gating_output.softmax(dim=-1)
|
||||
elif scoring_func == "sigmoid":
|
||||
scores = gating_output.sigmoid()
|
||||
else:
|
||||
raise ValueError(f"Unsupported scoring function: {scoring_func}")
|
||||
|
||||
scores_for_choice = scores.view(
|
||||
-1, n_routed_experts
|
||||
) + e_score_correction_bias.unsqueeze(0)
|
||||
|
||||
# For batch invariance, use sorted=True to ensure deterministic expert selection
|
||||
use_sorted = vllm_is_batch_invariant()
|
||||
topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=use_sorted)[1]
|
||||
topk_weights = scores.gather(1, topk_indices)
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
return topk_weights.to(torch.float32), topk_indices.to(
|
||||
torch.int32 if indices_type is None else indices_type
|
||||
)
|
||||
|
||||
|
||||
class FusedTopKBiasRouter(BaseRouter):
|
||||
"""Router using fused top-k with e_score_correction_bias."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
top_k: int,
|
||||
global_num_experts: int,
|
||||
eplb_state: EplbLayerState,
|
||||
e_score_correction_bias: torch.Tensor,
|
||||
scoring_func: str,
|
||||
renormalize: bool = True,
|
||||
routed_scaling_factor: float = 1.0,
|
||||
enable_eplb: bool = False,
|
||||
indices_type_getter: Callable[[], torch.dtype | None] | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
eplb_state=eplb_state,
|
||||
enable_eplb=enable_eplb,
|
||||
indices_type_getter=indices_type_getter,
|
||||
)
|
||||
self.e_score_correction_bias = e_score_correction_bias
|
||||
self.renormalize = renormalize
|
||||
self.scoring_func = scoring_func
|
||||
self.routed_scaling_factor = routed_scaling_factor
|
||||
|
||||
@property
|
||||
def routing_method_type(self) -> RoutingMethodType:
|
||||
return get_routing_method_type(
|
||||
scoring_func=self.scoring_func,
|
||||
top_k=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
num_expert_group=None,
|
||||
has_e_score_bias=True,
|
||||
)
|
||||
|
||||
def _compute_routing(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
indices_type: torch.dtype | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute routing using fused top-k with bias."""
|
||||
topk_weights, topk_ids = fused_topk_bias(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
e_score_correction_bias=self.e_score_correction_bias.data,
|
||||
topk=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
scoring_func=self.scoring_func,
|
||||
indices_type=indices_type,
|
||||
)
|
||||
|
||||
if self.routed_scaling_factor != 1.0:
|
||||
topk_weights *= self.routed_scaling_factor
|
||||
|
||||
return topk_weights, topk_ids
|
||||
165
vllm/model_executor/layers/fused_moe/router/fused_topk_router.py
Normal file
165
vllm/model_executor/layers/fused_moe/router/fused_topk_router.py
Normal file
@@ -0,0 +1,165 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.distributed.eplb.eplb_state import EplbLayerState
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
RoutingMethodType,
|
||||
get_routing_method_type,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter
|
||||
|
||||
|
||||
def vllm_topk_softmax(
|
||||
topk_weights: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
token_expert_indices: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
renormalize: bool = False,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
ops.topk_softmax(
|
||||
topk_weights,
|
||||
topk_indices,
|
||||
token_expert_indices,
|
||||
gating_output,
|
||||
renormalize,
|
||||
)
|
||||
|
||||
return topk_weights, topk_indices
|
||||
|
||||
|
||||
def vllm_topk_sigmoid(
|
||||
topk_weights: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
token_expert_indices: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
renormalize: bool = False,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
ops.topk_sigmoid(
|
||||
topk_weights,
|
||||
topk_indices,
|
||||
token_expert_indices,
|
||||
gating_output,
|
||||
renormalize,
|
||||
)
|
||||
|
||||
return topk_weights, topk_indices
|
||||
|
||||
|
||||
def dispatch_topk_softmax_func(
|
||||
use_rocm_aiter: bool = False,
|
||||
) -> Callable[..., tuple[torch.Tensor, ...]]:
|
||||
if use_rocm_aiter:
|
||||
return rocm_aiter_ops.topk_softmax
|
||||
return vllm_topk_softmax
|
||||
|
||||
|
||||
def dispatch_topk_sigmoid_func(
|
||||
use_rocm_aiter: bool = False,
|
||||
) -> Callable[..., tuple[torch.Tensor, ...]]:
|
||||
if use_rocm_aiter:
|
||||
return rocm_aiter_ops.topk_sigmoid
|
||||
return vllm_topk_sigmoid
|
||||
|
||||
|
||||
def fused_topk(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
indices_type: torch.dtype | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch"
|
||||
|
||||
M, _ = hidden_states.size()
|
||||
|
||||
topk_weights = torch.empty(
|
||||
M, topk, dtype=torch.float32, device=hidden_states.device
|
||||
)
|
||||
topk_ids = torch.empty(
|
||||
M,
|
||||
topk,
|
||||
dtype=torch.int32 if indices_type is None else indices_type,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
token_expert_indices = torch.empty(
|
||||
M, topk, dtype=torch.int32, device=hidden_states.device
|
||||
)
|
||||
|
||||
if scoring_func == "softmax":
|
||||
topk_func = dispatch_topk_softmax_func(
|
||||
use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled()
|
||||
)
|
||||
topk_weights, topk_ids = topk_func(
|
||||
topk_weights, topk_ids, token_expert_indices, gating_output.float(), renormalize
|
||||
)
|
||||
|
||||
return topk_weights, topk_ids, token_expert_indices
|
||||
elif scoring_func == "sigmoid":
|
||||
topk_func = dispatch_topk_sigmoid_func(
|
||||
use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled()
|
||||
)
|
||||
topk_weights, topk_ids = topk_func(
|
||||
topk_weights, topk_ids, token_expert_indices, gating_output.float(), renormalize
|
||||
)
|
||||
|
||||
return topk_weights, topk_ids, token_expert_indices
|
||||
else:
|
||||
raise ValueError(f"Unsupported scoring function: {scoring_func}")
|
||||
|
||||
|
||||
class FusedTopKRouter(BaseRouter):
|
||||
"""Default router using standard fused top-k routing."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
top_k: int,
|
||||
global_num_experts: int,
|
||||
eplb_state: EplbLayerState,
|
||||
scoring_func: str = "softmax",
|
||||
renormalize: bool = True,
|
||||
enable_eplb: bool = False,
|
||||
indices_type_getter: Callable[[], torch.dtype | None] | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
eplb_state=eplb_state,
|
||||
enable_eplb=enable_eplb,
|
||||
indices_type_getter=indices_type_getter,
|
||||
)
|
||||
self.renormalize = renormalize
|
||||
self.scoring_func = scoring_func
|
||||
|
||||
@property
|
||||
def routing_method_type(self) -> RoutingMethodType:
|
||||
return get_routing_method_type(
|
||||
scoring_func=self.scoring_func,
|
||||
top_k=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
num_expert_group=None,
|
||||
has_e_score_bias=False,
|
||||
)
|
||||
|
||||
def _compute_routing(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
indices_type: torch.dtype | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute routing using standard fused top-k."""
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
topk=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
indices_type=indices_type,
|
||||
scoring_func=self.scoring_func,
|
||||
)
|
||||
|
||||
return topk_weights, topk_ids
|
||||
@@ -0,0 +1,354 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm import envs as envs
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.distributed.eplb.eplb_state import EplbLayerState
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
RoutingMethodType,
|
||||
get_routing_method_type,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
rocm_aiter_grouped_topk,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter
|
||||
from vllm.model_executor.layers.fused_moe.router.fused_topk_bias_router import (
|
||||
fused_topk_bias,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.router.fused_topk_router import fused_topk
|
||||
from vllm.model_executor.utils import maybe_disable_graph_partition
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def fused_grouped_topk(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
e_score_correction_bias: torch.Tensor,
|
||||
num_expert_group: int = 0,
|
||||
topk_group: int = 0,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch"
|
||||
|
||||
if scoring_func == "sigmoid":
|
||||
# Fully fused kernel path for sigmoid
|
||||
topk_values, topk_indices = ops.grouped_topk(
|
||||
gating_output, # raw logits
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
topk,
|
||||
renormalize,
|
||||
routed_scaling_factor,
|
||||
e_score_correction_bias,
|
||||
1, # scoring_func=1 for sigmoid
|
||||
)
|
||||
elif scoring_func == "softmax":
|
||||
# Apply softmax in Python, then use fused kernel
|
||||
# TODO: Add support for softmax in kernel
|
||||
scores = torch.softmax(gating_output, dim=-1)
|
||||
topk_values, topk_indices = ops.grouped_topk(
|
||||
scores, # pre-computed scores
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
topk,
|
||||
renormalize,
|
||||
routed_scaling_factor,
|
||||
e_score_correction_bias,
|
||||
0, # scoring_func=0 (no activation, scores already computed)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported scoring function: {scoring_func}")
|
||||
|
||||
# Fused kernel outputs float32 values and int32 indices directly
|
||||
return topk_values, topk_indices
|
||||
|
||||
|
||||
# This is used by the Deepseek-V2 and Deepseek-V3 model
|
||||
@torch.compile(
|
||||
dynamic=True,
|
||||
backend=current_platform.simple_compile_backend,
|
||||
options=maybe_disable_graph_partition(current_platform.simple_compile_backend),
|
||||
)
|
||||
def grouped_topk(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
num_expert_group: int = 0,
|
||||
topk_group: int = 0,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if (
|
||||
envs.VLLM_USE_FUSED_MOE_GROUPED_TOPK
|
||||
and current_platform.is_cuda()
|
||||
and num_expert_group <= 32
|
||||
and topk <= 32
|
||||
and e_score_correction_bias is not None
|
||||
):
|
||||
return fused_grouped_topk(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=gating_output,
|
||||
topk=topk,
|
||||
renormalize=renormalize,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
|
||||
assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch"
|
||||
|
||||
if scoring_func == "softmax":
|
||||
scores = torch.softmax(gating_output, dim=-1)
|
||||
elif scoring_func == "sigmoid":
|
||||
scores = gating_output.sigmoid()
|
||||
else:
|
||||
raise ValueError(f"Unsupported scoring function: {scoring_func}")
|
||||
|
||||
num_token = scores.size(0)
|
||||
if e_score_correction_bias is not None:
|
||||
# Store original scores before applying correction bias. We use biased
|
||||
# scores for expert selection but original scores for routing weights
|
||||
original_scores = scores
|
||||
scores = scores + e_score_correction_bias.unsqueeze(0)
|
||||
group_scores = (
|
||||
scores.view(num_token, num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
|
||||
)
|
||||
else:
|
||||
group_scores = (
|
||||
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
||||
) # [n, n_group]
|
||||
|
||||
# For batch invariance, use sorted=True to ensure deterministic expert selection
|
||||
use_sorted = vllm_is_batch_invariant()
|
||||
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[
|
||||
1
|
||||
] # [n, top_k_group]
|
||||
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
||||
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
||||
score_mask = (
|
||||
group_mask.unsqueeze(-1)
|
||||
.expand(num_token, num_expert_group, scores.size(-1) // num_expert_group)
|
||||
.reshape(num_token, -1)
|
||||
) # [n, e]
|
||||
tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e]
|
||||
|
||||
if e_score_correction_bias is not None:
|
||||
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=use_sorted)[1]
|
||||
# Use original unbiased scores for the routing weights
|
||||
topk_weights = original_scores.gather(1, topk_ids)
|
||||
else:
|
||||
topk_weights, topk_ids = torch.topk(
|
||||
tmp_scores, k=topk, dim=-1, sorted=use_sorted
|
||||
)
|
||||
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
if routed_scaling_factor != 1.0:
|
||||
topk_weights = topk_weights * routed_scaling_factor
|
||||
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
||||
|
||||
|
||||
# --8<-- [start:grouped_topk]
|
||||
@CustomOp.register("grouped_topk")
|
||||
class GroupedTopk(CustomOp):
|
||||
"""GroupedTopk used by the Deepseek-V2 and Deepseek-V3 model."""
|
||||
|
||||
# --8<-- [end:grouped_topk]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
num_expert_group: int = 0,
|
||||
topk_group: int = 0,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
num_fused_shared_experts: int = 0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.native_impl = grouped_topk
|
||||
self.topk = topk
|
||||
self.renormalize = renormalize
|
||||
self.num_expert_group = num_expert_group
|
||||
self.topk_group = topk_group
|
||||
self.scoring_func = scoring_func
|
||||
self.routed_scaling_factor = routed_scaling_factor
|
||||
self.num_fused_shared_experts = num_fused_shared_experts
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return self.native_impl(
|
||||
hidden_states,
|
||||
gating_output,
|
||||
self.topk,
|
||||
self.renormalize,
|
||||
self.num_expert_group,
|
||||
self.topk_group,
|
||||
self.scoring_func,
|
||||
self.routed_scaling_factor,
|
||||
e_score_correction_bias,
|
||||
)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return self.forward_native(
|
||||
hidden_states, gating_output, e_score_correction_bias
|
||||
)
|
||||
|
||||
def forward_hip(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if rocm_aiter_ops.is_fused_moe_enabled():
|
||||
if not rocm_aiter_ops.is_fusion_moe_shared_experts_enabled():
|
||||
assert self.num_fused_shared_experts == 0
|
||||
return rocm_aiter_grouped_topk(
|
||||
hidden_states,
|
||||
gating_output,
|
||||
self.topk,
|
||||
self.renormalize,
|
||||
self.num_expert_group,
|
||||
self.topk_group,
|
||||
self.scoring_func,
|
||||
self.routed_scaling_factor,
|
||||
e_score_correction_bias,
|
||||
self.num_fused_shared_experts,
|
||||
)
|
||||
else:
|
||||
return self.forward_native(
|
||||
hidden_states, gating_output, e_score_correction_bias
|
||||
)
|
||||
|
||||
from ixformer.inference.functions import moe_grouped_topk as grouped_topk
|
||||
|
||||
class GroupedTopKRouter(BaseRouter):
|
||||
"""Router using grouped top-k routing (e.g., DeepSeekV2/V3)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
top_k: int,
|
||||
global_num_experts: int,
|
||||
eplb_state: EplbLayerState,
|
||||
num_expert_group: int,
|
||||
topk_group: int,
|
||||
renormalize: bool = True,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
enable_eplb: bool = False,
|
||||
indices_type_getter: Callable[[], torch.dtype | None] | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
eplb_state=eplb_state,
|
||||
enable_eplb=enable_eplb,
|
||||
indices_type_getter=indices_type_getter,
|
||||
)
|
||||
self.num_expert_group = num_expert_group
|
||||
self.topk_group = topk_group
|
||||
self.renormalize = renormalize
|
||||
self.scoring_func = scoring_func
|
||||
self.routed_scaling_factor = routed_scaling_factor
|
||||
self.e_score_correction_bias = e_score_correction_bias
|
||||
self.num_fused_shared_experts = num_fused_shared_experts
|
||||
|
||||
@property
|
||||
def routing_method_type(self) -> RoutingMethodType:
|
||||
return get_routing_method_type(
|
||||
scoring_func=self.scoring_func,
|
||||
top_k=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
num_expert_group=self.num_expert_group,
|
||||
has_e_score_bias=self.e_score_correction_bias is not None,
|
||||
)
|
||||
|
||||
def _compute_routing(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
indices_type: torch.dtype | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute routing using grouped top-k."""
|
||||
|
||||
def valid_grouping() -> bool:
|
||||
# Check if num_experts is greater than num_expert_group
|
||||
# and is divisible by num_expert_group
|
||||
num_experts = router_logits.shape[-1]
|
||||
if num_experts <= self.num_expert_group:
|
||||
return False
|
||||
return num_experts % self.num_expert_group == 0
|
||||
|
||||
if not valid_grouping():
|
||||
if self.e_score_correction_bias is not None:
|
||||
topk_weights, topk_ids = fused_topk_bias(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
e_score_correction_bias=self.e_score_correction_bias.data,
|
||||
topk=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
)
|
||||
if self.routed_scaling_factor != 1.0:
|
||||
topk_weights *= self.routed_scaling_factor
|
||||
else:
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
topk=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
indices_type=indices_type,
|
||||
)
|
||||
return topk_weights, topk_ids
|
||||
|
||||
# Select grouped_topk implementation
|
||||
if rocm_aiter_ops.is_fused_moe_enabled():
|
||||
if not rocm_aiter_ops.is_fusion_moe_shared_experts_enabled():
|
||||
assert self.num_fused_shared_experts == 0
|
||||
grouped_topk_impl = partial(
|
||||
rocm_aiter_grouped_topk,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
)
|
||||
else:
|
||||
grouped_topk_impl = grouped_topk
|
||||
|
||||
topk_weights, topk_ids = grouped_topk_impl(
|
||||
# hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
topk=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
num_expert_group=self.num_expert_group,
|
||||
topk_group=self.topk_group,
|
||||
scoring_func=self.scoring_func,
|
||||
# routed_scaling_factor=self.routed_scaling_factor,
|
||||
e_score_correction_bias=self.e_score_correction_bias,
|
||||
)
|
||||
|
||||
return topk_weights, topk_ids
|
||||
169
vllm/model_executor/layers/fused_moe/router/router_factory.py
Normal file
169
vllm/model_executor/layers/fused_moe/router/router_factory.py
Normal file
@@ -0,0 +1,169 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.eplb.eplb_state import EplbLayerState
|
||||
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
|
||||
from vllm.model_executor.layers.fused_moe.router.custom_routing_router import (
|
||||
CustomRoutingRouter,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
|
||||
FusedMoERouter,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.router.fused_topk_bias_router import (
|
||||
FusedTopKBiasRouter,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.router.fused_topk_router import (
|
||||
FusedTopKRouter,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.router.grouped_topk_router import (
|
||||
GroupedTopKRouter,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.router.routing_simulator_router import (
|
||||
RoutingSimulatorRouter,
|
||||
)
|
||||
|
||||
EMPTY_EPLB_STATE: EplbLayerState = EplbLayerState()
|
||||
|
||||
|
||||
def create_fused_moe_router(
|
||||
# common parameters
|
||||
top_k: int,
|
||||
global_num_experts: int,
|
||||
renormalize: bool = True,
|
||||
indices_type_getter: Callable[[], torch.dtype | None] | None = None,
|
||||
# grouped topk parameters
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: int | None = None,
|
||||
topk_group: int | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
num_fused_shared_experts: int = 0,
|
||||
# grouped topk + fused topk bias parameters
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
# custom routing paramaters
|
||||
custom_routing_function: Callable | None = None,
|
||||
# eplb parameters
|
||||
enable_eplb: bool = False,
|
||||
eplb_state: EplbLayerState = EMPTY_EPLB_STATE,
|
||||
) -> FusedMoERouter:
|
||||
"""
|
||||
Factory function to create the appropriate FusedMoERouter subclass based on
|
||||
the provided parameters.
|
||||
|
||||
The selection logic follows this priority order:
|
||||
1. RoutingSimulatorRouter - if VLLM_MOE_ROUTING_SIMULATION_STRATEGY env var is set
|
||||
2. GroupedTopKRouter - if use_grouped_topk is True
|
||||
3. CustomRoutingRouter - if custom_routing_function is not None
|
||||
4. FusedTopKBiasRouter - if e_score_correction_bias is not None
|
||||
5. FusedTopKRouter - default fallback
|
||||
|
||||
Common arguments:
|
||||
top_k: Number of experts to select per token
|
||||
global_num_experts: Total number of experts in the model
|
||||
renormalize: Whether to renormalize the routing weights
|
||||
indices_type_getter: Function to get the desired indices dtype
|
||||
routing_method_type: Optional explicit routing method type
|
||||
|
||||
Grouped topk arguments:
|
||||
use_grouped_topk: Whether to use grouped top-k routing
|
||||
num_expert_group: Number of expert groups (for grouped routing)
|
||||
topk_group: Top-k within each group (for grouped routing)
|
||||
scoring_func: Scoring function to use ("softmax" or "sigmoid")
|
||||
num_fused_shared_experts: Number of fused shared experts (for ROCm AITER)
|
||||
|
||||
Grouped topk and fused topk bias arguments:
|
||||
routed_scaling_factor: Scaling factor for routed weights
|
||||
e_score_correction_bias: Optional bias correction for expert scores
|
||||
|
||||
Custom routing arguments:
|
||||
custom_routing_function: Optional custom routing function
|
||||
|
||||
EPLB arguments:
|
||||
enable_eplb: Whether EPLB is enabled
|
||||
eplb_state: EPLB (Expert Parallelism Load Balancing) state
|
||||
|
||||
Returns:
|
||||
An instance of the appropriate FusedMoERouter subclass
|
||||
"""
|
||||
|
||||
routing_strategy = envs.VLLM_MOE_ROUTING_SIMULATION_STRATEGY
|
||||
if routing_strategy != "":
|
||||
return RoutingSimulatorRouter(
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
eplb_state=eplb_state,
|
||||
enable_eplb=enable_eplb,
|
||||
indices_type_getter=indices_type_getter,
|
||||
)
|
||||
|
||||
if use_grouped_topk:
|
||||
assert custom_routing_function is None
|
||||
if num_expert_group is None or topk_group is None:
|
||||
raise ValueError(
|
||||
"num_expert_group and topk_group must be provided when "
|
||||
"use_grouped_topk is True"
|
||||
)
|
||||
grouped_topk_router = GroupedTopKRouter(
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
eplb_state=eplb_state,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
renormalize=renormalize,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
enable_eplb=enable_eplb,
|
||||
indices_type_getter=indices_type_getter,
|
||||
)
|
||||
if (
|
||||
grouped_topk_router.routing_method_type != RoutingMethodType.Unspecified
|
||||
or num_expert_group > 1
|
||||
or topk_group > 1
|
||||
):
|
||||
return grouped_topk_router
|
||||
|
||||
# If routing_method for GroupedTopKRouter is Unspecified and there is only
|
||||
# one group, fallback to standard top-k routing
|
||||
use_grouped_topk = False
|
||||
num_expert_group = None
|
||||
topk_group = None
|
||||
|
||||
if custom_routing_function is not None:
|
||||
return CustomRoutingRouter(
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
eplb_state=eplb_state,
|
||||
custom_routing_function=custom_routing_function,
|
||||
renormalize=renormalize,
|
||||
enable_eplb=enable_eplb,
|
||||
indices_type_getter=indices_type_getter,
|
||||
)
|
||||
|
||||
if e_score_correction_bias is not None:
|
||||
return FusedTopKBiasRouter(
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
eplb_state=eplb_state,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
scoring_func=scoring_func,
|
||||
renormalize=renormalize,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
enable_eplb=enable_eplb,
|
||||
indices_type_getter=indices_type_getter,
|
||||
)
|
||||
|
||||
return FusedTopKRouter(
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
eplb_state=eplb_state,
|
||||
renormalize=renormalize,
|
||||
scoring_func=scoring_func,
|
||||
enable_eplb=enable_eplb,
|
||||
indices_type_getter=indices_type_getter,
|
||||
)
|
||||
@@ -0,0 +1,347 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.eplb.eplb_state import EplbLayerState
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
|
||||
from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class RoutingStrategy(ABC):
|
||||
"""Base class for token-to-expert routing strategies."""
|
||||
|
||||
@abstractmethod
|
||||
def route_tokens(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
indices_type: torch.dtype | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Route tokens to experts.
|
||||
|
||||
Args:
|
||||
hidden_states: Input hidden states [num_tokens, hidden_size]
|
||||
router_logits: Router logits [num_tokens, num_experts]
|
||||
top_k: Number of experts to select per token
|
||||
indices_type: Data type for expert indices
|
||||
|
||||
Returns:
|
||||
tuple of (topk_weights, topk_ids)
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class DistributionBasedRouting(RoutingStrategy):
|
||||
"""
|
||||
Distribution-based random routing strategy with configurable distributions.
|
||||
|
||||
This routing strategy randomly selects experts for each token based on
|
||||
different probability distributions. Currently supports uniform and normal
|
||||
distributions for testing different routing patterns.
|
||||
"""
|
||||
|
||||
def __init__(self, distribution: str = "uniform", **distribution_params: Any):
|
||||
"""
|
||||
Initialize distribution-based routing.
|
||||
|
||||
Args:
|
||||
distribution: Type of distribution to use for sampling
|
||||
- "uniform": Uniform distribution (default)
|
||||
- "normal": Normal/Gaussian distribution
|
||||
**distribution_params: Parameters specific to the
|
||||
chosen distribution
|
||||
For "uniform": No additional parameters needed
|
||||
For "normal": mean (default: 0.0), std (default: 1.0)
|
||||
"""
|
||||
self.distribution = distribution.lower()
|
||||
self.distribution_params = distribution_params
|
||||
|
||||
# Validate distribution and parameters
|
||||
self._validate_distribution_params()
|
||||
|
||||
def _validate_distribution_params(self):
|
||||
"""Validate distribution type and parameters."""
|
||||
valid_distributions = ["uniform", "normal"]
|
||||
|
||||
if self.distribution not in valid_distributions:
|
||||
raise ValueError(
|
||||
f"Unsupported distribution: {self.distribution}. "
|
||||
f"Supported distributions: {valid_distributions}"
|
||||
)
|
||||
|
||||
# Set default parameters if not provided
|
||||
if self.distribution == "normal":
|
||||
self.distribution_params.setdefault("mean", 0.0)
|
||||
self.distribution_params.setdefault("std", 1.0)
|
||||
|
||||
def route_tokens(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
indices_type: torch.dtype | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Randomly select experts for each token using the specified distribution.
|
||||
|
||||
Args:
|
||||
hidden_states: Input hidden states [num_tokens, hidden_size]
|
||||
router_logits: Router logits [num_tokens, num_experts]
|
||||
top_k: Number of experts to select per token
|
||||
indices_type: Data type for expert indices
|
||||
|
||||
Returns:
|
||||
tuple of (topk_weights, topk_ids) where:
|
||||
- topk_weights: Weights based on distribution sampling
|
||||
- topk_ids: Expert indices sampled from the distribution
|
||||
"""
|
||||
num_tokens = hidden_states.shape[0]
|
||||
num_experts = router_logits.shape[-1]
|
||||
|
||||
if indices_type is None:
|
||||
indices_type = torch.long
|
||||
|
||||
# Generate expert IDs based on the specified distribution
|
||||
topk_ids = self._sample_expert_ids(
|
||||
num_tokens, num_experts, top_k, hidden_states.device, indices_type
|
||||
)
|
||||
|
||||
# Generate weights based on the distribution
|
||||
topk_weights = self._generate_weights(num_tokens, top_k, hidden_states.device)
|
||||
|
||||
return topk_weights, topk_ids
|
||||
|
||||
def _sample_expert_ids(
|
||||
self,
|
||||
num_tokens: int,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
device: torch.device,
|
||||
indices_type: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
"""Sample expert IDs based on the specified distribution."""
|
||||
|
||||
if self.distribution == "uniform":
|
||||
# Uniform random sampling
|
||||
return torch.randint(
|
||||
low=0,
|
||||
high=num_experts,
|
||||
size=(num_tokens, top_k),
|
||||
dtype=indices_type,
|
||||
device=device,
|
||||
)
|
||||
|
||||
elif self.distribution == "normal":
|
||||
# For normal distribution, sample continuous values and map to
|
||||
# expert IDs
|
||||
continuous_samples = self._sample_continuous_distribution(
|
||||
num_tokens, top_k, device
|
||||
)
|
||||
|
||||
# Map continuous samples to expert indices
|
||||
# Normalize to [0, 1] range and scale to [0, num_experts)
|
||||
normalized_samples = self._normalize_samples(continuous_samples)
|
||||
expert_ids = (normalized_samples * num_experts).long()
|
||||
expert_ids = torch.clamp(expert_ids, 0, num_experts - 1)
|
||||
|
||||
return expert_ids.to(dtype=indices_type)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported distribution: {self.distribution}")
|
||||
|
||||
def _sample_continuous_distribution(
|
||||
self, num_tokens: int, top_k: int, device: torch.device
|
||||
) -> torch.Tensor:
|
||||
"""Sample from continuous distributions."""
|
||||
shape = (num_tokens, top_k)
|
||||
|
||||
if self.distribution == "normal":
|
||||
mean = self.distribution_params["mean"]
|
||||
std = self.distribution_params["std"]
|
||||
return torch.normal(mean, std, size=shape, device=device)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported continuous distribution: {self.distribution}"
|
||||
)
|
||||
|
||||
def _normalize_samples(self, samples: torch.Tensor) -> torch.Tensor:
|
||||
"""Normalize samples to [0, 1] range."""
|
||||
if self.distribution == "normal":
|
||||
# Use sigmoid to map normal distribution to [0, 1]
|
||||
return torch.sigmoid(samples)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported distribution for normalization: {self.distribution}"
|
||||
)
|
||||
|
||||
def _generate_weights(
|
||||
self, num_tokens: int, top_k: int, device: torch.device
|
||||
) -> torch.Tensor:
|
||||
"""Generate weights based on the distribution."""
|
||||
if self.distribution == "uniform":
|
||||
# All-ones weights for uniform distribution
|
||||
return torch.ones(
|
||||
(num_tokens, top_k),
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
elif self.distribution == "normal":
|
||||
# For normal distribution, generate weights from the same
|
||||
# distribution
|
||||
continuous_weights = self._sample_continuous_distribution(
|
||||
num_tokens, top_k, device
|
||||
)
|
||||
# Normalize to positive values and sum to 1
|
||||
weights = torch.abs(continuous_weights)
|
||||
weights = weights / weights.sum(dim=-1, keepdim=True)
|
||||
return weights
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported distribution for weight generation: {self.distribution}"
|
||||
)
|
||||
|
||||
def get_distribution_info(self) -> dict:
|
||||
"""Get information about the current distribution configuration."""
|
||||
return {
|
||||
"distribution": self.distribution,
|
||||
"parameters": self.distribution_params.copy(),
|
||||
}
|
||||
|
||||
|
||||
class RoutingSimulator:
|
||||
"""
|
||||
Token-to-Expert Routing Simulator.
|
||||
|
||||
This class provides a framework for testing and comparing different
|
||||
routing strategies for MoE models. It can simulate routing behavior
|
||||
and collect statistics for analysis.
|
||||
"""
|
||||
|
||||
# Class-level registry of routing strategies
|
||||
_routing_strategies: dict[str, RoutingStrategy] = {
|
||||
# Basic routing strategies
|
||||
"uniform_random": DistributionBasedRouting(
|
||||
distribution="uniform", mean=0.0, std=1.0
|
||||
),
|
||||
"normal_routing": DistributionBasedRouting(
|
||||
distribution="normal", mean=0.0, std=1.0
|
||||
),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def register_strategy(cls, name: str, strategy: RoutingStrategy):
|
||||
"""
|
||||
Register a custom routing strategy.
|
||||
|
||||
Args:
|
||||
name: Name of the strategy
|
||||
strategy: RoutingStrategy instance
|
||||
"""
|
||||
cls._routing_strategies[name] = strategy
|
||||
|
||||
@classmethod
|
||||
def get_available_strategies(cls) -> list[str]:
|
||||
"""
|
||||
Get list of available routing strategy names.
|
||||
|
||||
Returns:
|
||||
List of available strategy names
|
||||
"""
|
||||
return list(cls._routing_strategies.keys())
|
||||
|
||||
@staticmethod
|
||||
def simulate_routing(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
strategy_name: str,
|
||||
top_k: int,
|
||||
indices_type: torch.dtype | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Simulate token-to-expert routing using the specified strategy.
|
||||
|
||||
Args:
|
||||
hidden_states: Input hidden states [num_tokens, hidden_size]
|
||||
router_logits: Router logits [num_tokens, num_experts]
|
||||
strategy_name: Name of the routing strategy to use
|
||||
top_k: Number of experts to select per token
|
||||
indices_type: Data type for expert indices
|
||||
|
||||
Returns:
|
||||
tuple of (topk_weights, topk_ids)
|
||||
"""
|
||||
if strategy_name not in RoutingSimulator._routing_strategies:
|
||||
raise ValueError(
|
||||
f"Unknown routing strategy: {strategy_name}. "
|
||||
f"Available strategies: "
|
||||
f"{list(RoutingSimulator._routing_strategies.keys())}"
|
||||
)
|
||||
logger.warning_once(
|
||||
"Simulating MoE routing using a %s strategy. "
|
||||
"This should only be used for performance testing. "
|
||||
"Model outputs will not be valid.",
|
||||
strategy_name,
|
||||
)
|
||||
|
||||
strategy = RoutingSimulator._routing_strategies[strategy_name]
|
||||
return strategy.route_tokens(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
indices_type=indices_type,
|
||||
)
|
||||
|
||||
|
||||
class RoutingSimulatorRouter(BaseRouter):
|
||||
"""Router that uses routing simulation strategies for testing/debugging."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
top_k: int,
|
||||
global_num_experts: int,
|
||||
eplb_state: EplbLayerState,
|
||||
enable_eplb: bool = False,
|
||||
indices_type_getter: Callable[[], torch.dtype | None] | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
eplb_state=eplb_state,
|
||||
enable_eplb=enable_eplb,
|
||||
indices_type_getter=indices_type_getter,
|
||||
)
|
||||
|
||||
@property
|
||||
def routing_method_type(self) -> RoutingMethodType:
|
||||
return RoutingMethodType.Simulated
|
||||
|
||||
def _compute_routing(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
indices_type: torch.dtype | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Use routing simulator to compute routing."""
|
||||
routing_strategy = envs.VLLM_MOE_ROUTING_SIMULATION_STRATEGY
|
||||
topk_weights, topk_ids = RoutingSimulator.simulate_routing(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
strategy_name=routing_strategy,
|
||||
top_k=self.top_k,
|
||||
indices_type=indices_type,
|
||||
)
|
||||
return topk_weights, topk_ids
|
||||
Reference in New Issue
Block a user