2025-02-21 17:07:37 +08:00
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
2025-03-11 21:08:02 +08:00
# Copyright 2023 The vLLM team.
2025-02-21 17:07:37 +08:00
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
2025-04-17 14:59:56 +08:00
# This file is a part of the vllm-ascend project.
# Adapted from vllm/tests/kernels/test_moe.py
2025-02-21 17:07:37 +08:00
2025-06-09 19:28:11 +08:00
import os
2025-08-28 10:13:35 +08:00
from typing import Any , Callable , Optional
2025-02-21 17:07:37 +08:00
import torch
2025-05-24 14:29:36 +08:00
import torch . distributed as dist
2025-02-21 17:07:37 +08:00
import torch_npu
[refactor] Refactoring AscendFusedMoE (#1229)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
This PR is used for resolved [issue
1147](https://github.com/vllm-project/vllm-ascend/issues/1147)
1. Move fused_moe code into one file `fused_moe.py`.
2. Integrate branch conditions into function `get_fused_moe_state`.
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
### Does this PR introduce _any_ user-facing change?
1. This PR has removed the env `VLLM_ENABLE_MC2`, because I think this
env is useless, we can make judgments based on the current scenario
without this env, it will only increase complexity.
2. This PR has removed the env `USING_LCCL_COM`, because this env has
already expired.
3. `additional_config.expert_tensor_parallel_size` has already expired,
and now we also use parameter `enable_expert_parallel`, consistent with
the vLLM.
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
Signed-off-by: zzzzwwjj <1183291235@qq.com>
2025-06-17 17:49:03 +08:00
from torch import nn
2025-04-19 17:38:18 +08:00
from vllm . config import get_current_vllm_config
2025-08-28 10:13:35 +08:00
from vllm . distributed import ( get_tensor_model_parallel_rank ,
2025-05-24 14:29:36 +08:00
get_tensor_model_parallel_world_size ,
2025-05-16 12:14:55 +08:00
tensor_model_parallel_all_reduce )
2025-07-21 09:08:04 +08:00
from vllm . distributed . parallel_state import ( get_dp_group , get_ep_group ,
get_tp_group )
[refactor] Refactoring AscendFusedMoE (#1229)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
This PR is used for resolved [issue
1147](https://github.com/vllm-project/vllm-ascend/issues/1147)
1. Move fused_moe code into one file `fused_moe.py`.
2. Integrate branch conditions into function `get_fused_moe_state`.
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
### Does this PR introduce _any_ user-facing change?
1. This PR has removed the env `VLLM_ENABLE_MC2`, because I think this
env is useless, we can make judgments based on the current scenario
without this env, it will only increase complexity.
2. This PR has removed the env `USING_LCCL_COM`, because this env has
already expired.
3. `additional_config.expert_tensor_parallel_size` has already expired,
and now we also use parameter `enable_expert_parallel`, consistent with
the vLLM.
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
Signed-off-by: zzzzwwjj <1183291235@qq.com>
2025-06-17 17:49:03 +08:00
from vllm . forward_context import get_forward_context
2025-07-09 08:52:24 +08:00
from vllm . model_executor . layers . fused_moe . config import \
FusedMoEConfig # isort: skip
from vllm . model_executor . layers . fused_moe . config import \
FusedMoEParallelConfig # isort: skip
2025-04-19 17:38:18 +08:00
from vllm . model_executor . layers . fused_moe . layer import (
2025-07-03 18:36:17 +08:00
FusedMoE , UnquantizedFusedMoEMethod , determine_expert_map )
2025-05-28 21:18:41 +08:00
from vllm . model_executor . layers . quantization . base_config import \
QuantizationConfig
2025-04-19 17:38:18 +08:00
2025-05-13 12:52:30 +08:00
import vllm_ascend . envs as envs_ascend
2025-06-05 16:28:01 +08:00
from vllm_ascend . ascend_config import get_ascend_config
2025-07-28 14:06:20 +08:00
from vllm_ascend . ascend_forward_context import FusedMoEState
2025-07-10 10:57:24 +08:00
from vllm_ascend . distributed . communication_op import \
data_parallel_reduce_scatter
2025-07-28 14:06:20 +08:00
from vllm_ascend . distributed . parallel_state import get_mc2_group
2025-06-09 19:28:11 +08:00
from vllm_ascend . ops . expert_load_balancer import ExpertLoadBalancer
2025-08-14 11:50:53 +08:00
from vllm_ascend . ops . layers . experts_selector import select_experts
2025-08-02 09:49:10 +08:00
from vllm_ascend . ops . moe_dispatcher . token_dispatcher import (
MoEAlltoAllSeqOverLapDispatcher , MoEDispatcherConfig )
2025-08-07 09:15:49 +08:00
from vllm_ascend . ops . sequence_parallel import MetadataForPadding
2025-08-28 10:13:35 +08:00
from vllm_ascend . utils import ( ACL_FORMAT_FRACTAL_NZ , dispose_tensor ,
get_all_reduce_merge_state ,
2025-07-21 19:43:30 +08:00
get_rm_router_logits_state , is_310p )
2025-04-19 17:38:18 +08:00
[CI]Moe alltoall communication optimization (#1067)
[CI]Moe alltoall communication optimization
The DeepSeek V3/R1 model has 256 routing experts. During parallel
inference, if the load of an EP rank is high, the overall communication
and computing time is slowed down, which becomes a weakness of parallel
inference because the load is unevenly distributed. However, the data
volume in the prefill phase is large, and the inter-card communication
time consumption/calculation time consumption and the data volume are
closely related to each other. Therefore, less non-linear precision loss
can be used to obtain a near-linear performance improvement.
During parallel inference, global synchronization occurs during
communication. As a result, the card with low load completes the
calculation first and waits for the card with the highest load to
complete the calculation. Therefore, if the load is unbalanced, the card
with high load slows down the overall time consumption. Significant
performance gains can be achieved by discarding a small number of
tokens, which is unacceptable in some precision-sensitive scenarios.
However, similar to quantification, it is a solution that uses an
acceptable precision loss in some scenarios for performance. In
addition, a trade-off between performance and precision can be achieved
by configuring a proportion of discarded tokens.
Perform the test on A3. The batch size is 8 (B), the prompt length is
3.5K tokens (S), and the parallel configuration is as follows: AttnDP=2,
AttnTP=8, MoeTP=1, and MoeEP=16. In this sence, we got a 10%-15%
performance gain.
Plus, the next version, we'll have an alltoallv moe.
---------
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
2025-06-07 10:15:56 +08:00
MOE_ALL2ALL_BUFFER : bool = envs_ascend . MOE_ALL2ALL_BUFFER
def process_topk_ids ( topk_ids : torch . Tensor , expert_num : int , ep_size : int ,
max_row_per_ep_rank : int , num_tokens : int ,
top_k : int ) - > tuple [ torch . Tensor , torch . Tensor ] :
original_total_elements = num_tokens * top_k
device = topk_ids . device
original_dtype = topk_ids . dtype
if original_total_elements == 0 :
output_len = ep_size * max_row_per_ep_rank
topk_ids_pad = torch . full ( ( output_len , ) ,
expert_num ,
dtype = original_dtype ,
device = device )
unpad_indices = torch . full ( ( original_total_elements , ) ,
- 1 ,
dtype = torch . long ,
device = device )
return topk_ids_pad , unpad_indices
experts_per_ep_rank_val = expert_num / / ep_size
if experts_per_ep_rank_val == 0 :
raise ValueError (
" expert_num // ep_size is 0, which leads to division by zero in ep_rank calculation. "
" Ensure expert_num >= ep_size. " )
assigned_ep_rank = ( topk_ids . float ( ) /
experts_per_ep_rank_val ) . to ( original_dtype )
indices_arange = torch . arange ( topk_ids . shape [ 0 ] , device = device )
[1/N][CI] Move linting system to pre-commits hooks (#1256)
### What this PR does / why we need it?
Follow vllm-project/vllm lint way:
https://github.com/vllm-project/vllm/blob/main/.pre-commit-config.yaml
Enable pre-commit to avoid some low level error AMAP.
This pr is one step of #1241, The purpose is make linting system more
clear and convenient, on this step, Mainly did the following things:
yapf, actionlint, ruff, typos, isort, mypy, png-lint, signoff-commit,
enforce-import-regex-instead-of-re.
TODO:
- clang-format(check for csrc with google style)
need clean code, disable for now
- pymarkdown
need clean code, disable for now
- shellcheck
need clean code, disable for now
### Does this PR introduce _any_ user-facing change?
Only developer UX change:
https://vllm-ascend--1256.org.readthedocs.build/en/1256/developer_guide/contributing.html#run-lint-locally
```
pip install -r requirements-lint.txt && pre-commit install
bash format.sh
```
### How was this patch tested?
CI passed with new added/existing test.
Co-authored-by: Yikun [yikunkero@gmail.com](mailto:yikunkero@gmail.com)
Co-authored-by: wangli
[wangli858794774@gmail.com](mailto:wangli858794774@gmail.com)
- vLLM version: v0.9.1
- vLLM main:
https://github.com/vllm-project/vllm/commit/5358cce5ffbd4011f8fea2188995a249b43b8bfe
---------
Signed-off-by: wangli <wangli858794774@gmail.com>
2025-07-10 14:17:15 +08:00
is_new_segment = torch . cat (
( torch . tensor ( [ True ] , device = device ) , assigned_ep_rank [ 1 : ]
!= assigned_ep_rank [ : - 1 ] ) )
[CI]Moe alltoall communication optimization (#1067)
[CI]Moe alltoall communication optimization
The DeepSeek V3/R1 model has 256 routing experts. During parallel
inference, if the load of an EP rank is high, the overall communication
and computing time is slowed down, which becomes a weakness of parallel
inference because the load is unevenly distributed. However, the data
volume in the prefill phase is large, and the inter-card communication
time consumption/calculation time consumption and the data volume are
closely related to each other. Therefore, less non-linear precision loss
can be used to obtain a near-linear performance improvement.
During parallel inference, global synchronization occurs during
communication. As a result, the card with low load completes the
calculation first and waits for the card with the highest load to
complete the calculation. Therefore, if the load is unbalanced, the card
with high load slows down the overall time consumption. Significant
performance gains can be achieved by discarding a small number of
tokens, which is unacceptable in some precision-sensitive scenarios.
However, similar to quantification, it is a solution that uses an
acceptable precision loss in some scenarios for performance. In
addition, a trade-off between performance and precision can be achieved
by configuring a proportion of discarded tokens.
Perform the test on A3. The batch size is 8 (B), the prompt length is
3.5K tokens (S), and the parallel configuration is as follows: AttnDP=2,
AttnTP=8, MoeTP=1, and MoeEP=16. In this sence, we got a 10%-15%
performance gain.
Plus, the next version, we'll have an alltoallv moe.
---------
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
2025-06-07 10:15:56 +08:00
temp_start_markers = torch . full_like ( indices_arange ,
- 1 ,
dtype = indices_arange . dtype )
temp_start_markers [ is_new_segment ] = indices_arange [ is_new_segment ]
start_offset_for_each_token = torch . cummax ( temp_start_markers , dim = 0 ) [ 0 ]
token_intra_ep_rank_idx = indices_arange - start_offset_for_each_token
is_kept_mask = token_intra_ep_rank_idx < max_row_per_ep_rank
cumsum_kept = torch . cumsum ( is_kept_mask . float ( ) , dim = 0 ) . to ( torch . long )
indices_in_rec_cond_list_for_all = cumsum_kept - 1
unpad_indices = torch . where (
is_kept_mask , indices_in_rec_cond_list_for_all ,
torch . tensor ( - 1 , device = device , dtype = torch . long ) )
output_len = ep_size * max_row_per_ep_rank
topk_ids_pad = torch . full ( ( output_len , ) ,
expert_num ,
dtype = original_dtype ,
device = device )
if topk_ids . shape [ 0 ] > 0 :
all_destination_indices = assigned_ep_rank * max_row_per_ep_rank + token_intra_ep_rank_idx
temp_pad_buffer = torch . full ( ( output_len + 1 , ) ,
expert_num ,
dtype = original_dtype ,
device = device )
output_len_tensor = torch . tensor ( output_len ,
dtype = torch . long ,
device = device )
scatter_indices = torch . where ( is_kept_mask , all_destination_indices ,
output_len_tensor )
temp_pad_buffer . scatter_ ( 0 , scatter_indices , topk_ids )
topk_ids_pad = temp_pad_buffer [ : output_len ]
return topk_ids_pad , unpad_indices
2025-05-13 12:52:30 +08:00
2025-04-19 17:38:18 +08:00
2025-08-02 09:49:10 +08:00
def apply_mlp (
hidden_states : torch . Tensor ,
w1 : torch . Tensor ,
w2 : torch . Tensor ,
group_list : torch . Tensor ,
group_list_type : int = 1 ,
) - > torch . Tensor :
[CI]Moe alltoall communication optimization (#1067)
[CI]Moe alltoall communication optimization
The DeepSeek V3/R1 model has 256 routing experts. During parallel
inference, if the load of an EP rank is high, the overall communication
and computing time is slowed down, which becomes a weakness of parallel
inference because the load is unevenly distributed. However, the data
volume in the prefill phase is large, and the inter-card communication
time consumption/calculation time consumption and the data volume are
closely related to each other. Therefore, less non-linear precision loss
can be used to obtain a near-linear performance improvement.
During parallel inference, global synchronization occurs during
communication. As a result, the card with low load completes the
calculation first and waits for the card with the highest load to
complete the calculation. Therefore, if the load is unbalanced, the card
with high load slows down the overall time consumption. Significant
performance gains can be achieved by discarding a small number of
tokens, which is unacceptable in some precision-sensitive scenarios.
However, similar to quantification, it is a solution that uses an
acceptable precision loss in some scenarios for performance. In
addition, a trade-off between performance and precision can be achieved
by configuring a proportion of discarded tokens.
Perform the test on A3. The batch size is 8 (B), the prompt length is
3.5K tokens (S), and the parallel configuration is as follows: AttnDP=2,
AttnTP=8, MoeTP=1, and MoeEP=16. In this sence, we got a 10%-15%
performance gain.
Plus, the next version, we'll have an alltoallv moe.
---------
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
2025-06-07 10:15:56 +08:00
"""
apply MLP : gate_up_proj - > swiglu - > down_proj
Args :
hidden_states_wrapper : wrapper of input hidden states with shape ( num_tokens , hidden_size ) .
w1 : expert weights1 with shape
( num_experts , hidden_size , intermediate_size * 2 )
w2 : expert weights2 with shape
( num_experts , intermediate_size , hidden_size )
group_list : number of tokens for each expert , follow cumsum mode , and
with shape ( num_experts ) .
transpose_weight :
w1 : ( num_experts , intermediate_size * 2 , hidden_size ) - >
( num_experts , hidden_size , intermediate_size * 2 )
w2 : ( num_experts , hidden_size , intermediate_size ) - >
( num_experts , intermediate_size , hidden_size )
Returns :
hidden_states : output hidden states after MLP .
"""
w1 = w1 . transpose ( 1 , 2 )
hidden_states = torch_npu . npu_grouped_matmul (
x = [ hidden_states ] ,
weight = [ w1 ] ,
split_item = 2 ,
group_list_type = group_list_type ,
group_type = 0 ,
group_list = group_list ,
2025-08-07 17:20:19 +08:00
) [ 0 ]
[CI]Moe alltoall communication optimization (#1067)
[CI]Moe alltoall communication optimization
The DeepSeek V3/R1 model has 256 routing experts. During parallel
inference, if the load of an EP rank is high, the overall communication
and computing time is slowed down, which becomes a weakness of parallel
inference because the load is unevenly distributed. However, the data
volume in the prefill phase is large, and the inter-card communication
time consumption/calculation time consumption and the data volume are
closely related to each other. Therefore, less non-linear precision loss
can be used to obtain a near-linear performance improvement.
During parallel inference, global synchronization occurs during
communication. As a result, the card with low load completes the
calculation first and waits for the card with the highest load to
complete the calculation. Therefore, if the load is unbalanced, the card
with high load slows down the overall time consumption. Significant
performance gains can be achieved by discarding a small number of
tokens, which is unacceptable in some precision-sensitive scenarios.
However, similar to quantification, it is a solution that uses an
acceptable precision loss in some scenarios for performance. In
addition, a trade-off between performance and precision can be achieved
by configuring a proportion of discarded tokens.
Perform the test on A3. The batch size is 8 (B), the prompt length is
3.5K tokens (S), and the parallel configuration is as follows: AttnDP=2,
AttnTP=8, MoeTP=1, and MoeEP=16. In this sence, we got a 10%-15%
performance gain.
Plus, the next version, we'll have an alltoallv moe.
---------
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
2025-06-07 10:15:56 +08:00
hidden_states = torch_npu . npu_swiglu ( hidden_states )
w2 = w2 . transpose ( 1 , 2 )
hidden_states = torch_npu . npu_grouped_matmul (
x = [ hidden_states ] ,
weight = [ w2 ] ,
split_item = 2 ,
group_list_type = group_list_type ,
group_type = 0 ,
group_list = group_list ,
2025-08-07 17:20:19 +08:00
) [ 0 ]
[CI]Moe alltoall communication optimization (#1067)
[CI]Moe alltoall communication optimization
The DeepSeek V3/R1 model has 256 routing experts. During parallel
inference, if the load of an EP rank is high, the overall communication
and computing time is slowed down, which becomes a weakness of parallel
inference because the load is unevenly distributed. However, the data
volume in the prefill phase is large, and the inter-card communication
time consumption/calculation time consumption and the data volume are
closely related to each other. Therefore, less non-linear precision loss
can be used to obtain a near-linear performance improvement.
During parallel inference, global synchronization occurs during
communication. As a result, the card with low load completes the
calculation first and waits for the card with the highest load to
complete the calculation. Therefore, if the load is unbalanced, the card
with high load slows down the overall time consumption. Significant
performance gains can be achieved by discarding a small number of
tokens, which is unacceptable in some precision-sensitive scenarios.
However, similar to quantification, it is a solution that uses an
acceptable precision loss in some scenarios for performance. In
addition, a trade-off between performance and precision can be achieved
by configuring a proportion of discarded tokens.
Perform the test on A3. The batch size is 8 (B), the prompt length is
3.5K tokens (S), and the parallel configuration is as follows: AttnDP=2,
AttnTP=8, MoeTP=1, and MoeEP=16. In this sence, we got a 10%-15%
performance gain.
Plus, the next version, we'll have an alltoallv moe.
---------
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
2025-06-07 10:15:56 +08:00
return hidden_states
2025-06-28 16:14:49 +08:00
def fused_experts_moge (
[Platform] Add initial experimental support for Altlas 300I series (#1333)
### What this PR does / why we need it?
Add initial experimental support for Ascend 310P, this patch squash
below PR into one to help validation:
- https://github.com/vllm-project/vllm-ascend/pull/914
- https://github.com/vllm-project/vllm-ascend/pull/1318
- https://github.com/vllm-project/vllm-ascend/pull/1327
### Does this PR introduce _any_ user-facing change?
User can run vLLM on Altlas 300I DUO series
### How was this patch tested?
CI passed with:
- E2E image build for 310P
- CI test on A2 with e2e test and longterm test
- Unit test missing because need a real 310P image to have the test,
will add in a separate PR later.
- Manually e2e test:
- Qwen2.5-7b-instruct, Qwen2.5-0.5b, Qwen3-0.6B, Qwen3-4B, Qwen3-8B:
https://github.com/vllm-project/vllm-ascend/pull/914#issuecomment-2942989322
- Pangu MGoE 72B
The patch has been tested locally on Ascend 310P hardware to ensure that
the changes do not break existing functionality and that the new
features work as intended.
#### ENV information
CANN, NNAL version: 8.1.RC1
> [!IMPORTANT]
> PTA 2.5.1 version >= torch_npu-2.5.1.post1.dev20250528 to support NZ
format and calling NNAL operators on 310P
#### Code example
##### Build vllm-ascend from source code
```shell
# download source code as vllm-ascend
cd vllm-ascend
export SOC_VERSION=Ascend310P3
pip install -v -e .
cd ..
```
##### Run offline inference
```python
from vllm import LLM, SamplingParams
prompts = ["水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。",
"水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。"]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=10)
# Create an LLM.
llm = LLM(
model="Qwen/Qwen2.5-7B-Instruct",
max_model_len=4096,
max_num_seqs=4,
dtype="float16", # IMPORTANT cause some ATB ops cannot support bf16 on 310P
disable_custom_all_reduce=True,
trust_remote_code=True,
tensor_parallel_size=2,
compilation_config={"custom_ops":['none', "+rms_norm", "+rotary_embedding"]},
)
# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
---------
Signed-off-by: Vincent Yuan <farawayboat@gmail.com>
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
Signed-off-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: Vincent Yuan <farawayboat@gmail.com>
Co-authored-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: leo-pony <nengjunma@outlook.com>
Co-authored-by: shen-shanshan <467638484@qq.com>
2025-06-21 09:00:16 +08:00
hidden_states : torch . Tensor ,
w1 : torch . Tensor ,
w2 : torch . Tensor ,
2025-07-21 09:08:04 +08:00
moe_parallel_config : FusedMoEParallelConfig ,
[Platform] Add initial experimental support for Altlas 300I series (#1333)
### What this PR does / why we need it?
Add initial experimental support for Ascend 310P, this patch squash
below PR into one to help validation:
- https://github.com/vllm-project/vllm-ascend/pull/914
- https://github.com/vllm-project/vllm-ascend/pull/1318
- https://github.com/vllm-project/vllm-ascend/pull/1327
### Does this PR introduce _any_ user-facing change?
User can run vLLM on Altlas 300I DUO series
### How was this patch tested?
CI passed with:
- E2E image build for 310P
- CI test on A2 with e2e test and longterm test
- Unit test missing because need a real 310P image to have the test,
will add in a separate PR later.
- Manually e2e test:
- Qwen2.5-7b-instruct, Qwen2.5-0.5b, Qwen3-0.6B, Qwen3-4B, Qwen3-8B:
https://github.com/vllm-project/vllm-ascend/pull/914#issuecomment-2942989322
- Pangu MGoE 72B
The patch has been tested locally on Ascend 310P hardware to ensure that
the changes do not break existing functionality and that the new
features work as intended.
#### ENV information
CANN, NNAL version: 8.1.RC1
> [!IMPORTANT]
> PTA 2.5.1 version >= torch_npu-2.5.1.post1.dev20250528 to support NZ
format and calling NNAL operators on 310P
#### Code example
##### Build vllm-ascend from source code
```shell
# download source code as vllm-ascend
cd vllm-ascend
export SOC_VERSION=Ascend310P3
pip install -v -e .
cd ..
```
##### Run offline inference
```python
from vllm import LLM, SamplingParams
prompts = ["水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。",
"水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。"]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=10)
# Create an LLM.
llm = LLM(
model="Qwen/Qwen2.5-7B-Instruct",
max_model_len=4096,
max_num_seqs=4,
dtype="float16", # IMPORTANT cause some ATB ops cannot support bf16 on 310P
disable_custom_all_reduce=True,
trust_remote_code=True,
tensor_parallel_size=2,
compilation_config={"custom_ops":['none', "+rms_norm", "+rotary_embedding"]},
)
# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
---------
Signed-off-by: Vincent Yuan <farawayboat@gmail.com>
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
Signed-off-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: Vincent Yuan <farawayboat@gmail.com>
Co-authored-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: leo-pony <nengjunma@outlook.com>
Co-authored-by: shen-shanshan <467638484@qq.com>
2025-06-21 09:00:16 +08:00
topk_weights : torch . Tensor ,
topk_ids : torch . Tensor ,
top_k : int ,
global_num_experts : int ,
expert_map : torch . Tensor = None ,
apply_router_weight_on_input : bool = False ,
) - > torch . Tensor :
"""
Args :
hidden_states : Hidden states of shape ( num_tokens , hidden_size ) .
w1 : Expert weights1 of shape ( num_experts , intermediate_size * 2 , hidden_size ) .
w2 : Expert weights2 of shape ( num_experts , hidden_size , intermediate_size ) .
topk_weights : Routing weights of shape ( num_tokens , top_k ) .
topk_ids : Selected expert IDs of shape ( num_tokens , top_k ) .
top_k : Number of experts to select .
expert_map : Expert mapping of shape ( num_experts , ) .
Returns :
hidden_states : Hidden states after routing .
"""
2025-07-21 09:08:04 +08:00
ep_size = moe_parallel_config . ep_size
[Platform] Add initial experimental support for Altlas 300I series (#1333)
### What this PR does / why we need it?
Add initial experimental support for Ascend 310P, this patch squash
below PR into one to help validation:
- https://github.com/vllm-project/vllm-ascend/pull/914
- https://github.com/vllm-project/vllm-ascend/pull/1318
- https://github.com/vllm-project/vllm-ascend/pull/1327
### Does this PR introduce _any_ user-facing change?
User can run vLLM on Altlas 300I DUO series
### How was this patch tested?
CI passed with:
- E2E image build for 310P
- CI test on A2 with e2e test and longterm test
- Unit test missing because need a real 310P image to have the test,
will add in a separate PR later.
- Manually e2e test:
- Qwen2.5-7b-instruct, Qwen2.5-0.5b, Qwen3-0.6B, Qwen3-4B, Qwen3-8B:
https://github.com/vllm-project/vllm-ascend/pull/914#issuecomment-2942989322
- Pangu MGoE 72B
The patch has been tested locally on Ascend 310P hardware to ensure that
the changes do not break existing functionality and that the new
features work as intended.
#### ENV information
CANN, NNAL version: 8.1.RC1
> [!IMPORTANT]
> PTA 2.5.1 version >= torch_npu-2.5.1.post1.dev20250528 to support NZ
format and calling NNAL operators on 310P
#### Code example
##### Build vllm-ascend from source code
```shell
# download source code as vllm-ascend
cd vllm-ascend
export SOC_VERSION=Ascend310P3
pip install -v -e .
cd ..
```
##### Run offline inference
```python
from vllm import LLM, SamplingParams
prompts = ["水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。",
"水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。"]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=10)
# Create an LLM.
llm = LLM(
model="Qwen/Qwen2.5-7B-Instruct",
max_model_len=4096,
max_num_seqs=4,
dtype="float16", # IMPORTANT cause some ATB ops cannot support bf16 on 310P
disable_custom_all_reduce=True,
trust_remote_code=True,
tensor_parallel_size=2,
compilation_config={"custom_ops":['none', "+rms_norm", "+rotary_embedding"]},
)
# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
---------
Signed-off-by: Vincent Yuan <farawayboat@gmail.com>
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
Signed-off-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: Vincent Yuan <farawayboat@gmail.com>
Co-authored-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: leo-pony <nengjunma@outlook.com>
Co-authored-by: shen-shanshan <467638484@qq.com>
2025-06-21 09:00:16 +08:00
local_num_experts = global_num_experts / / ep_size
local_num_group = top_k / / ep_size
if apply_router_weight_on_input :
assert ( topk_weights . dim ( ) == 2
) , " `topk_weights` should be in shape (num_tokens, topk) "
_ , topk = topk_weights . shape
assert (
topk == 1
) , " Only support topk=1 when `apply_router_weight_on_input` is True "
hidden_states = hidden_states * topk_weights . to ( hidden_states . dtype )
bsz , _ = hidden_states . shape
flatten_topk_ids = topk_ids . view ( - 1 )
sorted_topk_ids = torch . argsort ( flatten_topk_ids . float ( ) )
sorted_topk_ids = sorted_topk_ids . to ( torch . int32 )
sorted_hidden_states = hidden_states . index_select (
0 , sorted_topk_ids / / local_num_group )
experts_id = torch . arange ( 0 ,
local_num_experts ,
dtype = topk_ids . dtype ,
device = topk_ids . device )
num_tokens_per_expert = ( flatten_topk_ids . unsqueeze ( - 1 ) == experts_id ) . to (
torch . float32 ) . sum ( 0 )
topk_scales = topk_weights . view ( - 1 ) . index_select (
0 , sorted_topk_ids ) . unsqueeze ( - 1 )
group_list = num_tokens_per_expert . cumsum ( dim = 0 ) . to ( torch . int64 )
w1 = w1 . transpose ( 1 , 2 )
gate_up_out = torch_npu . npu_grouped_matmul (
x = [ sorted_hidden_states ] ,
weight = [ w1 ] ,
split_item = 2 ,
group_list_type = 0 ,
group_type = 0 ,
group_list = group_list ,
) [ 0 ]
2025-06-28 16:14:49 +08:00
if is_310p ( ) :
gate_up_out = torch_npu . npu_swiglu ( gate_up_out . to ( torch . float32 ) ) . to (
torch . float16 )
else :
gate_up_out = torch_npu . npu_swiglu ( gate_up_out )
[Platform] Add initial experimental support for Altlas 300I series (#1333)
### What this PR does / why we need it?
Add initial experimental support for Ascend 310P, this patch squash
below PR into one to help validation:
- https://github.com/vllm-project/vllm-ascend/pull/914
- https://github.com/vllm-project/vllm-ascend/pull/1318
- https://github.com/vllm-project/vllm-ascend/pull/1327
### Does this PR introduce _any_ user-facing change?
User can run vLLM on Altlas 300I DUO series
### How was this patch tested?
CI passed with:
- E2E image build for 310P
- CI test on A2 with e2e test and longterm test
- Unit test missing because need a real 310P image to have the test,
will add in a separate PR later.
- Manually e2e test:
- Qwen2.5-7b-instruct, Qwen2.5-0.5b, Qwen3-0.6B, Qwen3-4B, Qwen3-8B:
https://github.com/vllm-project/vllm-ascend/pull/914#issuecomment-2942989322
- Pangu MGoE 72B
The patch has been tested locally on Ascend 310P hardware to ensure that
the changes do not break existing functionality and that the new
features work as intended.
#### ENV information
CANN, NNAL version: 8.1.RC1
> [!IMPORTANT]
> PTA 2.5.1 version >= torch_npu-2.5.1.post1.dev20250528 to support NZ
format and calling NNAL operators on 310P
#### Code example
##### Build vllm-ascend from source code
```shell
# download source code as vllm-ascend
cd vllm-ascend
export SOC_VERSION=Ascend310P3
pip install -v -e .
cd ..
```
##### Run offline inference
```python
from vllm import LLM, SamplingParams
prompts = ["水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。",
"水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。"]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=10)
# Create an LLM.
llm = LLM(
model="Qwen/Qwen2.5-7B-Instruct",
max_model_len=4096,
max_num_seqs=4,
dtype="float16", # IMPORTANT cause some ATB ops cannot support bf16 on 310P
disable_custom_all_reduce=True,
trust_remote_code=True,
tensor_parallel_size=2,
compilation_config={"custom_ops":['none', "+rms_norm", "+rotary_embedding"]},
)
# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
---------
Signed-off-by: Vincent Yuan <farawayboat@gmail.com>
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
Signed-off-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: Vincent Yuan <farawayboat@gmail.com>
Co-authored-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: leo-pony <nengjunma@outlook.com>
Co-authored-by: shen-shanshan <467638484@qq.com>
2025-06-21 09:00:16 +08:00
gate_up_out * = topk_scales
w2 = w2 . transpose ( 1 , 2 )
down_out_list = torch_npu . npu_grouped_matmul (
x = [ gate_up_out ] ,
weight = [ w2 ] ,
split_item = 2 ,
group_list_type = 0 ,
group_type = 0 ,
group_list = group_list ,
) [ 0 ]
2025-06-28 16:14:49 +08:00
unsorted_topk_ids = torch . argsort ( sorted_topk_ids . float ( ) ) . to ( torch . int32 )
[Platform] Add initial experimental support for Altlas 300I series (#1333)
### What this PR does / why we need it?
Add initial experimental support for Ascend 310P, this patch squash
below PR into one to help validation:
- https://github.com/vllm-project/vllm-ascend/pull/914
- https://github.com/vllm-project/vllm-ascend/pull/1318
- https://github.com/vllm-project/vllm-ascend/pull/1327
### Does this PR introduce _any_ user-facing change?
User can run vLLM on Altlas 300I DUO series
### How was this patch tested?
CI passed with:
- E2E image build for 310P
- CI test on A2 with e2e test and longterm test
- Unit test missing because need a real 310P image to have the test,
will add in a separate PR later.
- Manually e2e test:
- Qwen2.5-7b-instruct, Qwen2.5-0.5b, Qwen3-0.6B, Qwen3-4B, Qwen3-8B:
https://github.com/vllm-project/vllm-ascend/pull/914#issuecomment-2942989322
- Pangu MGoE 72B
The patch has been tested locally on Ascend 310P hardware to ensure that
the changes do not break existing functionality and that the new
features work as intended.
#### ENV information
CANN, NNAL version: 8.1.RC1
> [!IMPORTANT]
> PTA 2.5.1 version >= torch_npu-2.5.1.post1.dev20250528 to support NZ
format and calling NNAL operators on 310P
#### Code example
##### Build vllm-ascend from source code
```shell
# download source code as vllm-ascend
cd vllm-ascend
export SOC_VERSION=Ascend310P3
pip install -v -e .
cd ..
```
##### Run offline inference
```python
from vllm import LLM, SamplingParams
prompts = ["水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。",
"水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。"]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=10)
# Create an LLM.
llm = LLM(
model="Qwen/Qwen2.5-7B-Instruct",
max_model_len=4096,
max_num_seqs=4,
dtype="float16", # IMPORTANT cause some ATB ops cannot support bf16 on 310P
disable_custom_all_reduce=True,
trust_remote_code=True,
tensor_parallel_size=2,
compilation_config={"custom_ops":['none', "+rms_norm", "+rotary_embedding"]},
)
# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
---------
Signed-off-by: Vincent Yuan <farawayboat@gmail.com>
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
Signed-off-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: Vincent Yuan <farawayboat@gmail.com>
Co-authored-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: leo-pony <nengjunma@outlook.com>
Co-authored-by: shen-shanshan <467638484@qq.com>
2025-06-21 09:00:16 +08:00
unsorted_hidden_states = down_out_list . index_select ( 0 , unsorted_topk_ids )
final_hidden_states = unsorted_hidden_states . reshape (
bsz , top_k / / ep_size , - 1 ) . sum ( 1 )
return final_hidden_states
2025-08-28 10:13:35 +08:00
def quant_apply_mlp ( hidden_states : torch . Tensor ,
w1 : torch . Tensor ,
w1_scale : torch . Tensor ,
w2 : torch . Tensor ,
w2_scale : torch . Tensor ,
group_list : torch . Tensor ,
dynamic_scale : torch . Tensor = None ,
group_list_type : int = 1 ,
w1_scale_bias : torch . Tensor = None ,
w2_scale_bias : torch . Tensor = None ) - > torch . Tensor :
if dynamic_scale is None :
unquantized_hidden_states = hidden_states
hidden_states , pertoken_scale = torch_npu . npu_dynamic_quant (
hidden_states )
# Dispose the original unquantized hidden states
# to save npu memory because they're no longer used.
dispose_tensor ( unquantized_hidden_states )
else :
pertoken_scale = dynamic_scale
bias1 , bias2 = None , None
_output_dtype = w2_scale . dtype
is_mc2 = get_forward_context ( ) . fused_moe_state == FusedMoEState . MC2
if w1_scale_bias is None and is_mc2 :
w1_scale = w1_scale . to ( torch . float32 )
# gmm1: gate_up_proj
hidden_states = torch_npu . npu_grouped_matmul (
x = [ hidden_states ] ,
weight = [ w1 ] ,
split_item = 3 ,
group_list_type = group_list_type ,
group_type = 0 ,
group_list = group_list ,
output_dtype = torch . int32 ) [ 0 ]
# act_fn: swiglu
hidden_states , swiglu_out_scale = torch_npu . npu_dequant_swiglu_quant (
x = hidden_states ,
weight_scale = w1_scale ,
activation_scale = pertoken_scale ,
bias = None ,
quant_scale = None ,
quant_offset = None ,
group_index = group_list ,
activate_left = True ,
quant_mode = 1 ,
)
[Attention][Kernel]moe support for llama4 and mllama4 (#740)
### What this PR does / why we need it?
moe support for llama4 and mllama4 in vllm-ascend
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
start sever:
python -m vllm.entrypoints.openai.api_server --model
/data/nfs/benchmark/tokenizer/Llama-4-Scout-17B-16E-Instruct \
--max-num-seqs=256 \
--max-model-len=8192 \
--tensor-parallel-size=8 \
--block-size=128 \
--dtype bfloat16 \
--host=0.0.0.0 \
--port=8000 \
--gpu-memory-utilization=0.9 \
--trust-remote-code
client:
python online_server.py --model-path
/data/nfs/benchmark/tokenizer/Llama-4-Scout-17B-16E-Instruct
--image-path /data/nfs/w60040464/cherry_blossom.jpg --docker-ip
7.242.108.253 --served-port 8000 --text "what is the content of this
image?"
result:
{'id': 'chatcmpl-2b709a5d2e1a4017991ec4ba8248686a', 'object':
'chat.completion', 'created': 1747056823, 'model':
'/data/nfs/benchmark/tokenizer/Llama-4-Scout-17B-16E-Instruct',
'choices': [{'index': 0, 'message': {'role': 'assistant',
'reasoning_content': None, 'content': 'The image depicts a tower, likely
Tokyo Skytree, framed by branches of a cherry blossom tree. The tower is
white and has a distinctive shape, with a large sphere at the top and a
long, thin spire extending from it. The branches of the cherry blossom
tree are in the foreground, with pink flowers blooming on them. The
background is a clear blue sky.\n\n**Key Features:**\n\n* **Tower:**
White, spherical shape at the top, long thin spire\n', 'tool_calls':
[]}, 'logprobs': None, 'finish_reason': 'length', 'stop_reason': None}],
'usage': {'prompt_tokens': 2340, 'total_tokens': 2440,
'completion_tokens': 100, 'prompt_tokens_details': None},
'prompt_logprobs': None}
Signed-off-by: chenxu <chenxu68@huawei.com>
Co-authored-by: chenxu <chenxu68@huawei.com>
Co-authored-by: evian <eviantai@u.nus.edu>
2025-05-13 19:12:40 +08:00
2025-08-28 10:13:35 +08:00
# gmm2: down_proj
hidden_states = torch_npu . npu_grouped_matmul (
x = [ hidden_states ] ,
weight = [ w2 ] ,
scale = [ w2_scale ] ,
per_token_scale = [ swiglu_out_scale ] ,
split_item = 2 ,
group_list_type = group_list_type ,
group_type = 0 ,
group_list = group_list ,
output_dtype = w2_scale . dtype ) [ 0 ]
2025-03-11 21:08:02 +08:00
else :
2025-08-28 10:13:35 +08:00
if w1_scale_bias is not None :
if group_list_type == 0 :
group_list = torch . cat (
[ group_list [ : 1 ] ,
torch . diff ( group_list , dim = 0 ) ] )
group_list_type = 1
bias1 = [ w1_scale_bias ]
bias2 = [ w2_scale_bias ]
# TODO w4a8 scene: dynamic acquisition of dtype in the future
_output_dtype = torch . bfloat16
# gmm1: gate_up_proj
hidden_states = torch_npu . npu_grouped_matmul (
x = [ hidden_states ] ,
weight = [ w1 ] ,
scale = [ w1_scale ] ,
bias = bias1 ,
per_token_scale = [ pertoken_scale ] ,
split_item = 2 ,
group_list_type = group_list_type ,
group_type = 0 ,
group_list = group_list ,
output_dtype = _output_dtype ) [ 0 ]
# act_fn: swiglu
hidden_states = torch_npu . npu_swiglu ( hidden_states )
hidden_states , swiglu_out_scale = torch_npu . npu_dynamic_quant (
hidden_states )
# gmm2: down_proj
hidden_states = torch_npu . npu_grouped_matmul (
x = [ hidden_states ] ,
weight = [ w2 ] ,
scale = [ w2_scale ] ,
bias = bias2 ,
per_token_scale = [ swiglu_out_scale ] ,
split_item = 2 ,
group_list_type = group_list_type ,
group_type = 0 ,
group_list = group_list ,
output_dtype = _output_dtype ) [ 0 ]
return hidden_states
2025-03-11 21:08:02 +08:00
2025-02-21 17:07:37 +08:00
2025-08-28 10:13:35 +08:00
def unquant_apply_mlp (
hidden_states : torch . Tensor ,
w1 : torch . Tensor ,
w2 : torch . Tensor ,
group_list : torch . Tensor ,
group_list_type : int = 1 ,
topk_scales : Optional [ torch . Tensor ] = None ) - > torch . Tensor :
2025-02-21 17:07:37 +08:00
w1 = w1 . transpose ( 1 , 2 )
2025-08-28 10:13:35 +08:00
gate_up_out = torch_npu . npu_grouped_matmul (
x = [ hidden_states ] ,
2025-03-11 21:08:02 +08:00
weight = [ w1 ] ,
split_item = 2 ,
2025-08-28 10:13:35 +08:00
group_list_type = group_list_type ,
2025-03-11 21:08:02 +08:00
group_type = 0 ,
2025-08-28 10:13:35 +08:00
group_list = group_list ,
2025-08-07 17:20:19 +08:00
) [ 0 ]
2025-08-28 10:13:35 +08:00
if is_310p ( ) :
gate_up_out = torch_npu . npu_swiglu ( gate_up_out . to ( torch . float32 ) ) . to (
torch . float16 )
else :
gate_up_out = torch_npu . npu_swiglu ( gate_up_out )
2025-02-21 17:07:37 +08:00
2025-08-28 10:13:35 +08:00
if topk_scales is not None :
gate_up_out * = topk_scales
2025-02-21 17:07:37 +08:00
w2 = w2 . transpose ( 1 , 2 )
2025-08-28 10:13:35 +08:00
hidden_states = torch_npu . npu_grouped_matmul (
2025-03-11 21:08:02 +08:00
x = [ gate_up_out ] ,
weight = [ w2 ] ,
split_item = 2 ,
2025-08-28 10:13:35 +08:00
group_list_type = group_list_type ,
2025-03-11 21:08:02 +08:00
group_type = 0 ,
2025-08-28 10:13:35 +08:00
group_list = group_list ,
2025-08-07 17:20:19 +08:00
) [ 0 ]
2025-08-28 10:13:35 +08:00
return hidden_states
2025-03-11 21:08:02 +08:00
2025-08-28 10:13:35 +08:00
def unified_apply_mlp (
hidden_states : torch . Tensor ,
w1 : torch . Tensor ,
w1_scale : torch . Tensor ,
w2 : torch . Tensor ,
w2_scale : torch . Tensor ,
group_list : torch . Tensor ,
dynamic_scale : torch . Tensor = None ,
group_list_type : int = 1 ,
w1_scale_bias : torch . Tensor = None ,
w2_scale_bias : torch . Tensor = None ,
topk_scales : Optional [ torch . Tensor ] = None ) - > torch . Tensor :
if get_forward_context ( ) . with_quant :
return quant_apply_mlp ( hidden_states = hidden_states ,
w1 = w1 ,
w1_scale = w1_scale ,
w2 = w2 ,
w2_scale = w2_scale ,
group_list = group_list ,
dynamic_scale = dynamic_scale ,
group_list_type = group_list_type ,
w1_scale_bias = w1_scale_bias ,
w2_scale_bias = w2_scale_bias )
else :
return unquant_apply_mlp ( hidden_states = hidden_states ,
w1 = w1 ,
w2 = w2 ,
group_list = group_list ,
group_list_type = group_list_type ,
topk_scales = topk_scales )
def unified_fused_experts_eager ( hidden_states : torch . Tensor ,
w1 : torch . Tensor ,
w2 : torch . Tensor ,
topk_weights : torch . Tensor ,
topk_ids : torch . Tensor ,
row_idx : torch . Tensor ,
expert_map : Optional [ torch . Tensor ] = None ,
log2phy : Optional [ torch . Tensor ] = None ,
global_redundant_expert_num : int = 0 ,
w1_scale : Optional [ torch . Tensor ] = None ,
w1_scale_bias : Optional [ torch . Tensor ] = None ,
w2_scale : Optional [ torch . Tensor ] = None ,
w2_scale_bias : Optional [ torch . Tensor ] = None ,
shared_experts : Optional [ torch . Tensor ] = None ,
shared_gate_up : Optional [ Any ] = None ,
shared_dequant_scale : Optional [ Any ] = None ,
mc2_mask : Optional [ torch . Tensor ] = None ,
apply_router_weight_on_input : bool = False ) :
token_dispatcher = get_forward_context ( ) . token_dispatcher
results = token_dispatcher . token_dispatch (
hidden_states = hidden_states ,
topk_weights = topk_weights ,
topk_ids = topk_ids ,
row_idx = row_idx ,
expert_map = expert_map ,
log2phy = log2phy ,
global_redundant_expert_num = global_redundant_expert_num ,
shared_experts = shared_experts ,
shared_gate_up = shared_gate_up ,
shared_dequant_scale = shared_dequant_scale ,
mc2_mask = mc2_mask ,
apply_router_weight_on_input = apply_router_weight_on_input )
expert_output = unified_apply_mlp (
hidden_states = results [ " hidden_states " ] ,
w1 = w1 ,
w1_scale = w1_scale ,
w2 = w2 ,
w2_scale = w2_scale ,
group_list = results [ " group_list " ] ,
dynamic_scale = results . get ( " dynamic_scale " ) ,
group_list_type = results . get ( " group_list_type " ) ,
w1_scale_bias = w1_scale_bias ,
w2_scale_bias = w2_scale_bias ,
topk_scales = results . get ( " topk_scales " ) )
final_hidden_states = token_dispatcher . token_combine ( expert_output )
2025-03-11 21:08:02 +08:00
return final_hidden_states
2025-04-19 17:38:18 +08:00
class AscendUnquantizedFusedMoEMethod ( UnquantizedFusedMoEMethod ) :
2025-07-03 18:36:17 +08:00
def __init__ ( self , moe : FusedMoEConfig = None ) :
2025-05-28 21:18:41 +08:00
super ( ) . __init__ ( moe = moe )
2025-04-19 17:38:18 +08:00
vllm_config = get_current_vllm_config ( )
self . global_batch_size = vllm_config . scheduler_config . max_num_seqs
[CI]Moe alltoall communication optimization (#1067)
[CI]Moe alltoall communication optimization
The DeepSeek V3/R1 model has 256 routing experts. During parallel
inference, if the load of an EP rank is high, the overall communication
and computing time is slowed down, which becomes a weakness of parallel
inference because the load is unevenly distributed. However, the data
volume in the prefill phase is large, and the inter-card communication
time consumption/calculation time consumption and the data volume are
closely related to each other. Therefore, less non-linear precision loss
can be used to obtain a near-linear performance improvement.
During parallel inference, global synchronization occurs during
communication. As a result, the card with low load completes the
calculation first and waits for the card with the highest load to
complete the calculation. Therefore, if the load is unbalanced, the card
with high load slows down the overall time consumption. Significant
performance gains can be achieved by discarding a small number of
tokens, which is unacceptable in some precision-sensitive scenarios.
However, similar to quantification, it is a solution that uses an
acceptable precision loss in some scenarios for performance. In
addition, a trade-off between performance and precision can be achieved
by configuring a proportion of discarded tokens.
Perform the test on A3. The batch size is 8 (B), the prompt length is
3.5K tokens (S), and the parallel configuration is as follows: AttnDP=2,
AttnTP=8, MoeTP=1, and MoeEP=16. In this sence, we got a 10%-15%
performance gain.
Plus, the next version, we'll have an alltoallv moe.
---------
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
2025-06-07 10:15:56 +08:00
self . max_model_len = vllm_config . model_config . max_model_len
2025-08-26 14:12:43 +08:00
get_ascend_config ( )
2025-06-04 18:31:41 +08:00
2025-04-19 17:38:18 +08:00
try :
2025-07-28 14:06:20 +08:00
device_group = get_mc2_group ( ) . device_group
2025-04-19 17:38:18 +08:00
# TODO: Try local_rank = ep_group.rank_in_group
local_rank = torch . distributed . get_rank ( group = device_group )
backend = device_group . _get_backend ( torch . device ( " npu " ) )
self . moe_all_to_all_group_name = backend . get_hccl_comm_name (
local_rank )
except AttributeError :
self . moe_all_to_all_group_name = None
def process_weights_after_loading ( self , layer ) :
super ( UnquantizedFusedMoEMethod ,
self ) . process_weights_after_loading ( layer )
layer . w13_weight = torch . nn . Parameter ( self . _maybe_pad_weight (
layer . w13_weight . data ) ,
requires_grad = False )
layer . w2_weight = torch . nn . Parameter ( self . _maybe_pad_weight (
layer . w2_weight . data ) ,
requires_grad = False )
2025-08-27 11:25:02 +08:00
if not is_310p ( ) :
layer . w13_weight . data = torch_npu . npu_format_cast (
layer . w13_weight . data , ACL_FORMAT_FRACTAL_NZ )
layer . w2_weight . data = torch_npu . npu_format_cast (
layer . w2_weight . data , ACL_FORMAT_FRACTAL_NZ )
2025-04-19 17:38:18 +08:00
def apply (
self ,
layer : torch . nn . Module ,
x : torch . Tensor ,
router_logits : torch . Tensor ,
[CI]Moe alltoall communication optimization (#1067)
[CI]Moe alltoall communication optimization
The DeepSeek V3/R1 model has 256 routing experts. During parallel
inference, if the load of an EP rank is high, the overall communication
and computing time is slowed down, which becomes a weakness of parallel
inference because the load is unevenly distributed. However, the data
volume in the prefill phase is large, and the inter-card communication
time consumption/calculation time consumption and the data volume are
closely related to each other. Therefore, less non-linear precision loss
can be used to obtain a near-linear performance improvement.
During parallel inference, global synchronization occurs during
communication. As a result, the card with low load completes the
calculation first and waits for the card with the highest load to
complete the calculation. Therefore, if the load is unbalanced, the card
with high load slows down the overall time consumption. Significant
performance gains can be achieved by discarding a small number of
tokens, which is unacceptable in some precision-sensitive scenarios.
However, similar to quantification, it is a solution that uses an
acceptable precision loss in some scenarios for performance. In
addition, a trade-off between performance and precision can be achieved
by configuring a proportion of discarded tokens.
Perform the test on A3. The batch size is 8 (B), the prompt length is
3.5K tokens (S), and the parallel configuration is as follows: AttnDP=2,
AttnTP=8, MoeTP=1, and MoeEP=16. In this sence, we got a 10%-15%
performance gain.
Plus, the next version, we'll have an alltoallv moe.
---------
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
2025-06-07 10:15:56 +08:00
top_k : int ,
2025-04-19 17:38:18 +08:00
renormalize : bool ,
[CI]Moe alltoall communication optimization (#1067)
[CI]Moe alltoall communication optimization
The DeepSeek V3/R1 model has 256 routing experts. During parallel
inference, if the load of an EP rank is high, the overall communication
and computing time is slowed down, which becomes a weakness of parallel
inference because the load is unevenly distributed. However, the data
volume in the prefill phase is large, and the inter-card communication
time consumption/calculation time consumption and the data volume are
closely related to each other. Therefore, less non-linear precision loss
can be used to obtain a near-linear performance improvement.
During parallel inference, global synchronization occurs during
communication. As a result, the card with low load completes the
calculation first and waits for the card with the highest load to
complete the calculation. Therefore, if the load is unbalanced, the card
with high load slows down the overall time consumption. Significant
performance gains can be achieved by discarding a small number of
tokens, which is unacceptable in some precision-sensitive scenarios.
However, similar to quantification, it is a solution that uses an
acceptable precision loss in some scenarios for performance. In
addition, a trade-off between performance and precision can be achieved
by configuring a proportion of discarded tokens.
Perform the test on A3. The batch size is 8 (B), the prompt length is
3.5K tokens (S), and the parallel configuration is as follows: AttnDP=2,
AttnTP=8, MoeTP=1, and MoeEP=16. In this sence, we got a 10%-15%
performance gain.
Plus, the next version, we'll have an alltoallv moe.
---------
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
2025-06-07 10:15:56 +08:00
use_grouped_topk : bool = False ,
2025-04-19 17:38:18 +08:00
global_num_experts : int = - 1 ,
expert_map : Optional [ torch . Tensor ] = None ,
[CI]Moe alltoall communication optimization (#1067)
[CI]Moe alltoall communication optimization
The DeepSeek V3/R1 model has 256 routing experts. During parallel
inference, if the load of an EP rank is high, the overall communication
and computing time is slowed down, which becomes a weakness of parallel
inference because the load is unevenly distributed. However, the data
volume in the prefill phase is large, and the inter-card communication
time consumption/calculation time consumption and the data volume are
closely related to each other. Therefore, less non-linear precision loss
can be used to obtain a near-linear performance improvement.
During parallel inference, global synchronization occurs during
communication. As a result, the card with low load completes the
calculation first and waits for the card with the highest load to
complete the calculation. Therefore, if the load is unbalanced, the card
with high load slows down the overall time consumption. Significant
performance gains can be achieved by discarding a small number of
tokens, which is unacceptable in some precision-sensitive scenarios.
However, similar to quantification, it is a solution that uses an
acceptable precision loss in some scenarios for performance. In
addition, a trade-off between performance and precision can be achieved
by configuring a proportion of discarded tokens.
Perform the test on A3. The batch size is 8 (B), the prompt length is
3.5K tokens (S), and the parallel configuration is as follows: AttnDP=2,
AttnTP=8, MoeTP=1, and MoeEP=16. In this sence, we got a 10%-15%
performance gain.
Plus, the next version, we'll have an alltoallv moe.
---------
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
2025-06-07 10:15:56 +08:00
topk_group : Optional [ int ] = None ,
num_expert_group : Optional [ int ] = None ,
2025-04-19 17:38:18 +08:00
custom_routing_function : Optional [ Callable ] = None ,
scoring_func : str = " softmax " ,
e_score_correction_bias : Optional [ torch . Tensor ] = None ,
2025-05-24 14:29:36 +08:00
is_prefill : bool = False ,
2025-06-04 20:26:44 +08:00
enable_force_load_balance : bool = False ,
Support multistream of shared experts in FusedMoE (#997)
Contains on #1111 for completeness.
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
Implement multi-stream parallelism for MoE layers with shared experts,
where computation of shared experts will be overlapped with expert token
dispatch and combine. Also, when multi-stream is enabled, weights of
shared experts will be force to replicate across all cards, regardless
of any tensor parallelism configurations, to avoid AllReduce operations.
With the expected overlaping being:
```
| shared gate_up | shared act | | shared down |
| dispatch | routed gate_up, act, down | combine |
```
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
### Does this PR introduce _any_ user-facing change?
No.
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
### How was this patch tested?
Tested on 1x16 910 node, with tailored 2 layer DSKv2.
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
---------
Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
2025-06-11 09:18:38 +08:00
shared_experts : Optional [ Any ] = None ,
2025-04-19 17:38:18 +08:00
* * kwargs ,
[CI]Moe alltoall communication optimization (#1067)
[CI]Moe alltoall communication optimization
The DeepSeek V3/R1 model has 256 routing experts. During parallel
inference, if the load of an EP rank is high, the overall communication
and computing time is slowed down, which becomes a weakness of parallel
inference because the load is unevenly distributed. However, the data
volume in the prefill phase is large, and the inter-card communication
time consumption/calculation time consumption and the data volume are
closely related to each other. Therefore, less non-linear precision loss
can be used to obtain a near-linear performance improvement.
During parallel inference, global synchronization occurs during
communication. As a result, the card with low load completes the
calculation first and waits for the card with the highest load to
complete the calculation. Therefore, if the load is unbalanced, the card
with high load slows down the overall time consumption. Significant
performance gains can be achieved by discarding a small number of
tokens, which is unacceptable in some precision-sensitive scenarios.
However, similar to quantification, it is a solution that uses an
acceptable precision loss in some scenarios for performance. In
addition, a trade-off between performance and precision can be achieved
by configuring a proportion of discarded tokens.
Perform the test on A3. The batch size is 8 (B), the prompt length is
3.5K tokens (S), and the parallel configuration is as follows: AttnDP=2,
AttnTP=8, MoeTP=1, and MoeEP=16. In this sence, we got a 10%-15%
performance gain.
Plus, the next version, we'll have an alltoallv moe.
---------
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
2025-06-07 10:15:56 +08:00
) - > torch . Tensor :
2025-08-27 09:13:31 +08:00
topk_weights , topk_ids , row_idx = select_experts (
2025-08-14 11:50:53 +08:00
hidden_states = x ,
router_logits = router_logits ,
top_k = top_k ,
use_grouped_topk = use_grouped_topk ,
renormalize = renormalize ,
topk_group = topk_group ,
num_expert_group = num_expert_group ,
custom_routing_function = custom_routing_function ,
scoring_func = scoring_func ,
e_score_correction_bias = e_score_correction_bias ,
global_num_experts = global_num_experts ,
is_unquantized = True )
2025-04-23 16:23:25 +08:00
2025-06-04 20:26:44 +08:00
topk_weights = topk_weights . to ( x . dtype )
# this is a naive implementation for experts load balance so as
# to avoid accumulating too much tokens on a single rank.
# currently it is only activated when doing profile runs.
2025-08-04 15:23:20 +08:00
if enable_force_load_balance and not self . use_aclgraph :
2025-06-04 20:26:44 +08:00
topk_ids = torch . randint_like ( topk_ids , 0 , global_num_experts )
2025-08-28 10:13:35 +08:00
return unified_fused_experts_eager ( hidden_states = x ,
w1 = layer . w13_weight ,
w2 = layer . w2_weight ,
topk_weights = topk_weights ,
topk_ids = topk_ids ,
row_idx = row_idx ,
expert_map = expert_map ,
shared_experts = shared_experts ,
mc2_mask = kwargs . get (
" mc2_mask " , None ) )
2025-04-19 17:38:18 +08:00
class AscendFusedMoE ( FusedMoE ) :
2025-06-09 19:28:11 +08:00
# The moe_counter parameter is required during the initialization of EPLB
# to identify the current layer index within the MOE model.
moe_counter = - 1
2025-05-16 12:14:55 +08:00
def __init__ (
self ,
num_experts : int , # Global number of experts
top_k : int ,
hidden_size : int ,
intermediate_size : int ,
params_dtype : Optional [ torch . dtype ] = None ,
reduce_results : bool = False ,
renormalize : bool = True ,
use_grouped_topk : bool = False ,
num_expert_group : Optional [ int ] = None ,
topk_group : Optional [ int ] = None ,
quant_config : Optional [ QuantizationConfig ] = None ,
tp_size : Optional [ int ] = None ,
ep_size : Optional [ int ] = None ,
dp_size : Optional [ int ] = None ,
prefix : str = " " ,
custom_routing_function : Optional [ Callable ] = None ,
scoring_func : str = " softmax " ,
e_score_correction_bias : Optional [ torch . Tensor ] = None ,
activation : str = " silu " ,
apply_router_weight_on_input : bool = False ,
) :
# TODO: This could not initialize FusedMoE baseclass,
# fixme and make __init__() of AscendFusedMoE more clear
2025-08-08 10:20:23 +08:00
super ( ) . __init__ (
num_experts = num_experts ,
top_k = top_k ,
hidden_size = hidden_size ,
intermediate_size = intermediate_size ,
params_dtype = params_dtype ,
reduce_results = reduce_results ,
renormalize = renormalize ,
use_grouped_topk = use_grouped_topk ,
num_expert_group = num_expert_group ,
topk_group = topk_group ,
quant_config = quant_config ,
tp_size = tp_size ,
ep_size = ep_size ,
dp_size = dp_size ,
prefix = prefix ,
custom_routing_function = custom_routing_function ,
scoring_func = scoring_func ,
e_score_correction_bias = e_score_correction_bias ,
activation = activation ,
2025-08-14 09:17:50 +08:00
apply_router_weight_on_input = apply_router_weight_on_input ,
2025-08-08 10:20:23 +08:00
)
2025-06-09 19:28:11 +08:00
AscendFusedMoE . moe_counter + = 1
self . moe_instance_id = AscendFusedMoE . moe_counter
2025-04-19 17:38:18 +08:00
if params_dtype is None :
params_dtype = torch . get_default_dtype ( )
2025-05-16 12:14:55 +08:00
vllm_config = get_current_vllm_config ( )
2025-07-04 17:54:33 +08:00
self . moe_parallel_config = FusedMoEParallelConfig . make (
tp_size_ = ( tp_size if tp_size is not None else
get_tensor_model_parallel_world_size ( ) ) ,
dp_size_ = ( dp_size
if dp_size is not None else get_dp_group ( ) . world_size ) ,
vllm_parallel_config = vllm_config . parallel_config )
2025-05-28 21:18:41 +08:00
2025-04-19 17:38:18 +08:00
self . top_k = top_k
self . num_experts = num_experts
self . global_num_experts = num_experts
assert intermediate_size % self . tp_size == 0
self . intermediate_size_per_partition = intermediate_size / / self . tp_size
self . reduce_results = reduce_results
self . renormalize = renormalize
self . use_grouped_topk = use_grouped_topk
if self . use_grouped_topk :
assert num_expert_group is not None and topk_group is not None
self . num_expert_group = num_expert_group
self . topk_group = topk_group
self . custom_routing_function = custom_routing_function
self . scoring_func = scoring_func
self . e_score_correction_bias = e_score_correction_bias
self . expert_map = None
self . activation = activation
2025-06-09 19:28:11 +08:00
self . log2phy = None
self . global_redundant_expert_num = 0
2025-04-19 17:38:18 +08:00
2025-07-10 12:07:05 +08:00
is_deepseek_v3_r1 = self . global_num_experts == 256
2025-07-11 08:53:17 +08:00
self . rm_router_logits = get_rm_router_logits_state (
self . moe_parallel_config . ep_size , self . dp_size , is_deepseek_v3_r1 )
2025-07-10 12:07:05 +08:00
self . all_reduce_merge = get_all_reduce_merge_state (
self . moe_parallel_config . ep_size , is_deepseek_v3_r1 )
2025-06-09 19:28:11 +08:00
ascend_config = get_ascend_config ( )
expert_map_path = ascend_config . expert_map_path
if expert_map_path and os . path . exists ( expert_map_path ) :
# moe expert load balance
expert_load_balancer = ExpertLoadBalancer ( expert_map_path ,
self . global_num_experts )
self . local_num_experts , self . expert_map = \
expert_load_balancer . get_rank_placement_map (
self . moe_instance_id ,
get_ep_group ( ) . rank_in_group )
self . log2phy = expert_load_balancer . get_rank_log2phy_map (
self . moe_instance_id ,
get_ep_group ( ) . rank_in_group )
self . global_redundant_expert_num = \
expert_load_balancer . get_global_redundant_expert_num ( )
else :
# Create a tensor of size num_experts filled with -1
self . local_num_experts , self . expert_map = determine_expert_map (
self . ep_size ,
get_ep_group ( ) . rank_in_group , self . global_num_experts )
2025-05-16 12:14:55 +08:00
2025-08-12 14:12:12 +08:00
self . enable_shared_expert_dp = ascend_config . enable_shared_expert_dp
2025-05-16 12:14:55 +08:00
2025-04-19 17:38:18 +08:00
if self . scoring_func != " softmax " and not self . use_grouped_topk :
raise ValueError ( " Only softmax scoring function is supported for "
" non-grouped topk. " )
2025-08-22 17:09:08 +08:00
moe = FusedMoEConfig . make (
2025-07-09 08:52:24 +08:00
num_experts = self . global_num_experts ,
experts_per_token = top_k ,
hidden_dim = hidden_size ,
num_local_experts = self . local_num_experts ,
moe_parallel_config = self . moe_parallel_config ,
# TODO (bnell): this needs to be fixed for quantized types.
in_dtype = params_dtype ,
quant_config = quant_config )
2025-05-28 21:18:41 +08:00
2025-08-22 17:09:08 +08:00
self . moe_config = moe
2025-05-28 21:18:41 +08:00
if quant_config is None :
2025-08-22 17:09:08 +08:00
self . quant_method = AscendUnquantizedFusedMoEMethod ( moe )
2025-05-28 21:18:41 +08:00
else :
self . quant_method = quant_config . get_quant_method ( self , prefix )
2025-05-16 12:14:55 +08:00
2025-04-19 17:38:18 +08:00
assert self . quant_method is not None
local_num_experts = torch . sum ( self . expert_map != - 1 ) \
if self . expert_map is not None else num_experts
moe_quant_params = {
" num_experts " : local_num_experts ,
" hidden_size " : hidden_size ,
" intermediate_size_per_partition " :
self . intermediate_size_per_partition ,
" params_dtype " : params_dtype ,
" weight_loader " : self . weight_loader ,
}
# need full intermediate size pre-sharding for WNA16 act order
if ( self . quant_method . __class__ . __name__
in ( " GPTQMarlinMoEMethod " , " CompressedTensorsWNA16MoEMethod " ) ) :
moe_quant_params [ " intermediate_size_full " ] = intermediate_size
2025-06-04 18:31:41 +08:00
self . ep_group = get_ep_group ( )
[refactor] Refactoring AscendFusedMoE (#1229)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
This PR is used for resolved [issue
1147](https://github.com/vllm-project/vllm-ascend/issues/1147)
1. Move fused_moe code into one file `fused_moe.py`.
2. Integrate branch conditions into function `get_fused_moe_state`.
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
### Does this PR introduce _any_ user-facing change?
1. This PR has removed the env `VLLM_ENABLE_MC2`, because I think this
env is useless, we can make judgments based on the current scenario
without this env, it will only increase complexity.
2. This PR has removed the env `USING_LCCL_COM`, because this env has
already expired.
3. `additional_config.expert_tensor_parallel_size` has already expired,
and now we also use parameter `enable_expert_parallel`, consistent with
the vLLM.
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
Signed-off-by: zzzzwwjj <1183291235@qq.com>
2025-06-17 17:49:03 +08:00
# NOTE: self.tp_group is not expert_tp_group
self . tp_group = get_tp_group ( ) . device_group
2025-04-19 17:38:18 +08:00
self . quant_method . create_weights ( layer = self , * * moe_quant_params )
2025-08-02 09:49:10 +08:00
self . token_dispatcher = None
if envs_ascend . VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ and isinstance (
self . quant_method , AscendUnquantizedFusedMoEMethod ) :
self . reduce_results = False
moe_dispatcher_config = (
MoEDispatcherConfig ( ) . set_num_moe_experts (
self . global_num_experts ) . set_num_local_experts (
self . local_num_experts ) . set_moe_router_topk (
top_k ) . set_group_topk ( topk_group ) .
set_num_groups ( num_expert_group ) . set_expert_bias (
e_score_correction_bias ) . set_scaling_factor ( 1.0 ) . build ( ) )
self . token_dispatcher = MoEAlltoAllSeqOverLapDispatcher (
moe_dispatcher_config )
if envs_ascend . VLLM_ASCEND_ENABLE_DBO :
token_dispatcher1 = MoEAlltoAllSeqOverLapDispatcher (
moe_dispatcher_config )
self . token_dispatchers = [
self . token_dispatcher , token_dispatcher1
]
2025-04-19 17:38:18 +08:00
2025-08-28 10:13:35 +08:00
ep_size = ( get_ep_group ( ) . world_size if
vllm_config . parallel_config . enable_expert_parallel else 1 )
with_quant = quant_config is not None
from vllm_ascend . ops . moe_dispatcher . token_dispatcher import \
setup_token_dispatchers
setup_token_dispatchers (
ep_size ,
top_k = self . top_k ,
num_experts = self . global_num_experts ,
num_global_redundant_experts = self . global_redundant_expert_num ,
num_local_experts = self . local_num_experts ,
with_quant = with_quant )
2025-07-07 22:36:03 +08:00
def naive_multicast ( self , x : torch . Tensor ,
cu_tokens_across_dp_cpu : torch . Tensor ) :
assert ( len ( x . shape ) == 2 )
buffer = torch . empty ( ( cu_tokens_across_dp_cpu [ - 1 ] , x . size ( 1 ) ) ,
device = x . device ,
dtype = x . dtype )
start = 0 if self . dp_rank == 0 else cu_tokens_across_dp_cpu [
self . dp_rank - 1 ]
end = cu_tokens_across_dp_cpu [ self . dp_rank ]
buffer [ start : end , : ] . copy_ ( x )
for idx in range ( self . dp_size ) :
start = 0 if idx == 0 else cu_tokens_across_dp_cpu [ idx - 1 ]
end = cu_tokens_across_dp_cpu [ idx ]
get_dp_group ( ) . broadcast ( buffer [ start : end , : ] , idx )
return buffer
2025-04-19 17:38:18 +08:00
def forward ( self ,
hidden_states : torch . Tensor ,
router_logits : torch . Tensor ,
is_prefill : bool ,
2025-05-15 09:19:55 +08:00
enable_force_load_balance : bool = False ,
Support multistream of shared experts in FusedMoE (#997)
Contains on #1111 for completeness.
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
Implement multi-stream parallelism for MoE layers with shared experts,
where computation of shared experts will be overlapped with expert token
dispatch and combine. Also, when multi-stream is enabled, weights of
shared experts will be force to replicate across all cards, regardless
of any tensor parallelism configurations, to avoid AllReduce operations.
With the expected overlaping being:
```
| shared gate_up | shared act | | shared down |
| dispatch | routed gate_up, act, down | combine |
```
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
### Does this PR introduce _any_ user-facing change?
No.
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
### How was this patch tested?
Tested on 1x16 910 node, with tailored 2 layer DSKv2.
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
---------
Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
2025-06-11 09:18:38 +08:00
top_k : Optional [ int ] = None ,
2025-06-25 19:56:49 +08:00
shared_experts : Optional [ Any ] = None ,
2025-07-11 08:53:17 +08:00
gate = None ,
2025-08-07 09:15:49 +08:00
replace_allreduce : bool = False ,
_metadata_for_padding : Optional [ MetadataForPadding ] = None ) :
2025-07-11 08:53:17 +08:00
2025-04-19 17:38:18 +08:00
assert self . quant_method is not None
if top_k :
real_top_k = top_k
else :
real_top_k = self . top_k
[refactor] Refactoring AscendFusedMoE (#1229)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
This PR is used for resolved [issue
1147](https://github.com/vllm-project/vllm-ascend/issues/1147)
1. Move fused_moe code into one file `fused_moe.py`.
2. Integrate branch conditions into function `get_fused_moe_state`.
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
### Does this PR introduce _any_ user-facing change?
1. This PR has removed the env `VLLM_ENABLE_MC2`, because I think this
env is useless, we can make judgments based on the current scenario
without this env, it will only increase complexity.
2. This PR has removed the env `USING_LCCL_COM`, because this env has
already expired.
3. `additional_config.expert_tensor_parallel_size` has already expired,
and now we also use parameter `enable_expert_parallel`, consistent with
the vLLM.
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
Signed-off-by: zzzzwwjj <1183291235@qq.com>
2025-06-17 17:49:03 +08:00
num_tokens , hidden_size = hidden_states . shape
2025-07-28 14:06:20 +08:00
forward_context = get_forward_context ( )
fused_moe_state = forward_context . fused_moe_state
mc2_mask = forward_context . mc2_mask
2025-07-29 23:53:19 +08:00
# For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel.
quantized_x_for_share , dynamic_scale_for_share = None , None
[refactor] Refactoring AscendFusedMoE (#1229)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
This PR is used for resolved [issue
1147](https://github.com/vllm-project/vllm-ascend/issues/1147)
1. Move fused_moe code into one file `fused_moe.py`.
2. Integrate branch conditions into function `get_fused_moe_state`.
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
### Does this PR introduce _any_ user-facing change?
1. This PR has removed the env `VLLM_ENABLE_MC2`, because I think this
env is useless, we can make judgments based on the current scenario
without this env, it will only increase complexity.
2. This PR has removed the env `USING_LCCL_COM`, because this env has
already expired.
3. `additional_config.expert_tensor_parallel_size` has already expired,
and now we also use parameter `enable_expert_parallel`, consistent with
the vLLM.
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
Signed-off-by: zzzzwwjj <1183291235@qq.com>
2025-06-17 17:49:03 +08:00
if shared_experts :
2025-08-26 14:12:43 +08:00
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
shared_hidden_states = shared_experts ( hidden_states )
[refactor] Refactoring AscendFusedMoE (#1229)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
This PR is used for resolved [issue
1147](https://github.com/vllm-project/vllm-ascend/issues/1147)
1. Move fused_moe code into one file `fused_moe.py`.
2. Integrate branch conditions into function `get_fused_moe_state`.
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
### Does this PR introduce _any_ user-facing change?
1. This PR has removed the env `VLLM_ENABLE_MC2`, because I think this
env is useless, we can make judgments based on the current scenario
without this env, it will only increase complexity.
2. This PR has removed the env `USING_LCCL_COM`, because this env has
already expired.
3. `additional_config.expert_tensor_parallel_size` has already expired,
and now we also use parameter `enable_expert_parallel`, consistent with
the vLLM.
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
Signed-off-by: zzzzwwjj <1183291235@qq.com>
2025-06-17 17:49:03 +08:00
2025-08-07 09:15:49 +08:00
mc2_mask = forward_context . mc2_mask
enable_sp = _metadata_for_padding is not None and _metadata_for_padding . not_dummy_and_is_prefill
[refactor] Refactoring AscendFusedMoE (#1229)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
This PR is used for resolved [issue
1147](https://github.com/vllm-project/vllm-ascend/issues/1147)
1. Move fused_moe code into one file `fused_moe.py`.
2. Integrate branch conditions into function `get_fused_moe_state`.
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
### Does this PR introduce _any_ user-facing change?
1. This PR has removed the env `VLLM_ENABLE_MC2`, because I think this
env is useless, we can make judgments based on the current scenario
without this env, it will only increase complexity.
2. This PR has removed the env `USING_LCCL_COM`, because this env has
already expired.
3. `additional_config.expert_tensor_parallel_size` has already expired,
and now we also use parameter `enable_expert_parallel`, consistent with
the vLLM.
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
Signed-off-by: zzzzwwjj <1183291235@qq.com>
2025-06-17 17:49:03 +08:00
tp_size = get_tensor_model_parallel_world_size ( )
2025-08-07 09:15:49 +08:00
if enable_sp :
tp_rank = get_tensor_model_parallel_rank ( )
mc2_mask_sp = _metadata_for_padding . mc2_mask if _metadata_for_padding is not None else forward_context . mc2_mask
chunk_mc2_mask = torch . tensor_split ( mc2_mask_sp , tp_size , dim = 0 )
mc2_mask = chunk_mc2_mask [ tp_rank ]
replace_allreduce = True
2025-07-28 14:06:20 +08:00
if ( fused_moe_state not in [
2025-07-07 22:36:03 +08:00
FusedMoEState . AllGather , FusedMoEState . AllGatherEP ,
FusedMoEState . NaiveMulticast
] and not replace_allreduce ) :
2025-07-28 14:06:20 +08:00
if fused_moe_state in { FusedMoEState . MC2 } :
padding_size = forward_context . padded_num_tokens
else :
# TODO: Determine if we can remove the padding
padding_size = tp_size
2025-08-12 14:12:12 +08:00
if num_tokens < padding_size and not self . enable_shared_expert_dp :
[refactor] Refactoring AscendFusedMoE (#1229)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
This PR is used for resolved [issue
1147](https://github.com/vllm-project/vllm-ascend/issues/1147)
1. Move fused_moe code into one file `fused_moe.py`.
2. Integrate branch conditions into function `get_fused_moe_state`.
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
### Does this PR introduce _any_ user-facing change?
1. This PR has removed the env `VLLM_ENABLE_MC2`, because I think this
env is useless, we can make judgments based on the current scenario
without this env, it will only increase complexity.
2. This PR has removed the env `USING_LCCL_COM`, because this env has
already expired.
3. `additional_config.expert_tensor_parallel_size` has already expired,
and now we also use parameter `enable_expert_parallel`, consistent with
the vLLM.
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
Signed-off-by: zzzzwwjj <1183291235@qq.com>
2025-06-17 17:49:03 +08:00
hidden_states = nn . functional . pad (
2025-07-28 14:06:20 +08:00
hidden_states , ( 0 , 0 , 0 , padding_size - num_tokens ) )
[refactor] Refactoring AscendFusedMoE (#1229)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
This PR is used for resolved [issue
1147](https://github.com/vllm-project/vllm-ascend/issues/1147)
1. Move fused_moe code into one file `fused_moe.py`.
2. Integrate branch conditions into function `get_fused_moe_state`.
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
### Does this PR introduce _any_ user-facing change?
1. This PR has removed the env `VLLM_ENABLE_MC2`, because I think this
env is useless, we can make judgments based on the current scenario
without this env, it will only increase complexity.
2. This PR has removed the env `USING_LCCL_COM`, because this env has
already expired.
3. `additional_config.expert_tensor_parallel_size` has already expired,
and now we also use parameter `enable_expert_parallel`, consistent with
the vLLM.
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
Signed-off-by: zzzzwwjj <1183291235@qq.com>
2025-06-17 17:49:03 +08:00
router_logits = nn . functional . pad (
2025-07-28 14:06:20 +08:00
router_logits , ( 0 , 0 , 0 , padding_size - num_tokens ) )
if tp_size > 1 :
tp_rank = get_tensor_model_parallel_rank ( )
2025-08-12 14:12:12 +08:00
if not self . enable_shared_expert_dp :
chunk_hidden_states = torch . tensor_split ( hidden_states ,
tp_size ,
dim = 0 )
chunk_router_logits = torch . tensor_split ( router_logits ,
tp_size ,
dim = 0 )
hidden_states = chunk_hidden_states [ tp_rank ]
router_logits = chunk_router_logits [ tp_rank ]
chunk_mc2_mask = torch . tensor_split ( mc2_mask , tp_size , dim = 0 )
2025-07-28 14:06:20 +08:00
mc2_mask = chunk_mc2_mask [ tp_rank ]
2025-07-11 08:53:17 +08:00
2025-07-07 22:36:03 +08:00
if self . dp_size > 1 :
if fused_moe_state == FusedMoEState . AllGather :
# NOTE: When in torchair graph, it has been padded in model_runner_v1
2025-08-26 14:12:43 +08:00
max_tokens_across_dp = forward_context . max_tokens_across_dp
if num_tokens < max_tokens_across_dp :
hidden_states = nn . functional . pad (
hidden_states ,
( 0 , 0 , 0 , max_tokens_across_dp - num_tokens ) )
if not self . rm_router_logits :
router_logits = nn . functional . pad (
router_logits ,
2025-07-28 14:06:20 +08:00
( 0 , 0 , 0 , max_tokens_across_dp - num_tokens ) )
2025-07-07 22:36:03 +08:00
hidden_states = get_dp_group ( ) . all_gather ( hidden_states , 0 )
2025-07-11 08:53:17 +08:00
if self . rm_router_logits :
router_logits , _ = gate ( hidden_states )
else :
router_logits = get_dp_group ( ) . all_gather ( router_logits , 0 )
2025-07-07 22:36:03 +08:00
elif fused_moe_state == FusedMoEState . NaiveMulticast :
cu_tokens_across_dp_cpu = get_forward_context (
) . dp_metadata . cu_tokens_across_dp_cpu
hidden_states = self . naive_multicast ( hidden_states ,
cu_tokens_across_dp_cpu )
2025-07-11 08:53:17 +08:00
if self . rm_router_logits :
router_logits , _ = gate ( hidden_states )
else :
router_logits = self . naive_multicast (
router_logits , cu_tokens_across_dp_cpu )
2025-04-19 17:38:18 +08:00
# Matrix multiply.
Support multistream of shared experts in FusedMoE (#997)
Contains on #1111 for completeness.
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
Implement multi-stream parallelism for MoE layers with shared experts,
where computation of shared experts will be overlapped with expert token
dispatch and combine. Also, when multi-stream is enabled, weights of
shared experts will be force to replicate across all cards, regardless
of any tensor parallelism configurations, to avoid AllReduce operations.
With the expected overlaping being:
```
| shared gate_up | shared act | | shared down |
| dispatch | routed gate_up, act, down | combine |
```
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
### Does this PR introduce _any_ user-facing change?
No.
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
### How was this patch tested?
Tested on 1x16 910 node, with tailored 2 layer DSKv2.
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
---------
Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
2025-06-11 09:18:38 +08:00
e_hidden_states = self . quant_method . apply (
2025-04-19 17:38:18 +08:00
layer = self ,
x = hidden_states ,
router_logits = router_logits ,
top_k = real_top_k ,
renormalize = self . renormalize ,
use_grouped_topk = self . use_grouped_topk ,
2025-04-23 16:23:25 +08:00
global_num_experts = self . global_num_experts ,
2025-04-19 17:38:18 +08:00
expert_map = self . expert_map ,
topk_group = self . topk_group ,
num_expert_group = self . num_expert_group ,
custom_routing_function = self . custom_routing_function ,
scoring_func = self . scoring_func ,
e_score_correction_bias = self . e_score_correction_bias ,
2025-05-15 09:19:55 +08:00
is_prefill = is_prefill ,
2025-06-05 23:39:38 +08:00
enable_force_load_balance = enable_force_load_balance ,
2025-06-09 19:28:11 +08:00
log2phy = self . log2phy ,
global_redundant_expert_num = self . global_redundant_expert_num ,
2025-08-26 14:12:43 +08:00
shared_experts = None ,
2025-07-28 14:06:20 +08:00
mc2_mask = mc2_mask ,
2025-08-02 09:49:10 +08:00
token_dispatcher = self . token_dispatcher ,
2025-07-29 23:53:19 +08:00
quantized_x_for_share = quantized_x_for_share ,
dynamic_scale_for_share = dynamic_scale_for_share ,
Support multistream of shared experts in FusedMoE (#997)
Contains on #1111 for completeness.
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
Implement multi-stream parallelism for MoE layers with shared experts,
where computation of shared experts will be overlapped with expert token
dispatch and combine. Also, when multi-stream is enabled, weights of
shared experts will be force to replicate across all cards, regardless
of any tensor parallelism configurations, to avoid AllReduce operations.
With the expected overlaping being:
```
| shared gate_up | shared act | | shared down |
| dispatch | routed gate_up, act, down | combine |
```
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
### Does this PR introduce _any_ user-facing change?
No.
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
### How was this patch tested?
Tested on 1x16 910 node, with tailored 2 layer DSKv2.
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
---------
Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
2025-06-11 09:18:38 +08:00
)
2025-06-05 23:39:38 +08:00
[refactor] Refactoring AscendFusedMoE (#1229)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
This PR is used for resolved [issue
1147](https://github.com/vllm-project/vllm-ascend/issues/1147)
1. Move fused_moe code into one file `fused_moe.py`.
2. Integrate branch conditions into function `get_fused_moe_state`.
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
### Does this PR introduce _any_ user-facing change?
1. This PR has removed the env `VLLM_ENABLE_MC2`, because I think this
env is useless, we can make judgments based on the current scenario
without this env, it will only increase complexity.
2. This PR has removed the env `USING_LCCL_COM`, because this env has
already expired.
3. `additional_config.expert_tensor_parallel_size` has already expired,
and now we also use parameter `enable_expert_parallel`, consistent with
the vLLM.
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
Signed-off-by: zzzzwwjj <1183291235@qq.com>
2025-06-17 17:49:03 +08:00
if shared_experts :
if isinstance ( e_hidden_states , tuple ) :
e_hidden_states , shared_hidden_states = e_hidden_states
2025-07-28 14:06:20 +08:00
if ( fused_moe_state not in [
2025-07-07 22:36:03 +08:00
FusedMoEState . AllGather , FusedMoEState . AllGatherEP ,
FusedMoEState . NaiveMulticast
2025-08-12 14:12:12 +08:00
] and not replace_allreduce and not self . enable_shared_expert_dp ) :
2025-07-28 14:06:20 +08:00
if tp_size > 1 :
dist . all_gather ( list ( chunk_hidden_states ) , e_hidden_states ,
self . tp_group )
final_hidden_states = torch . cat ( chunk_hidden_states , dim = 0 )
2025-08-02 09:49:10 +08:00
dispose_tensor ( e_hidden_states )
2025-07-28 14:06:20 +08:00
else :
final_hidden_states = e_hidden_states
if num_tokens < padding_size :
[refactor] Refactoring AscendFusedMoE (#1229)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
This PR is used for resolved [issue
1147](https://github.com/vllm-project/vllm-ascend/issues/1147)
1. Move fused_moe code into one file `fused_moe.py`.
2. Integrate branch conditions into function `get_fused_moe_state`.
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
### Does this PR introduce _any_ user-facing change?
1. This PR has removed the env `VLLM_ENABLE_MC2`, because I think this
env is useless, we can make judgments based on the current scenario
without this env, it will only increase complexity.
2. This PR has removed the env `USING_LCCL_COM`, because this env has
already expired.
3. `additional_config.expert_tensor_parallel_size` has already expired,
and now we also use parameter `enable_expert_parallel`, consistent with
the vLLM.
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
Signed-off-by: zzzzwwjj <1183291235@qq.com>
2025-06-17 17:49:03 +08:00
final_hidden_states = final_hidden_states [ : num_tokens ]
2025-08-12 14:12:12 +08:00
elif self . dp_size > 1 and not self . enable_shared_expert_dp :
2025-07-07 22:36:03 +08:00
if fused_moe_state == FusedMoEState . NaiveMulticast :
start = 0 if self . dp_rank == 0 else cu_tokens_across_dp_cpu [
self . dp_rank - 1 ]
end = cu_tokens_across_dp_cpu [ self . dp_rank ]
final_hidden_states = get_dp_group ( ) . all_reduce (
e_hidden_states )
final_hidden_states = final_hidden_states [ start : end , : ]
dispose_tensor ( e_hidden_states )
elif fused_moe_state == FusedMoEState . AllGather :
2025-07-10 10:57:24 +08:00
final_hidden_states = data_parallel_reduce_scatter (
e_hidden_states , dim = 0 )
2025-07-07 22:36:03 +08:00
final_hidden_states = final_hidden_states [ : num_tokens ]
dispose_tensor ( e_hidden_states )
2025-07-31 15:30:28 +08:00
else :
final_hidden_states = e_hidden_states
[refactor] Refactoring AscendFusedMoE (#1229)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
This PR is used for resolved [issue
1147](https://github.com/vllm-project/vllm-ascend/issues/1147)
1. Move fused_moe code into one file `fused_moe.py`.
2. Integrate branch conditions into function `get_fused_moe_state`.
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
### Does this PR introduce _any_ user-facing change?
1. This PR has removed the env `VLLM_ENABLE_MC2`, because I think this
env is useless, we can make judgments based on the current scenario
without this env, it will only increase complexity.
2. This PR has removed the env `USING_LCCL_COM`, because this env has
already expired.
3. `additional_config.expert_tensor_parallel_size` has already expired,
and now we also use parameter `enable_expert_parallel`, consistent with
the vLLM.
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
Signed-off-by: zzzzwwjj <1183291235@qq.com>
2025-06-17 17:49:03 +08:00
else :
final_hidden_states = e_hidden_states
2025-07-10 12:07:05 +08:00
if tp_size > 1 and not self . all_reduce_merge and fused_moe_state in [
2025-07-07 22:36:03 +08:00
FusedMoEState . AllGather , FusedMoEState . AllGatherEP ,
FusedMoEState . NaiveMulticast
] :
[refactor] Refactoring AscendFusedMoE (#1229)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
This PR is used for resolved [issue
1147](https://github.com/vllm-project/vllm-ascend/issues/1147)
1. Move fused_moe code into one file `fused_moe.py`.
2. Integrate branch conditions into function `get_fused_moe_state`.
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
### Does this PR introduce _any_ user-facing change?
1. This PR has removed the env `VLLM_ENABLE_MC2`, because I think this
env is useless, we can make judgments based on the current scenario
without this env, it will only increase complexity.
2. This PR has removed the env `USING_LCCL_COM`, because this env has
already expired.
3. `additional_config.expert_tensor_parallel_size` has already expired,
and now we also use parameter `enable_expert_parallel`, consistent with
the vLLM.
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
Signed-off-by: zzzzwwjj <1183291235@qq.com>
2025-06-17 17:49:03 +08:00
final_hidden_states = tensor_model_parallel_all_reduce (
final_hidden_states )
if shared_experts :
return final_hidden_states , shared_hidden_states
else :
return final_hidden_states
2025-06-07 16:46:58 +08:00
# ----------------------------------------- TBO-related --------------------------------------------
def _forward_ms_fused_moe_comp (
self ,
hidden_states : torch . Tensor ,
router_logits : torch . Tensor ,
is_prefill : bool ,
real_top_k ,
enable_force_load_balance : bool = False ,
) :
hidden_states = self . quant_method . apply (
layer = self ,
x = hidden_states ,
router_logits = router_logits ,
top_k = real_top_k ,
renormalize = self . renormalize ,
use_grouped_topk = self . use_grouped_topk ,
global_num_experts = self . global_num_experts ,
expert_map = self . expert_map ,
topk_group = self . topk_group ,
num_expert_group = self . num_expert_group ,
custom_routing_function = self . custom_routing_function ,
scoring_func = self . scoring_func ,
e_score_correction_bias = self . e_score_correction_bias ,
is_prefill = is_prefill ,
2025-08-02 09:49:10 +08:00
enable_force_load_balance = enable_force_load_balance ,
)
return hidden_states