[Feature] Support fine-grained shared expert overlap (#5482)
Fine-grained control over shared expert overlap to prevent resource
contention.
- vLLM version: v0.13.0
- vLLM main:
5326c89803
---------
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
This commit is contained in:
@@ -17,7 +17,7 @@ from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
from vllm.forward_context import get_forward_context
|
||||
@@ -51,6 +51,11 @@ def setup_moe_comm_method(moe_config):
|
||||
@dataclass
|
||||
class FusedExpertsResult:
|
||||
routed_out: torch.Tensor
|
||||
# This field is for shared experts and should be set by the MoE
|
||||
# communication method that supports shared experts in parallel with routed
|
||||
# experts.
|
||||
before_dispatch_evt: torch.npu.Event | None = None
|
||||
before_combine_evt: torch.npu.Event | None = None
|
||||
# For dynamic_eplb
|
||||
group_list_type: int | None = None
|
||||
expert_tokens: torch.Tensor | None = None
|
||||
@@ -108,10 +113,6 @@ class MoECommMethod(ABC):
|
||||
w2_scale_bias: torch.Tensor = None,
|
||||
w1_offset: Optional[torch.Tensor] = None,
|
||||
w2_offset: Optional[torch.Tensor] = None,
|
||||
# For Cube/Vector parallel
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
dynamic_scale_for_share: Optional[Any] = None,
|
||||
# For load balance
|
||||
log2phy: torch.Tensor = None,
|
||||
need_trans: bool = False,
|
||||
@@ -126,6 +127,7 @@ class MoECommMethod(ABC):
|
||||
moe_comm_method = get_forward_context().moe_comm_method
|
||||
assert moe_comm_method is not None, "Missing communication context"
|
||||
|
||||
before_dispatch_evt = torch.npu.current_stream().record_event()
|
||||
dispatch_results = self.token_dispatcher.token_dispatch(
|
||||
hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
@@ -134,9 +136,6 @@ class MoECommMethod(ABC):
|
||||
log2phy=log2phy,
|
||||
global_redundant_expert_num=self.moe_config.
|
||||
global_redundant_expert_num,
|
||||
shared_experts=shared_experts,
|
||||
quantized_x_for_share=quantized_x_for_share,
|
||||
dynamic_scale_for_share=dynamic_scale_for_share,
|
||||
mc2_mask=mc2_mask,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
with_quant=use_int8_w8a8 or use_int4_w4a8,
|
||||
@@ -162,12 +161,15 @@ class MoECommMethod(ABC):
|
||||
need_trans=need_trans,
|
||||
dynamic_eplb=dynamic_eplb)
|
||||
|
||||
before_combine_evt = torch.npu.current_stream().record_event()
|
||||
combine_results = self.token_dispatcher.token_combine(
|
||||
hidden_states=mlp_output,
|
||||
context_metadata=dispatch_results.context_metadata)
|
||||
|
||||
return FusedExpertsResult(
|
||||
routed_out=combine_results.routed_out,
|
||||
before_dispatch_evt=before_dispatch_evt,
|
||||
before_combine_evt=before_combine_evt,
|
||||
group_list_type=dispatch_results.group_list_type,
|
||||
expert_tokens=dispatch_results.group_list)
|
||||
|
||||
@@ -284,10 +286,6 @@ class FusedMC2CommImpl(MoECommMethod):
|
||||
w2_scale_bias: torch.Tensor = None,
|
||||
w1_offset: Optional[torch.Tensor] = None,
|
||||
w2_offset: Optional[torch.Tensor] = None,
|
||||
# For Cube/Vector parallel
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
dynamic_scale_for_share: Optional[Any] = None,
|
||||
# For load balance
|
||||
log2phy: torch.Tensor = None,
|
||||
need_trans: bool = False,
|
||||
|
||||
Reference in New Issue
Block a user