2025-04-28 21:57:01 +08:00
|
|
|
#
|
|
|
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
|
|
|
# This file is a part of the vllm-ascend project.
|
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
#
|
2025-09-17 10:36:43 +08:00
|
|
|
import os.path
|
2025-09-08 20:09:50 +08:00
|
|
|
from typing import Callable, Optional
|
2025-04-28 21:57:01 +08:00
|
|
|
|
|
|
|
|
import torch
|
2025-08-30 11:00:35 +08:00
|
|
|
import torch_npu
|
2025-07-06 15:29:36 +08:00
|
|
|
from vllm.config import CompilationLevel, get_current_vllm_config
|
2025-09-09 18:19:56 +08:00
|
|
|
from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group,
|
|
|
|
|
tensor_model_parallel_all_reduce)
|
2025-08-12 21:10:20 +08:00
|
|
|
from vllm.forward_context import get_forward_context
|
2025-08-30 22:28:50 +08:00
|
|
|
from vllm.model_executor.layers.fused_moe.config import \
|
|
|
|
|
FusedMoEParallelConfig # isort: skip
|
2025-08-26 19:05:23 +08:00
|
|
|
from vllm.model_executor.layers.fused_moe.layer import (
|
2025-09-17 10:36:43 +08:00
|
|
|
FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map)
|
2025-09-19 19:05:01 +08:00
|
|
|
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
|
2025-04-28 21:57:01 +08:00
|
|
|
|
2025-08-04 15:23:20 +08:00
|
|
|
from vllm_ascend.ascend_config import get_ascend_config
|
2025-08-26 19:05:23 +08:00
|
|
|
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
2025-09-17 10:36:43 +08:00
|
|
|
from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map,
|
|
|
|
|
determine_default_log2phy_map)
|
|
|
|
|
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
2025-09-08 20:09:50 +08:00
|
|
|
from vllm_ascend.ops.moe.experts_selector import select_experts
|
|
|
|
|
from vllm_ascend.ops.moe.moe_comm_method import (AllGatherCommImpl,
|
2025-09-16 11:06:00 +08:00
|
|
|
AlltoAllCommImpl, MC2CommImpl,
|
|
|
|
|
NaiveMulticastCommImpl)
|
2025-09-19 11:06:45 +08:00
|
|
|
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p, npu_stream_switch
|
2025-04-28 21:57:01 +08:00
|
|
|
|
2025-07-06 15:29:36 +08:00
|
|
|
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
|
|
|
|
|
|
|
|
|
|
|
2025-08-30 22:28:50 +08:00
|
|
|
def fused_experts_moge(
|
|
|
|
|
hidden_states: torch.Tensor,
|
|
|
|
|
w1: torch.Tensor,
|
|
|
|
|
w2: torch.Tensor,
|
|
|
|
|
moe_parallel_config: FusedMoEParallelConfig,
|
|
|
|
|
topk_weights: torch.Tensor,
|
|
|
|
|
topk_ids: torch.Tensor,
|
|
|
|
|
top_k: int,
|
|
|
|
|
global_num_experts: int,
|
|
|
|
|
expert_map: torch.Tensor = None,
|
|
|
|
|
apply_router_weight_on_input: bool = False,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
hidden_states: Hidden states of shape (num_tokens, hidden_size).
|
|
|
|
|
w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size).
|
|
|
|
|
w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size).
|
|
|
|
|
topk_weights: Routing weights of shape (num_tokens, top_k).
|
|
|
|
|
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
|
|
|
|
|
top_k: Number of experts to select.
|
|
|
|
|
expert_map: Expert mapping of shape (num_experts,).
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
hidden_states: Hidden states after routing.
|
|
|
|
|
"""
|
|
|
|
|
ep_size = moe_parallel_config.ep_size
|
|
|
|
|
local_num_experts = global_num_experts // ep_size
|
|
|
|
|
local_num_group = top_k // ep_size
|
|
|
|
|
|
|
|
|
|
bsz, _ = hidden_states.shape
|
|
|
|
|
flatten_topk_ids = topk_ids.view(-1)
|
|
|
|
|
sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
|
|
|
|
|
sorted_topk_ids = sorted_topk_ids.to(torch.int32)
|
|
|
|
|
sorted_hidden_states = hidden_states.index_select(
|
|
|
|
|
0, sorted_topk_ids // local_num_group)
|
|
|
|
|
|
|
|
|
|
experts_id = torch.arange(0,
|
|
|
|
|
local_num_experts,
|
|
|
|
|
dtype=topk_ids.dtype,
|
|
|
|
|
device=topk_ids.device)
|
|
|
|
|
num_tokens_per_expert = (flatten_topk_ids.unsqueeze(-1) == experts_id).to(
|
|
|
|
|
torch.float32).sum(0)
|
|
|
|
|
topk_scales = topk_weights.view(-1).index_select(
|
|
|
|
|
0, sorted_topk_ids).unsqueeze(-1)
|
|
|
|
|
group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)
|
|
|
|
|
|
|
|
|
|
gate_up_out = torch_npu.npu_grouped_matmul(
|
|
|
|
|
x=[sorted_hidden_states],
|
|
|
|
|
weight=[w1],
|
|
|
|
|
split_item=2,
|
|
|
|
|
group_list_type=0,
|
|
|
|
|
group_type=0,
|
|
|
|
|
group_list=group_list,
|
|
|
|
|
)[0]
|
|
|
|
|
|
|
|
|
|
if is_310p():
|
|
|
|
|
gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
|
|
|
|
|
torch.float16)
|
|
|
|
|
else:
|
|
|
|
|
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
|
|
|
|
|
gate_up_out *= topk_scales
|
|
|
|
|
|
|
|
|
|
down_out_list = torch_npu.npu_grouped_matmul(
|
|
|
|
|
x=[gate_up_out],
|
|
|
|
|
weight=[w2],
|
|
|
|
|
split_item=2,
|
|
|
|
|
group_list_type=0,
|
|
|
|
|
group_type=0,
|
|
|
|
|
group_list=group_list,
|
|
|
|
|
)[0]
|
|
|
|
|
|
|
|
|
|
unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to(torch.int32)
|
|
|
|
|
unsorted_hidden_states = down_out_list.index_select(0, unsorted_topk_ids)
|
|
|
|
|
final_hidden_states = unsorted_hidden_states.reshape(
|
|
|
|
|
bsz, top_k // ep_size, -1).sum(1)
|
|
|
|
|
|
|
|
|
|
return final_hidden_states
|
|
|
|
|
|
|
|
|
|
|
2025-07-06 15:29:36 +08:00
|
|
|
def unquantized_fused_moe_init_func(self, *args, **kwargs):
|
|
|
|
|
original_unquantized_fused_moe_init_func(self, *args, **kwargs)
|
2025-08-04 15:23:20 +08:00
|
|
|
|
2025-08-30 11:00:35 +08:00
|
|
|
# NOTE: Currently, this self.use_aclgraph is only used in
|
|
|
|
|
# UnquantizedFusedMoEMethod.forward_oot to decide whether to use in
|
|
|
|
|
# ops/fused_moe.py:568 to circumvent torch.randint_like not supported issue.
|
|
|
|
|
# Once torch.randint_like is supported or removed, this flag can be removed.
|
|
|
|
|
vllm_config = get_current_vllm_config()
|
2025-08-04 15:23:20 +08:00
|
|
|
ascend_config = get_ascend_config()
|
|
|
|
|
if ascend_config.torchair_graph_config.enabled:
|
|
|
|
|
self.use_aclgraph = False
|
|
|
|
|
else:
|
|
|
|
|
self.use_aclgraph = (vllm_config.compilation_config.level
|
|
|
|
|
== CompilationLevel.PIECEWISE
|
|
|
|
|
and not vllm_config.model_config.enforce_eager)
|
2025-09-09 20:33:43 +08:00
|
|
|
self.transpose = True
|
2025-07-06 15:29:36 +08:00
|
|
|
|
2025-04-28 21:57:01 +08:00
|
|
|
|
|
|
|
|
def forward_oot(
|
2025-07-20 02:11:57 +08:00
|
|
|
self,
|
|
|
|
|
layer: torch.nn.Module,
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
use_grouped_topk: bool,
|
|
|
|
|
top_k: int,
|
|
|
|
|
router_logits: torch.Tensor,
|
|
|
|
|
renormalize: bool,
|
|
|
|
|
topk_group: Optional[int] = None,
|
|
|
|
|
num_expert_group: Optional[int] = None,
|
|
|
|
|
custom_routing_function: Optional[Callable] = None,
|
|
|
|
|
scoring_func: str = "softmax",
|
2025-09-01 19:02:50 +08:00
|
|
|
routed_scaling_factor: float = 1.0,
|
2025-07-20 02:11:57 +08:00
|
|
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
2025-08-14 11:50:53 +08:00
|
|
|
global_num_experts: int = -1,
|
2025-07-20 02:11:57 +08:00
|
|
|
expert_map: Optional[torch.Tensor] = None,
|
|
|
|
|
apply_router_weight_on_input: bool = False,
|
|
|
|
|
activation: str = "silu",
|
|
|
|
|
enable_eplb: bool = False,
|
|
|
|
|
expert_load_view: Optional[torch.Tensor] = None,
|
|
|
|
|
logical_to_physical_map: Optional[torch.Tensor] = None,
|
|
|
|
|
logical_replica_count: Optional[torch.Tensor] = None) -> torch.Tensor:
|
2025-07-11 08:55:06 +08:00
|
|
|
|
2025-09-08 20:09:50 +08:00
|
|
|
topk_weights, topk_ids, row_idx = select_experts(
|
2025-07-31 21:05:56 +08:00
|
|
|
hidden_states=x,
|
|
|
|
|
router_logits=router_logits,
|
|
|
|
|
top_k=top_k,
|
|
|
|
|
use_grouped_topk=use_grouped_topk,
|
|
|
|
|
renormalize=renormalize,
|
|
|
|
|
topk_group=topk_group,
|
|
|
|
|
num_expert_group=num_expert_group,
|
|
|
|
|
custom_routing_function=custom_routing_function,
|
|
|
|
|
scoring_func=scoring_func,
|
2025-09-01 19:02:50 +08:00
|
|
|
routed_scaling_factor=routed_scaling_factor,
|
2025-07-31 21:05:56 +08:00
|
|
|
e_score_correction_bias=e_score_correction_bias,
|
2025-08-14 11:50:53 +08:00
|
|
|
global_num_experts=global_num_experts)
|
2025-04-28 21:57:01 +08:00
|
|
|
|
2025-06-28 16:14:49 +08:00
|
|
|
if topk_ids.shape[1] < top_k or is_310p():
|
[Platform] Add initial experimental support for Altlas 300I series (#1333)
### What this PR does / why we need it?
Add initial experimental support for Ascend 310P, this patch squash
below PR into one to help validation:
- https://github.com/vllm-project/vllm-ascend/pull/914
- https://github.com/vllm-project/vllm-ascend/pull/1318
- https://github.com/vllm-project/vllm-ascend/pull/1327
### Does this PR introduce _any_ user-facing change?
User can run vLLM on Altlas 300I DUO series
### How was this patch tested?
CI passed with:
- E2E image build for 310P
- CI test on A2 with e2e test and longterm test
- Unit test missing because need a real 310P image to have the test,
will add in a separate PR later.
- Manually e2e test:
- Qwen2.5-7b-instruct, Qwen2.5-0.5b, Qwen3-0.6B, Qwen3-4B, Qwen3-8B:
https://github.com/vllm-project/vllm-ascend/pull/914#issuecomment-2942989322
- Pangu MGoE 72B
The patch has been tested locally on Ascend 310P hardware to ensure that
the changes do not break existing functionality and that the new
features work as intended.
#### ENV information
CANN, NNAL version: 8.1.RC1
> [!IMPORTANT]
> PTA 2.5.1 version >= torch_npu-2.5.1.post1.dev20250528 to support NZ
format and calling NNAL operators on 310P
#### Code example
##### Build vllm-ascend from source code
```shell
# download source code as vllm-ascend
cd vllm-ascend
export SOC_VERSION=Ascend310P3
pip install -v -e .
cd ..
```
##### Run offline inference
```python
from vllm import LLM, SamplingParams
prompts = ["水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。",
"水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。"]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=10)
# Create an LLM.
llm = LLM(
model="Qwen/Qwen2.5-7B-Instruct",
max_model_len=4096,
max_num_seqs=4,
dtype="float16", # IMPORTANT cause some ATB ops cannot support bf16 on 310P
disable_custom_all_reduce=True,
trust_remote_code=True,
tensor_parallel_size=2,
compilation_config={"custom_ops":['none', "+rms_norm", "+rotary_embedding"]},
)
# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
---------
Signed-off-by: Vincent Yuan <farawayboat@gmail.com>
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
Signed-off-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: Vincent Yuan <farawayboat@gmail.com>
Co-authored-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: leo-pony <nengjunma@outlook.com>
Co-authored-by: shen-shanshan <467638484@qq.com>
2025-06-21 09:00:16 +08:00
|
|
|
assert global_num_experts is not None
|
2025-06-28 16:14:49 +08:00
|
|
|
return fused_experts_moge(
|
[Platform] Add initial experimental support for Altlas 300I series (#1333)
### What this PR does / why we need it?
Add initial experimental support for Ascend 310P, this patch squash
below PR into one to help validation:
- https://github.com/vllm-project/vllm-ascend/pull/914
- https://github.com/vllm-project/vllm-ascend/pull/1318
- https://github.com/vllm-project/vllm-ascend/pull/1327
### Does this PR introduce _any_ user-facing change?
User can run vLLM on Altlas 300I DUO series
### How was this patch tested?
CI passed with:
- E2E image build for 310P
- CI test on A2 with e2e test and longterm test
- Unit test missing because need a real 310P image to have the test,
will add in a separate PR later.
- Manually e2e test:
- Qwen2.5-7b-instruct, Qwen2.5-0.5b, Qwen3-0.6B, Qwen3-4B, Qwen3-8B:
https://github.com/vllm-project/vllm-ascend/pull/914#issuecomment-2942989322
- Pangu MGoE 72B
The patch has been tested locally on Ascend 310P hardware to ensure that
the changes do not break existing functionality and that the new
features work as intended.
#### ENV information
CANN, NNAL version: 8.1.RC1
> [!IMPORTANT]
> PTA 2.5.1 version >= torch_npu-2.5.1.post1.dev20250528 to support NZ
format and calling NNAL operators on 310P
#### Code example
##### Build vllm-ascend from source code
```shell
# download source code as vllm-ascend
cd vllm-ascend
export SOC_VERSION=Ascend310P3
pip install -v -e .
cd ..
```
##### Run offline inference
```python
from vllm import LLM, SamplingParams
prompts = ["水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。",
"水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。"]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=10)
# Create an LLM.
llm = LLM(
model="Qwen/Qwen2.5-7B-Instruct",
max_model_len=4096,
max_num_seqs=4,
dtype="float16", # IMPORTANT cause some ATB ops cannot support bf16 on 310P
disable_custom_all_reduce=True,
trust_remote_code=True,
tensor_parallel_size=2,
compilation_config={"custom_ops":['none', "+rms_norm", "+rotary_embedding"]},
)
# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
---------
Signed-off-by: Vincent Yuan <farawayboat@gmail.com>
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
Signed-off-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: Vincent Yuan <farawayboat@gmail.com>
Co-authored-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: leo-pony <nengjunma@outlook.com>
Co-authored-by: shen-shanshan <467638484@qq.com>
2025-06-21 09:00:16 +08:00
|
|
|
hidden_states=x,
|
|
|
|
|
w1=layer.w13_weight,
|
|
|
|
|
w2=layer.w2_weight,
|
2025-07-21 09:08:04 +08:00
|
|
|
moe_parallel_config=self.moe.moe_parallel_config,
|
[Platform] Add initial experimental support for Altlas 300I series (#1333)
### What this PR does / why we need it?
Add initial experimental support for Ascend 310P, this patch squash
below PR into one to help validation:
- https://github.com/vllm-project/vllm-ascend/pull/914
- https://github.com/vllm-project/vllm-ascend/pull/1318
- https://github.com/vllm-project/vllm-ascend/pull/1327
### Does this PR introduce _any_ user-facing change?
User can run vLLM on Altlas 300I DUO series
### How was this patch tested?
CI passed with:
- E2E image build for 310P
- CI test on A2 with e2e test and longterm test
- Unit test missing because need a real 310P image to have the test,
will add in a separate PR later.
- Manually e2e test:
- Qwen2.5-7b-instruct, Qwen2.5-0.5b, Qwen3-0.6B, Qwen3-4B, Qwen3-8B:
https://github.com/vllm-project/vllm-ascend/pull/914#issuecomment-2942989322
- Pangu MGoE 72B
The patch has been tested locally on Ascend 310P hardware to ensure that
the changes do not break existing functionality and that the new
features work as intended.
#### ENV information
CANN, NNAL version: 8.1.RC1
> [!IMPORTANT]
> PTA 2.5.1 version >= torch_npu-2.5.1.post1.dev20250528 to support NZ
format and calling NNAL operators on 310P
#### Code example
##### Build vllm-ascend from source code
```shell
# download source code as vllm-ascend
cd vllm-ascend
export SOC_VERSION=Ascend310P3
pip install -v -e .
cd ..
```
##### Run offline inference
```python
from vllm import LLM, SamplingParams
prompts = ["水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。",
"水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。"]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=10)
# Create an LLM.
llm = LLM(
model="Qwen/Qwen2.5-7B-Instruct",
max_model_len=4096,
max_num_seqs=4,
dtype="float16", # IMPORTANT cause some ATB ops cannot support bf16 on 310P
disable_custom_all_reduce=True,
trust_remote_code=True,
tensor_parallel_size=2,
compilation_config={"custom_ops":['none', "+rms_norm", "+rotary_embedding"]},
)
# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
---------
Signed-off-by: Vincent Yuan <farawayboat@gmail.com>
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
Signed-off-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: Vincent Yuan <farawayboat@gmail.com>
Co-authored-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: leo-pony <nengjunma@outlook.com>
Co-authored-by: shen-shanshan <467638484@qq.com>
2025-06-21 09:00:16 +08:00
|
|
|
topk_weights=topk_weights,
|
|
|
|
|
topk_ids=topk_ids,
|
|
|
|
|
top_k=top_k,
|
|
|
|
|
global_num_experts=global_num_experts,
|
|
|
|
|
expert_map=expert_map,
|
|
|
|
|
apply_router_weight_on_input=apply_router_weight_on_input)
|
|
|
|
|
|
2025-09-08 20:09:50 +08:00
|
|
|
moe_comm_method = get_forward_context().moe_comm_method
|
|
|
|
|
return moe_comm_method.fused_experts(hidden_states=x,
|
|
|
|
|
w1=layer.w13_weight,
|
|
|
|
|
w2=layer.w2_weight,
|
|
|
|
|
topk_weights=topk_weights,
|
|
|
|
|
topk_ids=topk_ids,
|
|
|
|
|
row_idx=row_idx,
|
|
|
|
|
global_num_experts=global_num_experts,
|
|
|
|
|
expert_map=expert_map)
|
2025-04-28 21:57:01 +08:00
|
|
|
|
|
|
|
|
|
2025-08-30 11:00:35 +08:00
|
|
|
def process_weights_after_loading(self, layer):
|
|
|
|
|
super(UnquantizedFusedMoEMethod, self).process_weights_after_loading(layer)
|
2025-09-09 20:33:43 +08:00
|
|
|
if self.transpose:
|
|
|
|
|
w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose(
|
|
|
|
|
1, 2).contiguous()
|
|
|
|
|
layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False)
|
2025-08-30 11:00:35 +08:00
|
|
|
|
2025-09-09 20:33:43 +08:00
|
|
|
w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose(
|
|
|
|
|
1, 2).contiguous()
|
|
|
|
|
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
|
|
|
|
|
|
|
|
|
|
self.transpose = False
|
|
|
|
|
else:
|
|
|
|
|
w13_data = self._maybe_pad_weight(layer.w13_weight.data)
|
|
|
|
|
layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False)
|
|
|
|
|
|
|
|
|
|
w2_data = self._maybe_pad_weight(layer.w2_weight.data)
|
|
|
|
|
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
|
2025-08-30 11:00:35 +08:00
|
|
|
|
|
|
|
|
if not is_310p():
|
|
|
|
|
layer.w13_weight.data = torch_npu.npu_format_cast(
|
|
|
|
|
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
|
|
|
|
|
layer.w2_weight.data = torch_npu.npu_format_cast(
|
|
|
|
|
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ)
|
|
|
|
|
|
|
|
|
|
|
2025-08-26 19:05:23 +08:00
|
|
|
class AscendFusedMoE(FusedMoE):
|
2025-09-17 10:36:43 +08:00
|
|
|
moe_counter = -1
|
2025-08-26 19:05:23 +08:00
|
|
|
|
2025-09-16 14:13:07 +08:00
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
2025-09-17 10:36:43 +08:00
|
|
|
AscendFusedMoE.moe_counter += 1
|
|
|
|
|
self.moe_instance_id = AscendFusedMoE.moe_counter
|
2025-08-26 19:05:23 +08:00
|
|
|
self.moe_config.tp_group = get_tp_group()
|
|
|
|
|
self.moe_config.dp_group = get_dp_group()
|
|
|
|
|
self.moe_config.ep_group = get_ep_group()
|
|
|
|
|
self.moe_config.mc2_group = get_mc2_group()
|
2025-09-17 10:36:43 +08:00
|
|
|
ascend_config = get_ascend_config()
|
|
|
|
|
self.dynamic_eplb = ascend_config.dynamic_eplb
|
|
|
|
|
self.expert_map_path = ascend_config.expert_map_path
|
|
|
|
|
self.global_redundant_expert_num = ascend_config.init_redundancy_expert
|
|
|
|
|
# static eplb initializing with expert_map_path
|
|
|
|
|
if self.expert_map_path and os.path.exists(
|
|
|
|
|
self.expert_map_path) and os.access(self.expert_map_path,
|
|
|
|
|
os.R_OK):
|
|
|
|
|
self.expert_load_balancer = ExpertLoadBalancer(
|
|
|
|
|
self.expert_map_path, self.global_num_experts)
|
|
|
|
|
self.local_num_experts, self.expert_map = (
|
|
|
|
|
self.expert_load_balancer.get_rank_placement_map(
|
|
|
|
|
self.moe_instance_id, self.ep_rank))
|
|
|
|
|
self.log2phy = self.expert_load_balancer.get_rank_log2phy_map(
|
|
|
|
|
self.moe_instance_id, self.ep_rank).npu()
|
|
|
|
|
self.global_redundant_expert_num = (
|
|
|
|
|
self.expert_load_balancer.get_global_redundant_expert_num())
|
|
|
|
|
else:
|
|
|
|
|
# init moe.
|
|
|
|
|
self.local_num_experts, self.expert_map = determine_expert_map(
|
|
|
|
|
self.ep_size, self.ep_rank, self.global_num_experts)
|
|
|
|
|
# dynamic eplb initializing with not expert_map_path
|
|
|
|
|
if self.dynamic_eplb:
|
|
|
|
|
self.global_redundant_expert_num = ascend_config.init_redundancy_expert
|
|
|
|
|
self.local_num_experts, self.expert_map = determine_default_expert_map(
|
|
|
|
|
self.global_num_experts, self.ep_size, self.ep_rank,
|
|
|
|
|
self.global_redundant_expert_num)
|
|
|
|
|
self.log2phy = determine_default_log2phy_map(
|
|
|
|
|
self.global_num_experts, self.ep_size, self.ep_rank,
|
|
|
|
|
self.global_redundant_expert_num)
|
|
|
|
|
local_num_experts = (torch.sum(
|
|
|
|
|
self.expert_map != -1) if self.expert_map is not None else
|
|
|
|
|
self.global_num_experts)
|
|
|
|
|
if self.dynamic_eplb:
|
|
|
|
|
self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64)
|
2025-08-26 19:05:23 +08:00
|
|
|
|
2025-09-16 11:06:00 +08:00
|
|
|
for method in {
|
|
|
|
|
AllGatherCommImpl, AlltoAllCommImpl, MC2CommImpl,
|
|
|
|
|
NaiveMulticastCommImpl
|
|
|
|
|
}:
|
2025-08-26 19:05:23 +08:00
|
|
|
setattr(
|
|
|
|
|
self, method.__name__.lower(),
|
|
|
|
|
method(moe_config=self.moe_config)) # type: ignore[abstract]
|
|
|
|
|
|
2025-09-17 10:36:43 +08:00
|
|
|
def update_expert_map(self, new_expert_map):
|
|
|
|
|
self.expert_map = new_expert_map
|
|
|
|
|
|
|
|
|
|
def get_map(self):
|
|
|
|
|
return self.expert_map
|
|
|
|
|
|
|
|
|
|
def get_log2phy_map(self):
|
|
|
|
|
return self.logical_to_physical_map
|
|
|
|
|
|
|
|
|
|
def clear_moe_load(self):
|
|
|
|
|
if self.moe_load is not None:
|
|
|
|
|
self.moe_load.zero_()
|
|
|
|
|
|
2025-09-09 18:19:56 +08:00
|
|
|
def maybe_all_reduce_tensor_model_parallel(
|
|
|
|
|
self, final_hidden_states: torch.Tensor):
|
|
|
|
|
"""NOTE(Yizhou): This is to override the parent class method. In `mc2commimpl`,
|
|
|
|
|
and `alltoallcommimpl`, we do not need to all-reduce the final outputs since
|
|
|
|
|
the outputs are already aggregated across tensor parallel ranks in the
|
|
|
|
|
`finalize` function. In `allgathercommimpl`, we still need to all-reduce the
|
|
|
|
|
outputs since each rank only has partial outputs.
|
|
|
|
|
"""
|
|
|
|
|
forward_context = get_forward_context()
|
|
|
|
|
moe_comm_method_name = forward_context.moe_comm_method_name
|
|
|
|
|
if moe_comm_method_name in {"alltoallcommimpl", "mc2commimpl"}:
|
|
|
|
|
return final_hidden_states
|
|
|
|
|
else:
|
|
|
|
|
return tensor_model_parallel_all_reduce(final_hidden_states)
|
|
|
|
|
|
2025-08-26 19:05:23 +08:00
|
|
|
def forward_impl(self, hidden_states: torch.Tensor,
|
|
|
|
|
router_logits: torch.Tensor):
|
|
|
|
|
assert self.quant_method is not None
|
|
|
|
|
|
|
|
|
|
forward_context = get_forward_context()
|
|
|
|
|
moe_comm_method_name = forward_context.moe_comm_method_name
|
[Fix] Fix DP-related padding logic (#2582)
### What this PR does / why we need it?
The determination of attention state, padding, and other forward
metadata has been moved to an earlier stage within the input preparation
process. This change enables us to utilize a single all-reduce
operation, maximizing synchronization efficiency as early as possible.
The logic for synchronizing metadata—such as the number of tokens,
prefill status, and DBO status—across data parallel (DP) ranks has now
been unified and simplified.
For performance improvements, the all-reduce operation has been switched
from the `gloo` backend to the `npu` backend, which results in an
reduction of several milliseconds per step (**approximately 10%
performance gain for TPOT!**).
Additionally, the multi-DP server hang issue has been resolved, ensuring
no more hangs occur when `num_requests < dp_size`. Alas, a relief.
Finally, the miscalculated memory usage issue has been addressed by
removing the unnecessary `DummyCommImpl`, allowing the system to use the
real communication method when determining available memory.
### Does this PR introduce _any_ user-facing change?
None.
### How was this patch tested?
Maybe we should add an test case for multi-DP online server?
@MengqingCao
- vLLM version: v0.10.1.1
- vLLM main:
https://github.com/vllm-project/vllm/commit/c5d004aaaf3b2106d33974c673bec0568c18f762
---------
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-08-28 19:39:58 +08:00
|
|
|
|
2025-08-26 19:05:23 +08:00
|
|
|
forward_context.moe_comm_method = getattr(self, moe_comm_method_name)
|
|
|
|
|
|
|
|
|
|
hidden_states, router_logits = forward_context.moe_comm_method.prepare(
|
|
|
|
|
hidden_states=hidden_states, router_logits=router_logits)
|
|
|
|
|
|
|
|
|
|
# Matrix multiply.
|
|
|
|
|
final_hidden_states = self.quant_method.apply(
|
|
|
|
|
layer=self,
|
|
|
|
|
x=hidden_states,
|
|
|
|
|
router_logits=router_logits,
|
|
|
|
|
top_k=self.top_k,
|
|
|
|
|
renormalize=self.renormalize,
|
|
|
|
|
use_grouped_topk=self.use_grouped_topk,
|
|
|
|
|
global_num_experts=self.global_num_experts,
|
|
|
|
|
expert_map=self.expert_map,
|
|
|
|
|
topk_group=self.topk_group,
|
|
|
|
|
num_expert_group=self.num_expert_group,
|
|
|
|
|
custom_routing_function=self.custom_routing_function,
|
|
|
|
|
scoring_func=self.scoring_func,
|
|
|
|
|
e_score_correction_bias=self.e_score_correction_bias,
|
|
|
|
|
activation=self.activation,
|
|
|
|
|
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
|
|
|
|
enable_eplb=self.enable_eplb,
|
|
|
|
|
expert_load_view=self.expert_load_view,
|
|
|
|
|
logical_to_physical_map=self.logical_to_physical_map,
|
|
|
|
|
logical_replica_count=self.logical_replica_count,
|
|
|
|
|
)
|
2025-09-17 10:36:43 +08:00
|
|
|
if isinstance(final_hidden_states, tuple):
|
|
|
|
|
final_hidden_states, group_list_type, expert_tokens = final_hidden_states
|
|
|
|
|
|
|
|
|
|
if self.dynamic_eplb:
|
|
|
|
|
self.moe_load += expert_tokens if group_list_type else \
|
|
|
|
|
torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]])
|
2025-08-26 19:05:23 +08:00
|
|
|
|
|
|
|
|
final_hidden_states = forward_context.moe_comm_method.finalize(
|
|
|
|
|
hidden_states=final_hidden_states,
|
|
|
|
|
reduce_results=self.reduce_results)
|
|
|
|
|
|
|
|
|
|
return final_hidden_states
|
|
|
|
|
|
2025-09-09 20:33:43 +08:00
|
|
|
def transpose_weight(self, loaded_weight, expert_data, shard_dim):
|
|
|
|
|
# Ensure training and inference weight shapes match during RL weight updates
|
|
|
|
|
if (
|
|
|
|
|
loaded_weight.shape[1] != expert_data.shape[1] and \
|
|
|
|
|
loaded_weight.shape[0] != expert_data.shape[0]
|
|
|
|
|
):
|
|
|
|
|
shard_dim = int(not shard_dim)
|
|
|
|
|
loaded_weight = loaded_weight.transpose(0, 1).contiguous()
|
|
|
|
|
return loaded_weight, shard_dim
|
|
|
|
|
|
|
|
|
|
def _load_w13(self,
|
|
|
|
|
expert_data: torch.Tensor,
|
|
|
|
|
shard_dim: int,
|
|
|
|
|
shard_id: str,
|
|
|
|
|
loaded_weight: torch.Tensor,
|
|
|
|
|
tp_rank: int,
|
|
|
|
|
load_full: bool = False):
|
|
|
|
|
# Index the loaded weight for tp sharding.
|
|
|
|
|
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
|
|
|
|
|
loaded_weight, shard_dim = self.transpose_weight(
|
|
|
|
|
loaded_weight, expert_data, shard_dim)
|
|
|
|
|
shard_size = expert_data.shape[shard_dim] // 2
|
|
|
|
|
if not load_full:
|
|
|
|
|
loaded_weight = loaded_weight.narrow(shard_dim,
|
|
|
|
|
shard_size * tp_rank,
|
|
|
|
|
shard_size)
|
|
|
|
|
# Narrow parameter and load.
|
|
|
|
|
# w1, gate_proj: Load into first logical weight of w13.
|
|
|
|
|
if shard_id == "w1":
|
|
|
|
|
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
|
|
|
|
|
# w3, up_proj: Load into second logical weight of w13.
|
|
|
|
|
else:
|
|
|
|
|
assert shard_id == "w3"
|
|
|
|
|
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
|
|
|
|
|
expert_data.copy_(loaded_weight)
|
|
|
|
|
|
|
|
|
|
def _load_w2(self,
|
|
|
|
|
expert_data: torch.Tensor,
|
|
|
|
|
shard_dim: int,
|
|
|
|
|
loaded_weight: torch.Tensor,
|
|
|
|
|
tp_rank: int,
|
|
|
|
|
load_full: bool = False):
|
|
|
|
|
# Index the loaded weight for tp sharding.
|
|
|
|
|
# down_proj: "RowParallel" so tp sharding on input_dim
|
|
|
|
|
# Narrow parameter and load.
|
|
|
|
|
loaded_weight, shard_dim = self.transpose_weight(
|
|
|
|
|
loaded_weight, expert_data, shard_dim)
|
|
|
|
|
shard_size = expert_data.shape[shard_dim]
|
|
|
|
|
if not load_full:
|
|
|
|
|
loaded_weight = loaded_weight.narrow(shard_dim,
|
|
|
|
|
shard_size * tp_rank,
|
|
|
|
|
shard_size)
|
|
|
|
|
# w2, down_proj: Load into only logical weight of w2.
|
|
|
|
|
expert_data.copy_(loaded_weight)
|
|
|
|
|
|
2025-08-26 19:05:23 +08:00
|
|
|
|
2025-09-19 19:05:01 +08:00
|
|
|
class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
2025-09-09 18:19:56 +08:00
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
shared_experts: torch.nn.Module,
|
|
|
|
|
use_overlapped: bool = True,
|
|
|
|
|
**kwargs,
|
|
|
|
|
):
|
2025-09-19 19:05:01 +08:00
|
|
|
AscendFusedMoE.__init__(self, **kwargs)
|
2025-09-09 18:19:56 +08:00
|
|
|
self._shared_experts = shared_experts
|
|
|
|
|
self.use_overlapped = use_overlapped
|
2025-09-19 11:06:45 +08:00
|
|
|
self.shared_expert_stream = None
|
|
|
|
|
ascend_config = get_ascend_config()
|
|
|
|
|
self.multistream_overlap_shared_expert = ascend_config.multistream_overlap_shared_expert
|
|
|
|
|
if self.multistream_overlap_shared_expert:
|
|
|
|
|
self.shared_expert_stream = torch.npu.Stream()
|
2025-09-09 18:19:56 +08:00
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
hidden_states: torch.Tensor,
|
|
|
|
|
router_logits: torch.Tensor,
|
|
|
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
2025-09-19 11:06:45 +08:00
|
|
|
# Make sure the shared experts stream begins after hidden_states are ready.
|
|
|
|
|
if self.multistream_overlap_shared_expert:
|
|
|
|
|
self.shared_expert_stream.wait_stream( # type: ignore
|
|
|
|
|
torch.npu.current_stream())
|
|
|
|
|
with npu_stream_switch(self.shared_expert_stream,
|
|
|
|
|
enabled=self.multistream_overlap_shared_expert):
|
|
|
|
|
# Use a separate stream to run shared experts.
|
|
|
|
|
shared_out = self._shared_experts(hidden_states)
|
|
|
|
|
|
|
|
|
|
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
|
|
|
|
|
forward_context = get_forward_context()
|
|
|
|
|
moe_comm_method_name = forward_context.moe_comm_method_name
|
|
|
|
|
if moe_comm_method_name in {"alltoallcommimpl", "mc2commimpl"}:
|
|
|
|
|
shared_out = tensor_model_parallel_all_reduce(shared_out)
|
2025-09-09 18:19:56 +08:00
|
|
|
|
2025-09-19 19:05:01 +08:00
|
|
|
_, fused_out = AscendFusedMoE.forward(
|
|
|
|
|
self,
|
2025-09-09 18:19:56 +08:00
|
|
|
hidden_states=hidden_states,
|
|
|
|
|
router_logits=router_logits,
|
|
|
|
|
)
|
2025-09-19 11:06:45 +08:00
|
|
|
# Make sure the default stream waits for the shared experts stream to finish.
|
|
|
|
|
if self.multistream_overlap_shared_expert:
|
|
|
|
|
torch.npu.current_stream().wait_stream(self.shared_expert_stream)
|
2025-09-09 18:19:56 +08:00
|
|
|
return shared_out, fused_out
|
|
|
|
|
|
2025-09-19 19:05:01 +08:00
|
|
|
def forward_impl(self, hidden_states: torch.Tensor,
|
|
|
|
|
router_logits: torch.Tensor):
|
|
|
|
|
shared_output = torch.empty(1)
|
|
|
|
|
fused_output = AscendFusedMoE.forward_impl(
|
|
|
|
|
self,
|
|
|
|
|
hidden_states=hidden_states,
|
|
|
|
|
router_logits=router_logits,
|
|
|
|
|
)
|
|
|
|
|
return shared_output, fused_output
|
|
|
|
|
|
2025-09-09 18:19:56 +08:00
|
|
|
|
2025-07-06 15:29:36 +08:00
|
|
|
UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func
|
2025-08-30 11:00:35 +08:00
|
|
|
UnquantizedFusedMoEMethod.process_weights_after_loading = process_weights_after_loading
|
2025-09-10 08:43:10 +08:00
|
|
|
UnquantizedFusedMoEMethod.forward_oot = forward_oot
|