### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
| `vllm_ascend/ops/fused_moe/comm_utils.py` |
| `vllm_ascend/ops/fused_moe/experts_selector.py` |
| `vllm_ascend/ops/fused_moe/fused_moe.py` |
| `vllm_ascend/ops/fused_moe/moe_comm_method.py` |
| `vllm_ascend/ops/fused_moe/moe_mlp.py` |
| `vllm_ascend/ops/fused_moe/prepare_finalize.py` |
| `vllm_ascend/ops/fused_moe/token_dispatcher.py` |
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.14.0
- vLLM main:
d68209402d
Signed-off-by: MrZ20 <2609716663@qq.com>
Signed-off-by: SILONG ZENG <2609716663@qq.com>
This commit is contained in:
@@ -17,7 +17,6 @@ from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
from vllm.forward_context import get_forward_context
|
||||
@@ -27,18 +26,24 @@ import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_forward_context import MoECommType
|
||||
from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp
|
||||
from vllm_ascend.ops.fused_moe.prepare_finalize import (
|
||||
PrepareAndFinalize, PrepareAndFinalizeWithAll2All,
|
||||
PrepareAndFinalizeWithAllGather, PrepareAndFinalizeWithMC2, QuantType)
|
||||
PrepareAndFinalize,
|
||||
PrepareAndFinalizeWithAll2All,
|
||||
PrepareAndFinalizeWithAllGather,
|
||||
PrepareAndFinalizeWithMC2,
|
||||
QuantType,
|
||||
)
|
||||
from vllm_ascend.ops.fused_moe.token_dispatcher import (
|
||||
MoETokenDispatcher, TokenDispatcherWithAll2AllV,
|
||||
TokenDispatcherWithAllGather, TokenDispatcherWithMC2)
|
||||
MoETokenDispatcher,
|
||||
TokenDispatcherWithAll2AllV,
|
||||
TokenDispatcherWithAllGather,
|
||||
TokenDispatcherWithMC2,
|
||||
)
|
||||
|
||||
_MoECommMethods: Dict[Optional[MoECommType], MoECommMethod] = {}
|
||||
_MoECommMethods: dict[MoECommType | None, MoECommMethod] = {}
|
||||
|
||||
|
||||
def get_moe_comm_method(
|
||||
moe_comm_type: Optional[MoECommType]) -> Optional[MoECommMethod]:
|
||||
return _MoECommMethods.get(moe_comm_type, None)
|
||||
def get_moe_comm_method(moe_comm_type: MoECommType | None) -> MoECommMethod | None:
|
||||
return _MoECommMethods.get(moe_comm_type)
|
||||
|
||||
|
||||
def setup_moe_comm_method(moe_config):
|
||||
@@ -50,6 +55,7 @@ def setup_moe_comm_method(moe_config):
|
||||
|
||||
def set_gmmswigluquant_method():
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
return ascend_config.ascend_fusion_config.fusion_ops_gmmswigluquant
|
||||
|
||||
@@ -84,51 +90,46 @@ class MoECommMethod(ABC):
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
quant_type: QuantType = QuantType.NONE,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
|
||||
hidden_states, router_logits, mc2_mask, context_metadata = self.prepare_finalize.prepare(
|
||||
hidden_states, router_logits, enable_shared_expert_dp,
|
||||
replace_allreduce, quant_type)
|
||||
hidden_states, router_logits, enable_shared_expert_dp, replace_allreduce, quant_type
|
||||
)
|
||||
return hidden_states, router_logits, mc2_mask, context_metadata
|
||||
|
||||
def finalize(self,
|
||||
hidden_states: torch.Tensor,
|
||||
reduce_results: bool,
|
||||
context_metadata: Optional[dict] = None) -> torch.Tensor:
|
||||
hidden_states = self.prepare_finalize.finalize(hidden_states,
|
||||
reduce_results,
|
||||
context_metadata)
|
||||
def finalize(
|
||||
self, hidden_states: torch.Tensor, reduce_results: bool, context_metadata: dict | None = None
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.prepare_finalize.finalize(hidden_states, reduce_results, context_metadata)
|
||||
return hidden_states
|
||||
|
||||
def fused_experts(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor | list[torch.Tensor],
|
||||
w2: torch.Tensor | list[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int4_w4a8: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[list[torch.Tensor]] = None,
|
||||
w2_scale: Optional[list[torch.Tensor]] = None,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None,
|
||||
w1_offset: Optional[torch.Tensor] = None,
|
||||
w2_offset: Optional[torch.Tensor] = None,
|
||||
# For load balance
|
||||
log2phy: torch.Tensor = None,
|
||||
need_trans: bool = False,
|
||||
dynamic_eplb: bool = False,
|
||||
mc2_mask: torch.Tensor = None,
|
||||
pertoken_scale: Optional[torch.Tensor] = None):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor | list[torch.Tensor],
|
||||
w2: torch.Tensor | list[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int4_w4a8: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
w1_scale: list[torch.Tensor] | None = None,
|
||||
w2_scale: list[torch.Tensor] | None = None,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None,
|
||||
w1_offset: torch.Tensor | None = None,
|
||||
w2_offset: torch.Tensor | None = None,
|
||||
# For load balance
|
||||
log2phy: torch.Tensor = None,
|
||||
need_trans: bool = False,
|
||||
dynamic_eplb: bool = False,
|
||||
mc2_mask: torch.Tensor = None,
|
||||
pertoken_scale: torch.Tensor | None = None,
|
||||
):
|
||||
# Check constraints
|
||||
assert hidden_states.dtype in [
|
||||
torch.float32, torch.float16, torch.bfloat16, torch.int8
|
||||
]
|
||||
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16, torch.int8]
|
||||
|
||||
moe_comm_method = get_forward_context().moe_comm_method
|
||||
assert moe_comm_method is not None, "Missing communication context"
|
||||
@@ -143,13 +144,13 @@ class MoECommMethod(ABC):
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
expert_map=expert_map,
|
||||
global_redundant_expert_num=self.moe_config.
|
||||
global_redundant_expert_num,
|
||||
global_redundant_expert_num=self.moe_config.global_redundant_expert_num,
|
||||
mc2_mask=mc2_mask,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
with_quant=use_int8_w8a8 or use_int4_w4a8,
|
||||
dynamic_eplb=dynamic_eplb,
|
||||
pertoken_scale=pertoken_scale)
|
||||
pertoken_scale=pertoken_scale,
|
||||
)
|
||||
|
||||
mlp_output = unified_apply_mlp(
|
||||
hidden_states=dispatch_results.hidden_states,
|
||||
@@ -168,29 +169,29 @@ class MoECommMethod(ABC):
|
||||
with_quant=use_int8_w8a8 or use_int4_w4a8 or use_int4_w4a16,
|
||||
fusion=use_int8_w8a8 and self.use_fusion_ops,
|
||||
need_trans=need_trans,
|
||||
dynamic_eplb=dynamic_eplb)
|
||||
dynamic_eplb=dynamic_eplb,
|
||||
)
|
||||
|
||||
before_combine_evt = torch.npu.current_stream().record_event()
|
||||
combine_results = self.token_dispatcher.token_combine(
|
||||
hidden_states=mlp_output,
|
||||
context_metadata=dispatch_results.context_metadata)
|
||||
hidden_states=mlp_output, context_metadata=dispatch_results.context_metadata
|
||||
)
|
||||
|
||||
return FusedExpertsResult(
|
||||
routed_out=combine_results.routed_out,
|
||||
before_dispatch_evt=before_dispatch_evt,
|
||||
before_combine_evt=before_combine_evt,
|
||||
group_list_type=dispatch_results.group_list_type,
|
||||
expert_tokens=dispatch_results.group_list)
|
||||
expert_tokens=dispatch_results.group_list,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _get_token_dispatcher(self) -> MoETokenDispatcher:
|
||||
raise NotImplementedError(
|
||||
"_get_token_dispatcher function not implemented.")
|
||||
raise NotImplementedError("_get_token_dispatcher function not implemented.")
|
||||
|
||||
@abstractmethod
|
||||
def _get_prepare_finalize(self) -> PrepareAndFinalize:
|
||||
raise NotImplementedError(
|
||||
"_get_prepare_finalize function not implemented.")
|
||||
raise NotImplementedError("_get_prepare_finalize function not implemented.")
|
||||
|
||||
|
||||
class AllGatherCommImpl(MoECommMethod):
|
||||
@@ -216,7 +217,8 @@ class AllGatherCommImpl(MoECommMethod):
|
||||
return TokenDispatcherWithAllGather(
|
||||
top_k=self.moe_config.experts_per_token,
|
||||
num_experts=self.moe_config.num_experts,
|
||||
num_local_experts=self.moe_config.num_local_experts)
|
||||
num_local_experts=self.moe_config.num_local_experts,
|
||||
)
|
||||
|
||||
def _get_prepare_finalize(self):
|
||||
return PrepareAndFinalizeWithAllGather(self.moe_config)
|
||||
@@ -227,7 +229,7 @@ class MC2CommImpl(MoECommMethod):
|
||||
1. `enable_expert_parallel=True`.
|
||||
2. `npu_moe_distribute_dispatch` and `npu_moe_distribute_combine` are available.
|
||||
3. `enable_expert_parallel=False` is not supported.
|
||||
|
||||
|
||||
This implementation uses the MC2 communication method, which is optimized for
|
||||
Communication and Computation parallelism on Ascend devices.
|
||||
"""
|
||||
@@ -253,7 +255,8 @@ class AlltoAllCommImpl(MoECommMethod):
|
||||
return TokenDispatcherWithAll2AllV(
|
||||
top_k=self.moe_config.experts_per_token,
|
||||
num_experts=self.moe_config.num_experts,
|
||||
num_local_experts=self.moe_config.num_local_experts)
|
||||
num_local_experts=self.moe_config.num_local_experts,
|
||||
)
|
||||
|
||||
def _get_prepare_finalize(self):
|
||||
return PrepareAndFinalizeWithAll2All(self.moe_config)
|
||||
@@ -264,7 +267,7 @@ class FusedMC2CommImpl(MoECommMethod):
|
||||
1. `enable_expert_parallel=True`.
|
||||
2. `npu_moe_distribute_dispatch` and `npu_moe_distribute_combine` are available.
|
||||
3. `enable_expert_parallel=False` is not supported.
|
||||
|
||||
|
||||
This implementation uses the MC2 communication method, which is optimized for
|
||||
Communication and Computation parallelism on Ascend devices.
|
||||
"""
|
||||
@@ -276,36 +279,36 @@ class FusedMC2CommImpl(MoECommMethod):
|
||||
return PrepareAndFinalizeWithMC2(self.moe_config)
|
||||
|
||||
def fused_experts(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor | list[torch.Tensor],
|
||||
w2: torch.Tensor | list[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int4_w4a8: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[list[torch.Tensor]] = None,
|
||||
w2_scale: Optional[list[torch.Tensor]] = None,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None,
|
||||
w1_offset: Optional[torch.Tensor] = None,
|
||||
w2_offset: Optional[torch.Tensor] = None,
|
||||
# For load balance
|
||||
log2phy: torch.Tensor = None,
|
||||
need_trans: bool = False,
|
||||
dynamic_eplb: bool = False,
|
||||
mc2_mask: torch.Tensor = None,
|
||||
pertoken_scale: Optional[torch.Tensor] = None):
|
||||
assert not (
|
||||
w1_scale is None or w2_scale is None
|
||||
), "w1_scale and w2_scale cannot be None for FusedMC2CommImpl."
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor | list[torch.Tensor],
|
||||
w2: torch.Tensor | list[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int4_w4a8: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
w1_scale: list[torch.Tensor] | None = None,
|
||||
w2_scale: list[torch.Tensor] | None = None,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None,
|
||||
w1_offset: torch.Tensor | None = None,
|
||||
w2_offset: torch.Tensor | None = None,
|
||||
# For load balance
|
||||
log2phy: torch.Tensor = None,
|
||||
need_trans: bool = False,
|
||||
dynamic_eplb: bool = False,
|
||||
mc2_mask: torch.Tensor = None,
|
||||
pertoken_scale: torch.Tensor | None = None,
|
||||
):
|
||||
assert not (w1_scale is None or w2_scale is None), "w1_scale and w2_scale cannot be None for FusedMC2CommImpl."
|
||||
|
||||
assert isinstance(self.token_dispatcher, TokenDispatcherWithMC2), \
|
||||
assert isinstance(self.token_dispatcher, TokenDispatcherWithMC2), (
|
||||
"token_dispatcher must be an instance of TokenDispatcherWithMC2."
|
||||
)
|
||||
|
||||
# Apply log2phy if needed
|
||||
if log2phy is not None:
|
||||
@@ -346,10 +349,8 @@ class FusedMC2CommImpl(MoECommMethod):
|
||||
ep_rank_size=self.token_dispatcher.ep_world_size,
|
||||
ep_rank_id=self.token_dispatcher.ep_rank_id,
|
||||
moe_expert_num=self.moe_config.num_experts,
|
||||
global_bs=self.token_dispatcher.global_bs)
|
||||
global_bs=self.token_dispatcher.global_bs,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Wrong value of {envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2=}")
|
||||
return FusedExpertsResult(routed_out=out,
|
||||
group_list_type=group_list_type,
|
||||
expert_tokens=expert_tokens)
|
||||
raise ValueError(f"Wrong value of {envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2=}")
|
||||
return FusedExpertsResult(routed_out=out, group_list_type=group_list_type, expert_tokens=expert_tokens)
|
||||
|
||||
Reference in New Issue
Block a user