[Feature] Support fine-grained shared expert overlap (#5482)

Fine-grained control over shared expert overlap to prevent resource
contention.

- vLLM version: v0.13.0
- vLLM main:
5326c89803

---------

Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
This commit is contained in:
Jade Zheng
2026-01-17 11:53:22 +08:00
committed by GitHub
parent 48e10de8c9
commit 22f253142a
9 changed files with 203 additions and 130 deletions

View File

@@ -17,7 +17,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Optional
from typing import Dict, Optional
import torch
from vllm.forward_context import get_forward_context
@@ -51,6 +51,11 @@ def setup_moe_comm_method(moe_config):
@dataclass
class FusedExpertsResult:
routed_out: torch.Tensor
# This field is for shared experts and should be set by the MoE
# communication method that supports shared experts in parallel with routed
# experts.
before_dispatch_evt: torch.npu.Event | None = None
before_combine_evt: torch.npu.Event | None = None
# For dynamic_eplb
group_list_type: int | None = None
expert_tokens: torch.Tensor | None = None
@@ -108,10 +113,6 @@ class MoECommMethod(ABC):
w2_scale_bias: torch.Tensor = None,
w1_offset: Optional[torch.Tensor] = None,
w2_offset: Optional[torch.Tensor] = None,
# For Cube/Vector parallel
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
# For load balance
log2phy: torch.Tensor = None,
need_trans: bool = False,
@@ -126,6 +127,7 @@ class MoECommMethod(ABC):
moe_comm_method = get_forward_context().moe_comm_method
assert moe_comm_method is not None, "Missing communication context"
before_dispatch_evt = torch.npu.current_stream().record_event()
dispatch_results = self.token_dispatcher.token_dispatch(
hidden_states=hidden_states,
topk_weights=topk_weights,
@@ -134,9 +136,6 @@ class MoECommMethod(ABC):
log2phy=log2phy,
global_redundant_expert_num=self.moe_config.
global_redundant_expert_num,
shared_experts=shared_experts,
quantized_x_for_share=quantized_x_for_share,
dynamic_scale_for_share=dynamic_scale_for_share,
mc2_mask=mc2_mask,
apply_router_weight_on_input=apply_router_weight_on_input,
with_quant=use_int8_w8a8 or use_int4_w4a8,
@@ -162,12 +161,15 @@ class MoECommMethod(ABC):
need_trans=need_trans,
dynamic_eplb=dynamic_eplb)
before_combine_evt = torch.npu.current_stream().record_event()
combine_results = self.token_dispatcher.token_combine(
hidden_states=mlp_output,
context_metadata=dispatch_results.context_metadata)
return FusedExpertsResult(
routed_out=combine_results.routed_out,
before_dispatch_evt=before_dispatch_evt,
before_combine_evt=before_combine_evt,
group_list_type=dispatch_results.group_list_type,
expert_tokens=dispatch_results.group_list)
@@ -284,10 +286,6 @@ class FusedMC2CommImpl(MoECommMethod):
w2_scale_bias: torch.Tensor = None,
w1_offset: Optional[torch.Tensor] = None,
w2_offset: Optional[torch.Tensor] = None,
# For Cube/Vector parallel
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
# For load balance
log2phy: torch.Tensor = None,
need_trans: bool = False,