2025-04-28 21:57:01 +08:00
|
|
|
#
|
|
|
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
|
|
|
# This file is a part of the vllm-ascend project.
|
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
#
|
|
|
|
|
|
2025-08-26 19:05:23 +08:00
|
|
|
from typing import Any, Callable, Optional
|
2025-04-28 21:57:01 +08:00
|
|
|
|
|
|
|
|
import torch
|
2025-07-06 15:29:36 +08:00
|
|
|
from vllm.config import CompilationLevel, get_current_vllm_config
|
2025-08-26 19:05:23 +08:00
|
|
|
from vllm.distributed import get_dp_group, get_ep_group, get_tp_group
|
2025-08-12 21:10:20 +08:00
|
|
|
from vllm.forward_context import get_forward_context
|
2025-08-26 19:05:23 +08:00
|
|
|
from vllm.model_executor.layers.fused_moe.layer import (
|
|
|
|
|
FusedMoE, UnquantizedFusedMoEMethod)
|
2025-04-28 21:57:01 +08:00
|
|
|
|
2025-08-04 15:23:20 +08:00
|
|
|
from vllm_ascend.ascend_config import get_ascend_config
|
2025-08-26 19:05:23 +08:00
|
|
|
from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl,
|
|
|
|
|
DummyCommImpl,
|
|
|
|
|
MC2CommImpl,
|
|
|
|
|
MoECommMethod)
|
|
|
|
|
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
|
|
|
|
from vllm_ascend.ops.fused_moe import apply_mlp, fused_experts_moge
|
2025-08-14 11:50:53 +08:00
|
|
|
from vllm_ascend.ops.layers.experts_selector import select_experts
|
[Platform] Add initial experimental support for Altlas 300I series (#1333)
### What this PR does / why we need it?
Add initial experimental support for Ascend 310P, this patch squash
below PR into one to help validation:
- https://github.com/vllm-project/vllm-ascend/pull/914
- https://github.com/vllm-project/vllm-ascend/pull/1318
- https://github.com/vllm-project/vllm-ascend/pull/1327
### Does this PR introduce _any_ user-facing change?
User can run vLLM on Altlas 300I DUO series
### How was this patch tested?
CI passed with:
- E2E image build for 310P
- CI test on A2 with e2e test and longterm test
- Unit test missing because need a real 310P image to have the test,
will add in a separate PR later.
- Manually e2e test:
- Qwen2.5-7b-instruct, Qwen2.5-0.5b, Qwen3-0.6B, Qwen3-4B, Qwen3-8B:
https://github.com/vllm-project/vllm-ascend/pull/914#issuecomment-2942989322
- Pangu MGoE 72B
The patch has been tested locally on Ascend 310P hardware to ensure that
the changes do not break existing functionality and that the new
features work as intended.
#### ENV information
CANN, NNAL version: 8.1.RC1
> [!IMPORTANT]
> PTA 2.5.1 version >= torch_npu-2.5.1.post1.dev20250528 to support NZ
format and calling NNAL operators on 310P
#### Code example
##### Build vllm-ascend from source code
```shell
# download source code as vllm-ascend
cd vllm-ascend
export SOC_VERSION=Ascend310P3
pip install -v -e .
cd ..
```
##### Run offline inference
```python
from vllm import LLM, SamplingParams
prompts = ["水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。",
"水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。"]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=10)
# Create an LLM.
llm = LLM(
model="Qwen/Qwen2.5-7B-Instruct",
max_model_len=4096,
max_num_seqs=4,
dtype="float16", # IMPORTANT cause some ATB ops cannot support bf16 on 310P
disable_custom_all_reduce=True,
trust_remote_code=True,
tensor_parallel_size=2,
compilation_config={"custom_ops":['none', "+rms_norm", "+rotary_embedding"]},
)
# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
---------
Signed-off-by: Vincent Yuan <farawayboat@gmail.com>
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
Signed-off-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: Vincent Yuan <farawayboat@gmail.com>
Co-authored-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: leo-pony <nengjunma@outlook.com>
Co-authored-by: shen-shanshan <467638484@qq.com>
2025-06-21 09:00:16 +08:00
|
|
|
from vllm_ascend.utils import is_310p
|
2025-04-28 21:57:01 +08:00
|
|
|
|
2025-07-06 15:29:36 +08:00
|
|
|
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
|
|
|
|
|
|
|
|
|
|
|
2025-08-26 19:05:23 +08:00
|
|
|
def fused_experts(
|
|
|
|
|
hidden_states: torch.Tensor,
|
|
|
|
|
w1: torch.Tensor,
|
|
|
|
|
w2: torch.Tensor,
|
|
|
|
|
topk_weights: torch.Tensor,
|
|
|
|
|
topk_ids: torch.Tensor,
|
|
|
|
|
activation: str = "silu",
|
|
|
|
|
apply_router_weight_on_input: bool = False,
|
|
|
|
|
use_int8_w8a8: bool = False,
|
|
|
|
|
use_int4_w4a8: bool = False,
|
|
|
|
|
global_num_experts: Optional[int] = None,
|
|
|
|
|
expert_map: Optional[torch.Tensor] = None,
|
|
|
|
|
w1_scale: Optional[torch.Tensor] = None,
|
|
|
|
|
w2_scale: Optional[torch.Tensor] = None,
|
|
|
|
|
w1_scale_bias: torch.Tensor = None,
|
|
|
|
|
w2_scale_bias: torch.Tensor = None,
|
|
|
|
|
moe_comm_method: Optional[MoECommMethod] = None,
|
|
|
|
|
# For TorchAir graph
|
|
|
|
|
is_torchair: bool = False,
|
|
|
|
|
# For Cube/Vector parallel
|
|
|
|
|
shared_experts: Optional[Any] = None,
|
|
|
|
|
quantized_x_for_share: Optional[Any] = None,
|
|
|
|
|
dynamic_scale_for_share: Optional[Any] = None,
|
|
|
|
|
# For load balance
|
|
|
|
|
log2phy: torch.Tensor = None,
|
|
|
|
|
global_redundant_expert_num: int = 0,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
# Check constraints
|
|
|
|
|
assert hidden_states.shape[1] == w1.shape[2], (
|
|
|
|
|
f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[2]}")
|
|
|
|
|
|
|
|
|
|
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
|
|
|
|
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
|
|
|
|
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
|
|
|
|
|
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
|
|
|
|
|
assert hidden_states.dtype in [
|
|
|
|
|
torch.float32, torch.float16, torch.bfloat16
|
|
|
|
|
]
|
|
|
|
|
assert moe_comm_method is not None, "Missing communication context"
|
|
|
|
|
|
|
|
|
|
num_experts = w1.shape[0]
|
|
|
|
|
|
|
|
|
|
permuted_hidden_states, expert_tokens, group_list_type = moe_comm_method.permute(
|
|
|
|
|
hidden_states, topk_ids, topk_weights, expert_map, num_experts)
|
|
|
|
|
mlp_output = apply_mlp(
|
|
|
|
|
permuted_hidden_states,
|
|
|
|
|
w1,
|
|
|
|
|
w2,
|
|
|
|
|
expert_tokens,
|
|
|
|
|
group_list_type=group_list_type,
|
|
|
|
|
)
|
|
|
|
|
moe_comm_method.unpermute(mlp_output, hidden_states)
|
|
|
|
|
|
|
|
|
|
return hidden_states
|
|
|
|
|
|
|
|
|
|
|
2025-07-06 15:29:36 +08:00
|
|
|
def unquantized_fused_moe_init_func(self, *args, **kwargs):
|
|
|
|
|
original_unquantized_fused_moe_init_func(self, *args, **kwargs)
|
|
|
|
|
vllm_config = get_current_vllm_config()
|
|
|
|
|
self.max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
|
2025-08-04 15:23:20 +08:00
|
|
|
|
|
|
|
|
ascend_config = get_ascend_config()
|
|
|
|
|
|
|
|
|
|
if ascend_config.torchair_graph_config.enabled:
|
|
|
|
|
self.use_aclgraph = False
|
|
|
|
|
else:
|
|
|
|
|
self.use_aclgraph = (vllm_config.compilation_config.level
|
|
|
|
|
== CompilationLevel.PIECEWISE
|
|
|
|
|
and not vllm_config.model_config.enforce_eager)
|
2025-07-06 15:29:36 +08:00
|
|
|
|
2025-04-28 21:57:01 +08:00
|
|
|
|
|
|
|
|
def forward_oot(
|
2025-07-20 02:11:57 +08:00
|
|
|
self,
|
|
|
|
|
layer: torch.nn.Module,
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
use_grouped_topk: bool,
|
|
|
|
|
top_k: int,
|
|
|
|
|
router_logits: torch.Tensor,
|
|
|
|
|
renormalize: bool,
|
|
|
|
|
topk_group: Optional[int] = None,
|
|
|
|
|
num_expert_group: Optional[int] = None,
|
|
|
|
|
custom_routing_function: Optional[Callable] = None,
|
|
|
|
|
scoring_func: str = "softmax",
|
|
|
|
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
2025-08-14 11:50:53 +08:00
|
|
|
global_num_experts: int = -1,
|
2025-07-20 02:11:57 +08:00
|
|
|
expert_map: Optional[torch.Tensor] = None,
|
|
|
|
|
apply_router_weight_on_input: bool = False,
|
|
|
|
|
activation: str = "silu",
|
|
|
|
|
enable_eplb: bool = False,
|
|
|
|
|
expert_load_view: Optional[torch.Tensor] = None,
|
|
|
|
|
logical_to_physical_map: Optional[torch.Tensor] = None,
|
|
|
|
|
logical_replica_count: Optional[torch.Tensor] = None) -> torch.Tensor:
|
2025-07-11 08:55:06 +08:00
|
|
|
|
2025-08-27 09:13:31 +08:00
|
|
|
topk_weights, topk_ids, _ = select_experts(
|
2025-07-31 21:05:56 +08:00
|
|
|
hidden_states=x,
|
|
|
|
|
router_logits=router_logits,
|
|
|
|
|
top_k=top_k,
|
|
|
|
|
use_grouped_topk=use_grouped_topk,
|
|
|
|
|
renormalize=renormalize,
|
|
|
|
|
topk_group=topk_group,
|
|
|
|
|
num_expert_group=num_expert_group,
|
|
|
|
|
custom_routing_function=custom_routing_function,
|
|
|
|
|
scoring_func=scoring_func,
|
|
|
|
|
e_score_correction_bias=e_score_correction_bias,
|
2025-08-14 11:50:53 +08:00
|
|
|
global_num_experts=global_num_experts)
|
2025-04-28 21:57:01 +08:00
|
|
|
|
2025-06-28 16:14:49 +08:00
|
|
|
if topk_ids.shape[1] < top_k or is_310p():
|
[Platform] Add initial experimental support for Altlas 300I series (#1333)
### What this PR does / why we need it?
Add initial experimental support for Ascend 310P, this patch squash
below PR into one to help validation:
- https://github.com/vllm-project/vllm-ascend/pull/914
- https://github.com/vllm-project/vllm-ascend/pull/1318
- https://github.com/vllm-project/vllm-ascend/pull/1327
### Does this PR introduce _any_ user-facing change?
User can run vLLM on Altlas 300I DUO series
### How was this patch tested?
CI passed with:
- E2E image build for 310P
- CI test on A2 with e2e test and longterm test
- Unit test missing because need a real 310P image to have the test,
will add in a separate PR later.
- Manually e2e test:
- Qwen2.5-7b-instruct, Qwen2.5-0.5b, Qwen3-0.6B, Qwen3-4B, Qwen3-8B:
https://github.com/vllm-project/vllm-ascend/pull/914#issuecomment-2942989322
- Pangu MGoE 72B
The patch has been tested locally on Ascend 310P hardware to ensure that
the changes do not break existing functionality and that the new
features work as intended.
#### ENV information
CANN, NNAL version: 8.1.RC1
> [!IMPORTANT]
> PTA 2.5.1 version >= torch_npu-2.5.1.post1.dev20250528 to support NZ
format and calling NNAL operators on 310P
#### Code example
##### Build vllm-ascend from source code
```shell
# download source code as vllm-ascend
cd vllm-ascend
export SOC_VERSION=Ascend310P3
pip install -v -e .
cd ..
```
##### Run offline inference
```python
from vllm import LLM, SamplingParams
prompts = ["水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。",
"水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。"]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=10)
# Create an LLM.
llm = LLM(
model="Qwen/Qwen2.5-7B-Instruct",
max_model_len=4096,
max_num_seqs=4,
dtype="float16", # IMPORTANT cause some ATB ops cannot support bf16 on 310P
disable_custom_all_reduce=True,
trust_remote_code=True,
tensor_parallel_size=2,
compilation_config={"custom_ops":['none', "+rms_norm", "+rotary_embedding"]},
)
# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
---------
Signed-off-by: Vincent Yuan <farawayboat@gmail.com>
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
Signed-off-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: Vincent Yuan <farawayboat@gmail.com>
Co-authored-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: leo-pony <nengjunma@outlook.com>
Co-authored-by: shen-shanshan <467638484@qq.com>
2025-06-21 09:00:16 +08:00
|
|
|
assert global_num_experts is not None
|
2025-06-28 16:14:49 +08:00
|
|
|
return fused_experts_moge(
|
[Platform] Add initial experimental support for Altlas 300I series (#1333)
### What this PR does / why we need it?
Add initial experimental support for Ascend 310P, this patch squash
below PR into one to help validation:
- https://github.com/vllm-project/vllm-ascend/pull/914
- https://github.com/vllm-project/vllm-ascend/pull/1318
- https://github.com/vllm-project/vllm-ascend/pull/1327
### Does this PR introduce _any_ user-facing change?
User can run vLLM on Altlas 300I DUO series
### How was this patch tested?
CI passed with:
- E2E image build for 310P
- CI test on A2 with e2e test and longterm test
- Unit test missing because need a real 310P image to have the test,
will add in a separate PR later.
- Manually e2e test:
- Qwen2.5-7b-instruct, Qwen2.5-0.5b, Qwen3-0.6B, Qwen3-4B, Qwen3-8B:
https://github.com/vllm-project/vllm-ascend/pull/914#issuecomment-2942989322
- Pangu MGoE 72B
The patch has been tested locally on Ascend 310P hardware to ensure that
the changes do not break existing functionality and that the new
features work as intended.
#### ENV information
CANN, NNAL version: 8.1.RC1
> [!IMPORTANT]
> PTA 2.5.1 version >= torch_npu-2.5.1.post1.dev20250528 to support NZ
format and calling NNAL operators on 310P
#### Code example
##### Build vllm-ascend from source code
```shell
# download source code as vllm-ascend
cd vllm-ascend
export SOC_VERSION=Ascend310P3
pip install -v -e .
cd ..
```
##### Run offline inference
```python
from vllm import LLM, SamplingParams
prompts = ["水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。",
"水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。"]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=10)
# Create an LLM.
llm = LLM(
model="Qwen/Qwen2.5-7B-Instruct",
max_model_len=4096,
max_num_seqs=4,
dtype="float16", # IMPORTANT cause some ATB ops cannot support bf16 on 310P
disable_custom_all_reduce=True,
trust_remote_code=True,
tensor_parallel_size=2,
compilation_config={"custom_ops":['none', "+rms_norm", "+rotary_embedding"]},
)
# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
---------
Signed-off-by: Vincent Yuan <farawayboat@gmail.com>
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
Signed-off-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: Vincent Yuan <farawayboat@gmail.com>
Co-authored-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: leo-pony <nengjunma@outlook.com>
Co-authored-by: shen-shanshan <467638484@qq.com>
2025-06-21 09:00:16 +08:00
|
|
|
hidden_states=x,
|
|
|
|
|
w1=layer.w13_weight,
|
|
|
|
|
w2=layer.w2_weight,
|
2025-07-21 09:08:04 +08:00
|
|
|
moe_parallel_config=self.moe.moe_parallel_config,
|
[Platform] Add initial experimental support for Altlas 300I series (#1333)
### What this PR does / why we need it?
Add initial experimental support for Ascend 310P, this patch squash
below PR into one to help validation:
- https://github.com/vllm-project/vllm-ascend/pull/914
- https://github.com/vllm-project/vllm-ascend/pull/1318
- https://github.com/vllm-project/vllm-ascend/pull/1327
### Does this PR introduce _any_ user-facing change?
User can run vLLM on Altlas 300I DUO series
### How was this patch tested?
CI passed with:
- E2E image build for 310P
- CI test on A2 with e2e test and longterm test
- Unit test missing because need a real 310P image to have the test,
will add in a separate PR later.
- Manually e2e test:
- Qwen2.5-7b-instruct, Qwen2.5-0.5b, Qwen3-0.6B, Qwen3-4B, Qwen3-8B:
https://github.com/vllm-project/vllm-ascend/pull/914#issuecomment-2942989322
- Pangu MGoE 72B
The patch has been tested locally on Ascend 310P hardware to ensure that
the changes do not break existing functionality and that the new
features work as intended.
#### ENV information
CANN, NNAL version: 8.1.RC1
> [!IMPORTANT]
> PTA 2.5.1 version >= torch_npu-2.5.1.post1.dev20250528 to support NZ
format and calling NNAL operators on 310P
#### Code example
##### Build vllm-ascend from source code
```shell
# download source code as vllm-ascend
cd vllm-ascend
export SOC_VERSION=Ascend310P3
pip install -v -e .
cd ..
```
##### Run offline inference
```python
from vllm import LLM, SamplingParams
prompts = ["水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。",
"水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。"]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=10)
# Create an LLM.
llm = LLM(
model="Qwen/Qwen2.5-7B-Instruct",
max_model_len=4096,
max_num_seqs=4,
dtype="float16", # IMPORTANT cause some ATB ops cannot support bf16 on 310P
disable_custom_all_reduce=True,
trust_remote_code=True,
tensor_parallel_size=2,
compilation_config={"custom_ops":['none', "+rms_norm", "+rotary_embedding"]},
)
# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
---------
Signed-off-by: Vincent Yuan <farawayboat@gmail.com>
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
Signed-off-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: Vincent Yuan <farawayboat@gmail.com>
Co-authored-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: leo-pony <nengjunma@outlook.com>
Co-authored-by: shen-shanshan <467638484@qq.com>
2025-06-21 09:00:16 +08:00
|
|
|
topk_weights=topk_weights,
|
|
|
|
|
topk_ids=topk_ids,
|
|
|
|
|
top_k=top_k,
|
|
|
|
|
global_num_experts=global_num_experts,
|
|
|
|
|
expert_map=expert_map,
|
|
|
|
|
apply_router_weight_on_input=apply_router_weight_on_input)
|
|
|
|
|
|
2025-08-12 21:10:20 +08:00
|
|
|
moe_comm_method = get_forward_context().moe_comm_method
|
2025-07-06 15:29:36 +08:00
|
|
|
|
2025-08-26 19:05:23 +08:00
|
|
|
return fused_experts(
|
[Attention][Kernel]moe support for llama4 and mllama4 (#740)
### What this PR does / why we need it?
moe support for llama4 and mllama4 in vllm-ascend
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
start sever:
python -m vllm.entrypoints.openai.api_server --model
/data/nfs/benchmark/tokenizer/Llama-4-Scout-17B-16E-Instruct \
--max-num-seqs=256 \
--max-model-len=8192 \
--tensor-parallel-size=8 \
--block-size=128 \
--dtype bfloat16 \
--host=0.0.0.0 \
--port=8000 \
--gpu-memory-utilization=0.9 \
--trust-remote-code
client:
python online_server.py --model-path
/data/nfs/benchmark/tokenizer/Llama-4-Scout-17B-16E-Instruct
--image-path /data/nfs/w60040464/cherry_blossom.jpg --docker-ip
7.242.108.253 --served-port 8000 --text "what is the content of this
image?"
result:
{'id': 'chatcmpl-2b709a5d2e1a4017991ec4ba8248686a', 'object':
'chat.completion', 'created': 1747056823, 'model':
'/data/nfs/benchmark/tokenizer/Llama-4-Scout-17B-16E-Instruct',
'choices': [{'index': 0, 'message': {'role': 'assistant',
'reasoning_content': None, 'content': 'The image depicts a tower, likely
Tokyo Skytree, framed by branches of a cherry blossom tree. The tower is
white and has a distinctive shape, with a large sphere at the top and a
long, thin spire extending from it. The branches of the cherry blossom
tree are in the foreground, with pink flowers blooming on them. The
background is a clear blue sky.\n\n**Key Features:**\n\n* **Tower:**
White, spherical shape at the top, long thin spire\n', 'tool_calls':
[]}, 'logprobs': None, 'finish_reason': 'length', 'stop_reason': None}],
'usage': {'prompt_tokens': 2340, 'total_tokens': 2440,
'completion_tokens': 100, 'prompt_tokens_details': None},
'prompt_logprobs': None}
Signed-off-by: chenxu <chenxu68@huawei.com>
Co-authored-by: chenxu <chenxu68@huawei.com>
Co-authored-by: evian <eviantai@u.nus.edu>
2025-05-13 19:12:40 +08:00
|
|
|
hidden_states=x,
|
|
|
|
|
w1=layer.w13_weight,
|
|
|
|
|
w2=layer.w2_weight,
|
|
|
|
|
topk_weights=topk_weights,
|
|
|
|
|
topk_ids=topk_ids,
|
2025-08-12 21:10:20 +08:00
|
|
|
global_num_experts=global_num_experts,
|
[Attention][Kernel]moe support for llama4 and mllama4 (#740)
### What this PR does / why we need it?
moe support for llama4 and mllama4 in vllm-ascend
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
start sever:
python -m vllm.entrypoints.openai.api_server --model
/data/nfs/benchmark/tokenizer/Llama-4-Scout-17B-16E-Instruct \
--max-num-seqs=256 \
--max-model-len=8192 \
--tensor-parallel-size=8 \
--block-size=128 \
--dtype bfloat16 \
--host=0.0.0.0 \
--port=8000 \
--gpu-memory-utilization=0.9 \
--trust-remote-code
client:
python online_server.py --model-path
/data/nfs/benchmark/tokenizer/Llama-4-Scout-17B-16E-Instruct
--image-path /data/nfs/w60040464/cherry_blossom.jpg --docker-ip
7.242.108.253 --served-port 8000 --text "what is the content of this
image?"
result:
{'id': 'chatcmpl-2b709a5d2e1a4017991ec4ba8248686a', 'object':
'chat.completion', 'created': 1747056823, 'model':
'/data/nfs/benchmark/tokenizer/Llama-4-Scout-17B-16E-Instruct',
'choices': [{'index': 0, 'message': {'role': 'assistant',
'reasoning_content': None, 'content': 'The image depicts a tower, likely
Tokyo Skytree, framed by branches of a cherry blossom tree. The tower is
white and has a distinctive shape, with a large sphere at the top and a
long, thin spire extending from it. The branches of the cherry blossom
tree are in the foreground, with pink flowers blooming on them. The
background is a clear blue sky.\n\n**Key Features:**\n\n* **Tower:**
White, spherical shape at the top, long thin spire\n', 'tool_calls':
[]}, 'logprobs': None, 'finish_reason': 'length', 'stop_reason': None}],
'usage': {'prompt_tokens': 2340, 'total_tokens': 2440,
'completion_tokens': 100, 'prompt_tokens_details': None},
'prompt_logprobs': None}
Signed-off-by: chenxu <chenxu68@huawei.com>
Co-authored-by: chenxu <chenxu68@huawei.com>
Co-authored-by: evian <eviantai@u.nus.edu>
2025-05-13 19:12:40 +08:00
|
|
|
expert_map=expert_map,
|
2025-08-12 21:10:20 +08:00
|
|
|
moe_comm_method=moe_comm_method,
|
|
|
|
|
)
|
2025-04-28 21:57:01 +08:00
|
|
|
|
|
|
|
|
|
2025-08-26 19:05:23 +08:00
|
|
|
class AscendFusedMoE(FusedMoE):
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
num_experts,
|
|
|
|
|
top_k,
|
|
|
|
|
hidden_size,
|
|
|
|
|
intermediate_size,
|
|
|
|
|
params_dtype=None,
|
|
|
|
|
reduce_results=False,
|
|
|
|
|
renormalize=True,
|
|
|
|
|
use_grouped_topk=False,
|
|
|
|
|
num_expert_group=None,
|
|
|
|
|
topk_group=None,
|
|
|
|
|
quant_config=None,
|
|
|
|
|
tp_size=None,
|
|
|
|
|
ep_size=None,
|
|
|
|
|
dp_size=None,
|
|
|
|
|
prefix="",
|
|
|
|
|
custom_routing_function=None,
|
|
|
|
|
scoring_func="softmax",
|
|
|
|
|
e_score_correction_bias=None,
|
|
|
|
|
apply_router_weight_on_input=False,
|
|
|
|
|
activation="silu",
|
|
|
|
|
enable_eplb=False,
|
|
|
|
|
num_redundant_experts=0,
|
|
|
|
|
has_bias=False,
|
|
|
|
|
):
|
|
|
|
|
super().__init__(
|
|
|
|
|
num_experts,
|
|
|
|
|
top_k,
|
|
|
|
|
hidden_size,
|
|
|
|
|
intermediate_size,
|
|
|
|
|
params_dtype,
|
|
|
|
|
reduce_results,
|
|
|
|
|
renormalize,
|
|
|
|
|
use_grouped_topk,
|
|
|
|
|
num_expert_group,
|
|
|
|
|
topk_group,
|
|
|
|
|
quant_config,
|
|
|
|
|
tp_size,
|
|
|
|
|
ep_size,
|
|
|
|
|
dp_size,
|
|
|
|
|
prefix,
|
|
|
|
|
custom_routing_function,
|
|
|
|
|
scoring_func,
|
|
|
|
|
e_score_correction_bias,
|
|
|
|
|
apply_router_weight_on_input,
|
|
|
|
|
activation,
|
|
|
|
|
enable_eplb,
|
|
|
|
|
num_redundant_experts,
|
|
|
|
|
has_bias,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.moe_config.tp_group = get_tp_group()
|
|
|
|
|
self.moe_config.dp_group = get_dp_group()
|
|
|
|
|
self.moe_config.ep_group = get_ep_group()
|
|
|
|
|
self.moe_config.mc2_group = get_mc2_group()
|
|
|
|
|
|
|
|
|
|
for method in {AllGatherCommImpl, DummyCommImpl, MC2CommImpl}:
|
|
|
|
|
setattr(
|
|
|
|
|
self, method.__name__.lower(),
|
|
|
|
|
method(moe_config=self.moe_config)) # type: ignore[abstract]
|
|
|
|
|
|
|
|
|
|
def forward_impl(self, hidden_states: torch.Tensor,
|
|
|
|
|
router_logits: torch.Tensor):
|
|
|
|
|
assert self.quant_method is not None
|
|
|
|
|
|
|
|
|
|
forward_context = get_forward_context()
|
|
|
|
|
moe_comm_method_name = forward_context.moe_comm_method_name
|
|
|
|
|
if not self.moe_config.use_ep and moe_comm_method_name != "dummycommimpl":
|
|
|
|
|
moe_comm_method_name = "allgathercommimpl"
|
|
|
|
|
forward_context.moe_comm_method = getattr(self, moe_comm_method_name)
|
|
|
|
|
|
|
|
|
|
hidden_states, router_logits = forward_context.moe_comm_method.prepare(
|
|
|
|
|
hidden_states=hidden_states, router_logits=router_logits)
|
|
|
|
|
|
|
|
|
|
# Matrix multiply.
|
|
|
|
|
final_hidden_states = self.quant_method.apply(
|
|
|
|
|
layer=self,
|
|
|
|
|
x=hidden_states,
|
|
|
|
|
router_logits=router_logits,
|
|
|
|
|
top_k=self.top_k,
|
|
|
|
|
renormalize=self.renormalize,
|
|
|
|
|
use_grouped_topk=self.use_grouped_topk,
|
|
|
|
|
global_num_experts=self.global_num_experts,
|
|
|
|
|
expert_map=self.expert_map,
|
|
|
|
|
topk_group=self.topk_group,
|
|
|
|
|
num_expert_group=self.num_expert_group,
|
|
|
|
|
custom_routing_function=self.custom_routing_function,
|
|
|
|
|
scoring_func=self.scoring_func,
|
|
|
|
|
e_score_correction_bias=self.e_score_correction_bias,
|
|
|
|
|
activation=self.activation,
|
|
|
|
|
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
|
|
|
|
enable_eplb=self.enable_eplb,
|
|
|
|
|
expert_load_view=self.expert_load_view,
|
|
|
|
|
logical_to_physical_map=self.logical_to_physical_map,
|
|
|
|
|
logical_replica_count=self.logical_replica_count,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
final_hidden_states = forward_context.moe_comm_method.finalize(
|
|
|
|
|
hidden_states=final_hidden_states,
|
|
|
|
|
reduce_results=self.reduce_results)
|
|
|
|
|
|
|
|
|
|
return final_hidden_states
|
|
|
|
|
|
|
|
|
|
|
2025-07-06 15:29:36 +08:00
|
|
|
UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func
|
2025-04-28 21:57:01 +08:00
|
|
|
UnquantizedFusedMoEMethod.forward_oot = forward_oot
|