[CustomOp] Register AscendSharedFusedMoE custom op (#2980)

### What this PR does / why we need it?
Register `AscendSharedFusedMoE` custom op.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

`DeepSeek-V2-Lite` is a MoE model with shared experts.

Test:

```bash
vllm serve /root/.cache/modelscope/hub/models/deepseek-ai/DeepSeek-V2-Lite \
--trust-remote-code \
--enforce-eager \
--no-enable-prefix-caching \
--gpu-memory-utilization 0.95

curl -X POST http://localhost:8000/v1/chat/completions \
    -H "Content-Type: application/json" \
    -d '{
        "model": "/root/.cache/modelscope/hub/models/deepseek-ai/DeepSeek-V2-Lite",
        "messages": [
            {"role": "user", "content": "介绍一下联通公司?"}
        ],
        "stream": false,
        "max_tokens": 100
    }'
```

Output:

```bash
中国联合网络通信集团有限公司(简称“中国联通”)于2009年1月6日在原中国网通和原中国联通的基础上合并组建而成,在国内31个省(自治区、直辖市)和境外多个国家和地区设有分支机构,是中国唯一一家在纽约、香港、上海三地同时上市的电信运营企业,连续多年入选“世界500强企业”。\n\n中国联通主要经营固定通信业务,移动通信业务,国内
```


- vLLM version: v0.10.2
- vLLM main:
486c5599e3

---------

Signed-off-by: Shanshan Shen <87969357+shen-shanshan@users.noreply.github.com>
Signed-off-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
Shanshan Shen
2025-09-19 19:05:01 +08:00
committed by GitHub
parent 05a700d370
commit 8326f15ecf
4 changed files with 18 additions and 26 deletions

View File

@@ -27,6 +27,7 @@ from vllm.model_executor.layers.fused_moe.config import \
FusedMoEParallelConfig # isort: skip
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map)
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import get_mc2_group
@@ -415,7 +416,7 @@ class AscendFusedMoE(FusedMoE):
expert_data.copy_(loaded_weight)
class AscendSharedFusedMoE(AscendFusedMoE):
class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
def __init__(
self,
@@ -423,7 +424,7 @@ class AscendSharedFusedMoE(AscendFusedMoE):
use_overlapped: bool = True,
**kwargs,
):
super().__init__(**kwargs)
AscendFusedMoE.__init__(self, **kwargs)
self._shared_experts = shared_experts
self.use_overlapped = use_overlapped
self.shared_expert_stream = None
@@ -452,7 +453,8 @@ class AscendSharedFusedMoE(AscendFusedMoE):
if moe_comm_method_name in {"alltoallcommimpl", "mc2commimpl"}:
shared_out = tensor_model_parallel_all_reduce(shared_out)
fused_out = super().forward(
_, fused_out = AscendFusedMoE.forward(
self,
hidden_states=hidden_states,
router_logits=router_logits,
)
@@ -461,6 +463,16 @@ class AscendSharedFusedMoE(AscendFusedMoE):
torch.npu.current_stream().wait_stream(self.shared_expert_stream)
return shared_out, fused_out
def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
shared_output = torch.empty(1)
fused_output = AscendFusedMoE.forward_impl(
self,
hidden_states=hidden_states,
router_logits=router_logits,
)
return shared_output, fused_output
UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func
UnquantizedFusedMoEMethod.process_weights_after_loading = process_weights_after_loading

View File

@@ -18,4 +18,3 @@
import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa
import vllm_ascend.patch.worker.patch_common.patch_logits # noqa
import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa
import vllm_ascend.patch.worker.patch_common.patch_shared_fused_moe # noqa

View File

@@ -1,21 +0,0 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from vllm.model_executor.models import deepseek_v2, llama4
from vllm_ascend.ops.common_fused_moe import AscendSharedFusedMoE
deepseek_v2.SharedFusedMoE = AscendSharedFusedMoE
llama4.SharedFusedMoE = AscendSharedFusedMoE

View File

@@ -498,7 +498,8 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
from vllm_ascend.models.layers.mla import AscendMultiHeadLatentAttention
from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul
from vllm_ascend.ops.common_fused_moe import AscendFusedMoE
from vllm_ascend.ops.common_fused_moe import (AscendFusedMoE,
AscendSharedFusedMoE)
from vllm_ascend.ops.layernorm import AscendQuantRMSNorm, AscendRMSNorm
from vllm_ascend.ops.linear import (AscendColumnParallelLinear,
AscendMergedColumnParallelLinear,
@@ -525,6 +526,7 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
"LogitsProcessor": AscendLogitsProcessor,
"RMSNorm": AscendRMSNorm,
"FusedMoE": AscendFusedMoE,
"SharedFusedMoE": AscendSharedFusedMoE,
"MultiHeadLatentAttention": AscendMultiHeadLatentAttention,
}