2025-08-30 22:28:50 +08:00
|
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
|
|
|
# Copyright 2023 The vLLM team.
|
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
# This file is a part of the vllm-ascend project.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch_npu
|
[main] Fuse GroupedMatmul, Swiglu and DynamicQuant in `W8A8_DYNAMIC` quantized MoE layers (#2275)
### What this PR does / why we need it?
Fuse `GroupedMatmul`, `Swiglu` and `DynamicQuant` into one fusion
operation `GroupedMatmulSwigluQuant`.
1. extract common functions in `w4a8_dynamic.py` and `w8a8_dynamic.py`
2. if in supported occasion, use fusion operation
`npu_grouped_matmul_swiglu_quant`
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
Tested on W8A8 quantized Qwen3-235B-A22B model with `bs=16`
1. `tp=8`, `dp=1`, `moe_tp=8`, `moe_ep=1`, TPOP increased 21.54%, Output
Token Throughput increased 27.35%
<img width="3443" height="211" alt="image"
src="https://github.com/user-attachments/assets/a1a9c14d-2310-41be-9a03-36125dabae6e"
/>
3. `tp=8`, `dp=1`, `moe_tp=1`, `moe_ep=8`, TPOP increased 17.38%, Output
Token Throughput increased 6.86%
<img width="3443" height="211" alt="image"
src="https://github.com/user-attachments/assets/1ce92e92-720d-40c0-8b4d-c493e5cb10a6"
/>
- vLLM version: v0.10.1.1
- vLLM main:
https://github.com/vllm-project/vllm/commit/6997a25ac65ed6cc3c2be6d09ca45f633a345f63
---------
Signed-off-by: Ruri <33858552+zhoux77899@users.noreply.github.com>
Signed-off-by: zhoux77899 <zhouxiang100@huawei.com>
2025-09-04 11:37:32 +08:00
|
|
|
from torch.nn.functional import pad
|
2025-08-30 22:28:50 +08:00
|
|
|
from vllm.forward_context import get_forward_context
|
2025-12-22 16:01:58 +08:00
|
|
|
from vllm.triton_utils import HAS_TRITON
|
2025-08-30 22:28:50 +08:00
|
|
|
|
2025-09-22 19:12:58 +08:00
|
|
|
from vllm_ascend.ascend_forward_context import MoECommType
|
2026-02-06 15:28:49 +08:00
|
|
|
from vllm_ascend.utils import (
|
|
|
|
|
dispose_tensor,
|
|
|
|
|
enable_custom_op,
|
|
|
|
|
get_weight_prefetch_method,
|
|
|
|
|
)
|
2025-11-30 22:52:05 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def _custom_gmm_swiglu_enabled(fusion, dynamic_eplb):
|
|
|
|
|
return fusion and dynamic_eplb and enable_custom_op()
|
2025-08-30 22:28:50 +08:00
|
|
|
|
|
|
|
|
|
2026-02-06 15:28:49 +08:00
|
|
|
def cumsum_group_list(
|
|
|
|
|
group_list: torch.Tensor, src_list_type: int, dst_list_type: int, active_num: int = 0, expert_num: int = 0
|
|
|
|
|
) -> torch.Tensor:
|
2025-12-12 14:51:20 +08:00
|
|
|
if src_list_type not in [0, 1, 2]:
|
2026-02-06 15:28:49 +08:00
|
|
|
raise ValueError(f"group_list_type should be in [0, 1, 2], but received {src_list_type}")
|
[main] Fuse GroupedMatmul, Swiglu and DynamicQuant in `W8A8_DYNAMIC` quantized MoE layers (#2275)
### What this PR does / why we need it?
Fuse `GroupedMatmul`, `Swiglu` and `DynamicQuant` into one fusion
operation `GroupedMatmulSwigluQuant`.
1. extract common functions in `w4a8_dynamic.py` and `w8a8_dynamic.py`
2. if in supported occasion, use fusion operation
`npu_grouped_matmul_swiglu_quant`
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
Tested on W8A8 quantized Qwen3-235B-A22B model with `bs=16`
1. `tp=8`, `dp=1`, `moe_tp=8`, `moe_ep=1`, TPOP increased 21.54%, Output
Token Throughput increased 27.35%
<img width="3443" height="211" alt="image"
src="https://github.com/user-attachments/assets/a1a9c14d-2310-41be-9a03-36125dabae6e"
/>
3. `tp=8`, `dp=1`, `moe_tp=1`, `moe_ep=8`, TPOP increased 17.38%, Output
Token Throughput increased 6.86%
<img width="3443" height="211" alt="image"
src="https://github.com/user-attachments/assets/1ce92e92-720d-40c0-8b4d-c493e5cb10a6"
/>
- vLLM version: v0.10.1.1
- vLLM main:
https://github.com/vllm-project/vllm/commit/6997a25ac65ed6cc3c2be6d09ca45f633a345f63
---------
Signed-off-by: Ruri <33858552+zhoux77899@users.noreply.github.com>
Signed-off-by: zhoux77899 <zhouxiang100@huawei.com>
2025-09-04 11:37:32 +08:00
|
|
|
|
2025-12-12 14:51:20 +08:00
|
|
|
if src_list_type == dst_list_type:
|
[main] Fuse GroupedMatmul, Swiglu and DynamicQuant in `W8A8_DYNAMIC` quantized MoE layers (#2275)
### What this PR does / why we need it?
Fuse `GroupedMatmul`, `Swiglu` and `DynamicQuant` into one fusion
operation `GroupedMatmulSwigluQuant`.
1. extract common functions in `w4a8_dynamic.py` and `w8a8_dynamic.py`
2. if in supported occasion, use fusion operation
`npu_grouped_matmul_swiglu_quant`
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
Tested on W8A8 quantized Qwen3-235B-A22B model with `bs=16`
1. `tp=8`, `dp=1`, `moe_tp=8`, `moe_ep=1`, TPOP increased 21.54%, Output
Token Throughput increased 27.35%
<img width="3443" height="211" alt="image"
src="https://github.com/user-attachments/assets/a1a9c14d-2310-41be-9a03-36125dabae6e"
/>
3. `tp=8`, `dp=1`, `moe_tp=1`, `moe_ep=8`, TPOP increased 17.38%, Output
Token Throughput increased 6.86%
<img width="3443" height="211" alt="image"
src="https://github.com/user-attachments/assets/1ce92e92-720d-40c0-8b4d-c493e5cb10a6"
/>
- vLLM version: v0.10.1.1
- vLLM main:
https://github.com/vllm-project/vllm/commit/6997a25ac65ed6cc3c2be6d09ca45f633a345f63
---------
Signed-off-by: Ruri <33858552+zhoux77899@users.noreply.github.com>
Signed-off-by: zhoux77899 <zhouxiang100@huawei.com>
2025-09-04 11:37:32 +08:00
|
|
|
return group_list
|
2025-12-12 14:51:20 +08:00
|
|
|
if src_list_type == 1 and dst_list_type == 0:
|
[main] Fuse GroupedMatmul, Swiglu and DynamicQuant in `W8A8_DYNAMIC` quantized MoE layers (#2275)
### What this PR does / why we need it?
Fuse `GroupedMatmul`, `Swiglu` and `DynamicQuant` into one fusion
operation `GroupedMatmulSwigluQuant`.
1. extract common functions in `w4a8_dynamic.py` and `w8a8_dynamic.py`
2. if in supported occasion, use fusion operation
`npu_grouped_matmul_swiglu_quant`
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
Tested on W8A8 quantized Qwen3-235B-A22B model with `bs=16`
1. `tp=8`, `dp=1`, `moe_tp=8`, `moe_ep=1`, TPOP increased 21.54%, Output
Token Throughput increased 27.35%
<img width="3443" height="211" alt="image"
src="https://github.com/user-attachments/assets/a1a9c14d-2310-41be-9a03-36125dabae6e"
/>
3. `tp=8`, `dp=1`, `moe_tp=1`, `moe_ep=8`, TPOP increased 17.38%, Output
Token Throughput increased 6.86%
<img width="3443" height="211" alt="image"
src="https://github.com/user-attachments/assets/1ce92e92-720d-40c0-8b4d-c493e5cb10a6"
/>
- vLLM version: v0.10.1.1
- vLLM main:
https://github.com/vllm-project/vllm/commit/6997a25ac65ed6cc3c2be6d09ca45f633a345f63
---------
Signed-off-by: Ruri <33858552+zhoux77899@users.noreply.github.com>
Signed-off-by: zhoux77899 <zhouxiang100@huawei.com>
2025-09-04 11:37:32 +08:00
|
|
|
return group_list.cumsum(dim=0)
|
2025-12-12 14:51:20 +08:00
|
|
|
if src_list_type == 0 and dst_list_type == 1:
|
|
|
|
|
group_diff = torch.diff(group_list)
|
2025-12-16 08:39:54 +08:00
|
|
|
new_group = torch.cat([group_list[0].unsqueeze(0), group_diff], dim=0)
|
2025-12-12 14:51:20 +08:00
|
|
|
return new_group
|
|
|
|
|
if src_list_type == 2 and dst_list_type == 0:
|
|
|
|
|
experts = pad(group_list[:, 0], (1, 0))
|
|
|
|
|
tokens = pad(group_list[:, 1].cumsum(dim=0), (1, 0))
|
2026-02-06 15:28:49 +08:00
|
|
|
cumsum_group_list = torch.full(
|
|
|
|
|
size=(expert_num,), fill_value=active_num, dtype=group_list.dtype, device=group_list.device
|
|
|
|
|
)
|
2025-12-12 14:51:20 +08:00
|
|
|
|
|
|
|
|
for i, (start, end) in enumerate(zip(experts[:-1], experts[1:])):
|
|
|
|
|
if end > start:
|
|
|
|
|
cumsum_group_list[start:end] = tokens[i]
|
|
|
|
|
|
|
|
|
|
return cumsum_group_list
|
|
|
|
|
raise NotImplementedError(
|
|
|
|
|
f"Conversion from src_list_type={src_list_type} to dst_list_type={dst_list_type} is not implemented yet. "
|
2026-02-06 15:28:49 +08:00
|
|
|
"This feature is under development."
|
|
|
|
|
)
|
[main] Fuse GroupedMatmul, Swiglu and DynamicQuant in `W8A8_DYNAMIC` quantized MoE layers (#2275)
### What this PR does / why we need it?
Fuse `GroupedMatmul`, `Swiglu` and `DynamicQuant` into one fusion
operation `GroupedMatmulSwigluQuant`.
1. extract common functions in `w4a8_dynamic.py` and `w8a8_dynamic.py`
2. if in supported occasion, use fusion operation
`npu_grouped_matmul_swiglu_quant`
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
Tested on W8A8 quantized Qwen3-235B-A22B model with `bs=16`
1. `tp=8`, `dp=1`, `moe_tp=8`, `moe_ep=1`, TPOP increased 21.54%, Output
Token Throughput increased 27.35%
<img width="3443" height="211" alt="image"
src="https://github.com/user-attachments/assets/a1a9c14d-2310-41be-9a03-36125dabae6e"
/>
3. `tp=8`, `dp=1`, `moe_tp=1`, `moe_ep=8`, TPOP increased 17.38%, Output
Token Throughput increased 6.86%
<img width="3443" height="211" alt="image"
src="https://github.com/user-attachments/assets/1ce92e92-720d-40c0-8b4d-c493e5cb10a6"
/>
- vLLM version: v0.10.1.1
- vLLM main:
https://github.com/vllm-project/vllm/commit/6997a25ac65ed6cc3c2be6d09ca45f633a345f63
---------
Signed-off-by: Ruri <33858552+zhoux77899@users.noreply.github.com>
Signed-off-by: zhoux77899 <zhouxiang100@huawei.com>
2025-09-04 11:37:32 +08:00
|
|
|
|
|
|
|
|
|
2026-02-06 15:28:49 +08:00
|
|
|
def quant_apply_mlp(
|
|
|
|
|
hidden_states: torch.Tensor,
|
|
|
|
|
w1: list[torch.Tensor],
|
|
|
|
|
w1_scale: list[torch.Tensor],
|
|
|
|
|
w2: list[torch.Tensor],
|
|
|
|
|
w2_scale: list[torch.Tensor],
|
|
|
|
|
group_list: torch.Tensor,
|
|
|
|
|
group_list_type: int = 1,
|
|
|
|
|
dynamic_scale: torch.Tensor = None,
|
|
|
|
|
w1_scale_bias: torch.Tensor = None,
|
|
|
|
|
w2_scale_bias: torch.Tensor = None,
|
|
|
|
|
w1_offset: torch.Tensor | None = None,
|
|
|
|
|
w2_offset: torch.Tensor | None = None,
|
|
|
|
|
fusion: bool = False,
|
|
|
|
|
dynamic_eplb: bool = False,
|
|
|
|
|
) -> torch.Tensor:
|
2025-12-10 15:58:52 +08:00
|
|
|
if w1_offset is not None:
|
|
|
|
|
unquantized_hidden_states = hidden_states
|
|
|
|
|
quantized_hidden_states = None
|
|
|
|
|
elif dynamic_scale is None:
|
2025-08-30 22:28:50 +08:00
|
|
|
unquantized_hidden_states = hidden_states
|
2026-02-06 15:28:49 +08:00
|
|
|
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
|
2025-08-30 22:28:50 +08:00
|
|
|
# Dispose the original unquantized hidden states
|
|
|
|
|
# to save npu memory because they're no longer used.
|
|
|
|
|
dispose_tensor(unquantized_hidden_states)
|
2025-11-04 16:49:58 +08:00
|
|
|
quantized_hidden_states = None
|
2025-08-30 22:28:50 +08:00
|
|
|
else:
|
2025-12-10 15:58:52 +08:00
|
|
|
unquantized_hidden_states = None
|
2025-08-30 22:28:50 +08:00
|
|
|
pertoken_scale = dynamic_scale
|
2025-11-04 16:49:58 +08:00
|
|
|
quantized_hidden_states = hidden_states
|
2025-08-30 22:28:50 +08:00
|
|
|
|
|
|
|
|
bias1, bias2 = None, None
|
2025-11-30 22:52:05 +08:00
|
|
|
_output_dtype = w2_scale[0].dtype
|
2025-08-30 22:28:50 +08:00
|
|
|
|
2025-12-23 08:49:52 +08:00
|
|
|
weight_prefetch_method = get_weight_prefetch_method()
|
2026-02-10 14:14:37 +08:00
|
|
|
weight_prefetch_method.maybe_prefetch_moe_weight_postprocess(hidden_states)
|
2025-09-22 19:12:58 +08:00
|
|
|
is_mc2 = get_forward_context().moe_comm_type == MoECommType.MC2
|
2025-12-10 15:58:52 +08:00
|
|
|
if w1_scale_bias is None and w1_offset is None and is_mc2:
|
2025-11-30 22:52:05 +08:00
|
|
|
if _custom_gmm_swiglu_enabled(fusion, dynamic_eplb):
|
|
|
|
|
# gmm1: gate_up_proj & act_fn: swiglu
|
2026-02-06 15:28:49 +08:00
|
|
|
hidden_states, swiglu_out_scale, _ = torch.ops._C_ascend.grouped_matmul_swiglu_quant_weight_nz_tensor_list(
|
|
|
|
|
x=hidden_states,
|
|
|
|
|
weight=w1,
|
|
|
|
|
weight_scale=w1_scale,
|
|
|
|
|
x_scale=pertoken_scale,
|
|
|
|
|
group_list=cumsum_group_list(group_list, group_list_type, 0),
|
|
|
|
|
)
|
2025-11-30 22:52:05 +08:00
|
|
|
elif fusion and not dynamic_eplb:
|
[main] Fuse GroupedMatmul, Swiglu and DynamicQuant in `W8A8_DYNAMIC` quantized MoE layers (#2275)
### What this PR does / why we need it?
Fuse `GroupedMatmul`, `Swiglu` and `DynamicQuant` into one fusion
operation `GroupedMatmulSwigluQuant`.
1. extract common functions in `w4a8_dynamic.py` and `w8a8_dynamic.py`
2. if in supported occasion, use fusion operation
`npu_grouped_matmul_swiglu_quant`
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
Tested on W8A8 quantized Qwen3-235B-A22B model with `bs=16`
1. `tp=8`, `dp=1`, `moe_tp=8`, `moe_ep=1`, TPOP increased 21.54%, Output
Token Throughput increased 27.35%
<img width="3443" height="211" alt="image"
src="https://github.com/user-attachments/assets/a1a9c14d-2310-41be-9a03-36125dabae6e"
/>
3. `tp=8`, `dp=1`, `moe_tp=1`, `moe_ep=8`, TPOP increased 17.38%, Output
Token Throughput increased 6.86%
<img width="3443" height="211" alt="image"
src="https://github.com/user-attachments/assets/1ce92e92-720d-40c0-8b4d-c493e5cb10a6"
/>
- vLLM version: v0.10.1.1
- vLLM main:
https://github.com/vllm-project/vllm/commit/6997a25ac65ed6cc3c2be6d09ca45f633a345f63
---------
Signed-off-by: Ruri <33858552+zhoux77899@users.noreply.github.com>
Signed-off-by: zhoux77899 <zhouxiang100@huawei.com>
2025-09-04 11:37:32 +08:00
|
|
|
# gmm1: gate_up_proj & act_fn: swiglu
|
|
|
|
|
hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant(
|
|
|
|
|
x=hidden_states,
|
2025-11-30 22:52:05 +08:00
|
|
|
weight=w1[0],
|
2025-12-12 14:51:20 +08:00
|
|
|
group_list=cumsum_group_list(group_list, group_list_type, 0),
|
2025-11-30 22:52:05 +08:00
|
|
|
weight_scale=w1_scale[0],
|
2026-02-06 15:28:49 +08:00
|
|
|
x_scale=pertoken_scale,
|
|
|
|
|
)
|
2025-11-04 16:49:58 +08:00
|
|
|
if quantized_hidden_states is not None:
|
|
|
|
|
dispose_tensor(quantized_hidden_states)
|
[main] Fuse GroupedMatmul, Swiglu and DynamicQuant in `W8A8_DYNAMIC` quantized MoE layers (#2275)
### What this PR does / why we need it?
Fuse `GroupedMatmul`, `Swiglu` and `DynamicQuant` into one fusion
operation `GroupedMatmulSwigluQuant`.
1. extract common functions in `w4a8_dynamic.py` and `w8a8_dynamic.py`
2. if in supported occasion, use fusion operation
`npu_grouped_matmul_swiglu_quant`
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
Tested on W8A8 quantized Qwen3-235B-A22B model with `bs=16`
1. `tp=8`, `dp=1`, `moe_tp=8`, `moe_ep=1`, TPOP increased 21.54%, Output
Token Throughput increased 27.35%
<img width="3443" height="211" alt="image"
src="https://github.com/user-attachments/assets/a1a9c14d-2310-41be-9a03-36125dabae6e"
/>
3. `tp=8`, `dp=1`, `moe_tp=1`, `moe_ep=8`, TPOP increased 17.38%, Output
Token Throughput increased 6.86%
<img width="3443" height="211" alt="image"
src="https://github.com/user-attachments/assets/1ce92e92-720d-40c0-8b4d-c493e5cb10a6"
/>
- vLLM version: v0.10.1.1
- vLLM main:
https://github.com/vllm-project/vllm/commit/6997a25ac65ed6cc3c2be6d09ca45f633a345f63
---------
Signed-off-by: Ruri <33858552+zhoux77899@users.noreply.github.com>
Signed-off-by: zhoux77899 <zhouxiang100@huawei.com>
2025-09-04 11:37:32 +08:00
|
|
|
else:
|
2025-11-30 22:52:05 +08:00
|
|
|
if w1_scale[0].dtype != torch.float32:
|
|
|
|
|
w1_scale[0] = w1_scale[0].to(torch.float32)
|
[main] Fuse GroupedMatmul, Swiglu and DynamicQuant in `W8A8_DYNAMIC` quantized MoE layers (#2275)
### What this PR does / why we need it?
Fuse `GroupedMatmul`, `Swiglu` and `DynamicQuant` into one fusion
operation `GroupedMatmulSwigluQuant`.
1. extract common functions in `w4a8_dynamic.py` and `w8a8_dynamic.py`
2. if in supported occasion, use fusion operation
`npu_grouped_matmul_swiglu_quant`
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
Tested on W8A8 quantized Qwen3-235B-A22B model with `bs=16`
1. `tp=8`, `dp=1`, `moe_tp=8`, `moe_ep=1`, TPOP increased 21.54%, Output
Token Throughput increased 27.35%
<img width="3443" height="211" alt="image"
src="https://github.com/user-attachments/assets/a1a9c14d-2310-41be-9a03-36125dabae6e"
/>
3. `tp=8`, `dp=1`, `moe_tp=1`, `moe_ep=8`, TPOP increased 17.38%, Output
Token Throughput increased 6.86%
<img width="3443" height="211" alt="image"
src="https://github.com/user-attachments/assets/1ce92e92-720d-40c0-8b4d-c493e5cb10a6"
/>
- vLLM version: v0.10.1.1
- vLLM main:
https://github.com/vllm-project/vllm/commit/6997a25ac65ed6cc3c2be6d09ca45f633a345f63
---------
Signed-off-by: Ruri <33858552+zhoux77899@users.noreply.github.com>
Signed-off-by: zhoux77899 <zhouxiang100@huawei.com>
2025-09-04 11:37:32 +08:00
|
|
|
# gmm1: gate_up_proj
|
|
|
|
|
hidden_states = torch_npu.npu_grouped_matmul(
|
|
|
|
|
x=[hidden_states],
|
2025-11-30 22:52:05 +08:00
|
|
|
weight=w1,
|
[main] Fuse GroupedMatmul, Swiglu and DynamicQuant in `W8A8_DYNAMIC` quantized MoE layers (#2275)
### What this PR does / why we need it?
Fuse `GroupedMatmul`, `Swiglu` and `DynamicQuant` into one fusion
operation `GroupedMatmulSwigluQuant`.
1. extract common functions in `w4a8_dynamic.py` and `w8a8_dynamic.py`
2. if in supported occasion, use fusion operation
`npu_grouped_matmul_swiglu_quant`
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
Tested on W8A8 quantized Qwen3-235B-A22B model with `bs=16`
1. `tp=8`, `dp=1`, `moe_tp=8`, `moe_ep=1`, TPOP increased 21.54%, Output
Token Throughput increased 27.35%
<img width="3443" height="211" alt="image"
src="https://github.com/user-attachments/assets/a1a9c14d-2310-41be-9a03-36125dabae6e"
/>
3. `tp=8`, `dp=1`, `moe_tp=1`, `moe_ep=8`, TPOP increased 17.38%, Output
Token Throughput increased 6.86%
<img width="3443" height="211" alt="image"
src="https://github.com/user-attachments/assets/1ce92e92-720d-40c0-8b4d-c493e5cb10a6"
/>
- vLLM version: v0.10.1.1
- vLLM main:
https://github.com/vllm-project/vllm/commit/6997a25ac65ed6cc3c2be6d09ca45f633a345f63
---------
Signed-off-by: Ruri <33858552+zhoux77899@users.noreply.github.com>
Signed-off-by: zhoux77899 <zhouxiang100@huawei.com>
2025-09-04 11:37:32 +08:00
|
|
|
split_item=3,
|
|
|
|
|
group_list_type=group_list_type,
|
|
|
|
|
group_type=0,
|
|
|
|
|
group_list=group_list,
|
2026-02-06 15:28:49 +08:00
|
|
|
output_dtype=torch.int32,
|
|
|
|
|
)[0]
|
2025-11-04 16:49:58 +08:00
|
|
|
if quantized_hidden_states is not None:
|
|
|
|
|
dispose_tensor(quantized_hidden_states)
|
[main] Fuse GroupedMatmul, Swiglu and DynamicQuant in `W8A8_DYNAMIC` quantized MoE layers (#2275)
### What this PR does / why we need it?
Fuse `GroupedMatmul`, `Swiglu` and `DynamicQuant` into one fusion
operation `GroupedMatmulSwigluQuant`.
1. extract common functions in `w4a8_dynamic.py` and `w8a8_dynamic.py`
2. if in supported occasion, use fusion operation
`npu_grouped_matmul_swiglu_quant`
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
Tested on W8A8 quantized Qwen3-235B-A22B model with `bs=16`
1. `tp=8`, `dp=1`, `moe_tp=8`, `moe_ep=1`, TPOP increased 21.54%, Output
Token Throughput increased 27.35%
<img width="3443" height="211" alt="image"
src="https://github.com/user-attachments/assets/a1a9c14d-2310-41be-9a03-36125dabae6e"
/>
3. `tp=8`, `dp=1`, `moe_tp=1`, `moe_ep=8`, TPOP increased 17.38%, Output
Token Throughput increased 6.86%
<img width="3443" height="211" alt="image"
src="https://github.com/user-attachments/assets/1ce92e92-720d-40c0-8b4d-c493e5cb10a6"
/>
- vLLM version: v0.10.1.1
- vLLM main:
https://github.com/vllm-project/vllm/commit/6997a25ac65ed6cc3c2be6d09ca45f633a345f63
---------
Signed-off-by: Ruri <33858552+zhoux77899@users.noreply.github.com>
Signed-off-by: zhoux77899 <zhouxiang100@huawei.com>
2025-09-04 11:37:32 +08:00
|
|
|
# act_fn: swiglu
|
|
|
|
|
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
|
|
|
|
|
x=hidden_states,
|
2025-12-05 16:04:24 +08:00
|
|
|
weight_scale=w1_scale[0],
|
[main] Fuse GroupedMatmul, Swiglu and DynamicQuant in `W8A8_DYNAMIC` quantized MoE layers (#2275)
### What this PR does / why we need it?
Fuse `GroupedMatmul`, `Swiglu` and `DynamicQuant` into one fusion
operation `GroupedMatmulSwigluQuant`.
1. extract common functions in `w4a8_dynamic.py` and `w8a8_dynamic.py`
2. if in supported occasion, use fusion operation
`npu_grouped_matmul_swiglu_quant`
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
Tested on W8A8 quantized Qwen3-235B-A22B model with `bs=16`
1. `tp=8`, `dp=1`, `moe_tp=8`, `moe_ep=1`, TPOP increased 21.54%, Output
Token Throughput increased 27.35%
<img width="3443" height="211" alt="image"
src="https://github.com/user-attachments/assets/a1a9c14d-2310-41be-9a03-36125dabae6e"
/>
3. `tp=8`, `dp=1`, `moe_tp=1`, `moe_ep=8`, TPOP increased 17.38%, Output
Token Throughput increased 6.86%
<img width="3443" height="211" alt="image"
src="https://github.com/user-attachments/assets/1ce92e92-720d-40c0-8b4d-c493e5cb10a6"
/>
- vLLM version: v0.10.1.1
- vLLM main:
https://github.com/vllm-project/vllm/commit/6997a25ac65ed6cc3c2be6d09ca45f633a345f63
---------
Signed-off-by: Ruri <33858552+zhoux77899@users.noreply.github.com>
Signed-off-by: zhoux77899 <zhouxiang100@huawei.com>
2025-09-04 11:37:32 +08:00
|
|
|
activation_scale=pertoken_scale,
|
|
|
|
|
bias=None,
|
|
|
|
|
quant_scale=None,
|
|
|
|
|
quant_offset=None,
|
2025-12-12 14:51:20 +08:00
|
|
|
group_index=cumsum_group_list(group_list, group_list_type, 1),
|
[main] Fuse GroupedMatmul, Swiglu and DynamicQuant in `W8A8_DYNAMIC` quantized MoE layers (#2275)
### What this PR does / why we need it?
Fuse `GroupedMatmul`, `Swiglu` and `DynamicQuant` into one fusion
operation `GroupedMatmulSwigluQuant`.
1. extract common functions in `w4a8_dynamic.py` and `w8a8_dynamic.py`
2. if in supported occasion, use fusion operation
`npu_grouped_matmul_swiglu_quant`
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
Tested on W8A8 quantized Qwen3-235B-A22B model with `bs=16`
1. `tp=8`, `dp=1`, `moe_tp=8`, `moe_ep=1`, TPOP increased 21.54%, Output
Token Throughput increased 27.35%
<img width="3443" height="211" alt="image"
src="https://github.com/user-attachments/assets/a1a9c14d-2310-41be-9a03-36125dabae6e"
/>
3. `tp=8`, `dp=1`, `moe_tp=1`, `moe_ep=8`, TPOP increased 17.38%, Output
Token Throughput increased 6.86%
<img width="3443" height="211" alt="image"
src="https://github.com/user-attachments/assets/1ce92e92-720d-40c0-8b4d-c493e5cb10a6"
/>
- vLLM version: v0.10.1.1
- vLLM main:
https://github.com/vllm-project/vllm/commit/6997a25ac65ed6cc3c2be6d09ca45f633a345f63
---------
Signed-off-by: Ruri <33858552+zhoux77899@users.noreply.github.com>
Signed-off-by: zhoux77899 <zhouxiang100@huawei.com>
2025-09-04 11:37:32 +08:00
|
|
|
activate_left=True,
|
|
|
|
|
quant_mode=1,
|
|
|
|
|
)
|
2025-08-30 22:28:50 +08:00
|
|
|
# gmm2: down_proj
|
|
|
|
|
hidden_states = torch_npu.npu_grouped_matmul(
|
|
|
|
|
x=[hidden_states],
|
2025-11-30 22:52:05 +08:00
|
|
|
weight=w2,
|
|
|
|
|
scale=w2_scale,
|
2025-08-30 22:28:50 +08:00
|
|
|
per_token_scale=[swiglu_out_scale],
|
|
|
|
|
split_item=2,
|
|
|
|
|
group_list_type=group_list_type,
|
|
|
|
|
group_type=0,
|
|
|
|
|
group_list=group_list,
|
2026-02-06 15:28:49 +08:00
|
|
|
output_dtype=w2_scale[0].dtype,
|
|
|
|
|
)[0]
|
2025-12-10 15:58:52 +08:00
|
|
|
elif w1_offset is not None:
|
|
|
|
|
# gmm1: gate_up_proj
|
|
|
|
|
hidden_states = torch_npu.npu_grouped_matmul(
|
|
|
|
|
x=[unquantized_hidden_states],
|
|
|
|
|
weight=[w1],
|
|
|
|
|
antiquant_scale=[w1_scale],
|
|
|
|
|
antiquant_offset=[w1_offset],
|
|
|
|
|
split_item=2,
|
|
|
|
|
group_list_type=group_list_type,
|
|
|
|
|
group_type=0,
|
|
|
|
|
group_list=group_list,
|
2026-02-06 15:28:49 +08:00
|
|
|
output_dtype=_output_dtype,
|
|
|
|
|
)[0]
|
2025-12-10 15:58:52 +08:00
|
|
|
dispose_tensor(unquantized_hidden_states)
|
|
|
|
|
# act_fn: swiglu
|
|
|
|
|
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
|
|
|
|
# gmm2: down_proj
|
|
|
|
|
hidden_states = torch_npu.npu_grouped_matmul(
|
|
|
|
|
x=[hidden_states],
|
|
|
|
|
weight=[w2],
|
|
|
|
|
antiquant_scale=[w2_scale],
|
|
|
|
|
antiquant_offset=[w2_offset],
|
|
|
|
|
split_item=2,
|
|
|
|
|
group_list_type=group_list_type,
|
|
|
|
|
group_type=0,
|
|
|
|
|
group_list=group_list,
|
2026-02-06 15:28:49 +08:00
|
|
|
output_dtype=_output_dtype,
|
|
|
|
|
)[0]
|
2025-08-30 22:28:50 +08:00
|
|
|
else:
|
|
|
|
|
if w1_scale_bias is not None:
|
|
|
|
|
if group_list_type == 0:
|
2026-02-06 15:28:49 +08:00
|
|
|
group_list = torch.cat([group_list[:1], torch.diff(group_list, dim=0)])
|
2025-08-30 22:28:50 +08:00
|
|
|
group_list_type = 1
|
[main] Fuse GroupedMatmul, Swiglu and DynamicQuant in `W8A8_DYNAMIC` quantized MoE layers (#2275)
### What this PR does / why we need it?
Fuse `GroupedMatmul`, `Swiglu` and `DynamicQuant` into one fusion
operation `GroupedMatmulSwigluQuant`.
1. extract common functions in `w4a8_dynamic.py` and `w8a8_dynamic.py`
2. if in supported occasion, use fusion operation
`npu_grouped_matmul_swiglu_quant`
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
Tested on W8A8 quantized Qwen3-235B-A22B model with `bs=16`
1. `tp=8`, `dp=1`, `moe_tp=8`, `moe_ep=1`, TPOP increased 21.54%, Output
Token Throughput increased 27.35%
<img width="3443" height="211" alt="image"
src="https://github.com/user-attachments/assets/a1a9c14d-2310-41be-9a03-36125dabae6e"
/>
3. `tp=8`, `dp=1`, `moe_tp=1`, `moe_ep=8`, TPOP increased 17.38%, Output
Token Throughput increased 6.86%
<img width="3443" height="211" alt="image"
src="https://github.com/user-attachments/assets/1ce92e92-720d-40c0-8b4d-c493e5cb10a6"
/>
- vLLM version: v0.10.1.1
- vLLM main:
https://github.com/vllm-project/vllm/commit/6997a25ac65ed6cc3c2be6d09ca45f633a345f63
---------
Signed-off-by: Ruri <33858552+zhoux77899@users.noreply.github.com>
Signed-off-by: zhoux77899 <zhouxiang100@huawei.com>
2025-09-04 11:37:32 +08:00
|
|
|
bias1 = [w1_scale_bias] if not fusion else w1_scale_bias
|
2025-08-30 22:28:50 +08:00
|
|
|
bias2 = [w2_scale_bias]
|
|
|
|
|
# TODO w4a8 scene: dynamic acquisition of dtype in the future
|
|
|
|
|
_output_dtype = torch.bfloat16
|
|
|
|
|
|
2025-11-30 22:52:05 +08:00
|
|
|
if _custom_gmm_swiglu_enabled(fusion, dynamic_eplb):
|
|
|
|
|
# gmm1: gate_up_proj & act_fn: swiglu
|
2026-02-06 15:28:49 +08:00
|
|
|
hidden_states, swiglu_out_scale, _ = torch.ops._C_ascend.grouped_matmul_swiglu_quant_weight_nz_tensor_list(
|
|
|
|
|
x=hidden_states,
|
|
|
|
|
weight=w1,
|
|
|
|
|
weight_scale=w1_scale,
|
|
|
|
|
x_scale=pertoken_scale,
|
|
|
|
|
group_list=cumsum_group_list(group_list, group_list_type, 0),
|
|
|
|
|
bias=bias1,
|
|
|
|
|
)
|
2025-11-30 22:52:05 +08:00
|
|
|
elif fusion and not dynamic_eplb:
|
[main] Fuse GroupedMatmul, Swiglu and DynamicQuant in `W8A8_DYNAMIC` quantized MoE layers (#2275)
### What this PR does / why we need it?
Fuse `GroupedMatmul`, `Swiglu` and `DynamicQuant` into one fusion
operation `GroupedMatmulSwigluQuant`.
1. extract common functions in `w4a8_dynamic.py` and `w8a8_dynamic.py`
2. if in supported occasion, use fusion operation
`npu_grouped_matmul_swiglu_quant`
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
Tested on W8A8 quantized Qwen3-235B-A22B model with `bs=16`
1. `tp=8`, `dp=1`, `moe_tp=8`, `moe_ep=1`, TPOP increased 21.54%, Output
Token Throughput increased 27.35%
<img width="3443" height="211" alt="image"
src="https://github.com/user-attachments/assets/a1a9c14d-2310-41be-9a03-36125dabae6e"
/>
3. `tp=8`, `dp=1`, `moe_tp=1`, `moe_ep=8`, TPOP increased 17.38%, Output
Token Throughput increased 6.86%
<img width="3443" height="211" alt="image"
src="https://github.com/user-attachments/assets/1ce92e92-720d-40c0-8b4d-c493e5cb10a6"
/>
- vLLM version: v0.10.1.1
- vLLM main:
https://github.com/vllm-project/vllm/commit/6997a25ac65ed6cc3c2be6d09ca45f633a345f63
---------
Signed-off-by: Ruri <33858552+zhoux77899@users.noreply.github.com>
Signed-off-by: zhoux77899 <zhouxiang100@huawei.com>
2025-09-04 11:37:32 +08:00
|
|
|
# gmm1: gate_up_proj & act_fn: swiglu
|
|
|
|
|
hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant(
|
|
|
|
|
x=hidden_states,
|
2025-11-30 22:52:05 +08:00
|
|
|
weight=w1[0],
|
[main] Fuse GroupedMatmul, Swiglu and DynamicQuant in `W8A8_DYNAMIC` quantized MoE layers (#2275)
### What this PR does / why we need it?
Fuse `GroupedMatmul`, `Swiglu` and `DynamicQuant` into one fusion
operation `GroupedMatmulSwigluQuant`.
1. extract common functions in `w4a8_dynamic.py` and `w8a8_dynamic.py`
2. if in supported occasion, use fusion operation
`npu_grouped_matmul_swiglu_quant`
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
Tested on W8A8 quantized Qwen3-235B-A22B model with `bs=16`
1. `tp=8`, `dp=1`, `moe_tp=8`, `moe_ep=1`, TPOP increased 21.54%, Output
Token Throughput increased 27.35%
<img width="3443" height="211" alt="image"
src="https://github.com/user-attachments/assets/a1a9c14d-2310-41be-9a03-36125dabae6e"
/>
3. `tp=8`, `dp=1`, `moe_tp=1`, `moe_ep=8`, TPOP increased 17.38%, Output
Token Throughput increased 6.86%
<img width="3443" height="211" alt="image"
src="https://github.com/user-attachments/assets/1ce92e92-720d-40c0-8b4d-c493e5cb10a6"
/>
- vLLM version: v0.10.1.1
- vLLM main:
https://github.com/vllm-project/vllm/commit/6997a25ac65ed6cc3c2be6d09ca45f633a345f63
---------
Signed-off-by: Ruri <33858552+zhoux77899@users.noreply.github.com>
Signed-off-by: zhoux77899 <zhouxiang100@huawei.com>
2025-09-04 11:37:32 +08:00
|
|
|
bias=bias1,
|
2025-12-12 14:51:20 +08:00
|
|
|
group_list=cumsum_group_list(group_list, group_list_type, 0),
|
2025-11-30 22:52:05 +08:00
|
|
|
weight_scale=w1_scale[0],
|
2026-02-06 15:28:49 +08:00
|
|
|
x_scale=pertoken_scale,
|
|
|
|
|
)
|
2025-11-04 16:49:58 +08:00
|
|
|
if quantized_hidden_states is not None:
|
|
|
|
|
dispose_tensor(quantized_hidden_states)
|
[main] Fuse GroupedMatmul, Swiglu and DynamicQuant in `W8A8_DYNAMIC` quantized MoE layers (#2275)
### What this PR does / why we need it?
Fuse `GroupedMatmul`, `Swiglu` and `DynamicQuant` into one fusion
operation `GroupedMatmulSwigluQuant`.
1. extract common functions in `w4a8_dynamic.py` and `w8a8_dynamic.py`
2. if in supported occasion, use fusion operation
`npu_grouped_matmul_swiglu_quant`
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
Tested on W8A8 quantized Qwen3-235B-A22B model with `bs=16`
1. `tp=8`, `dp=1`, `moe_tp=8`, `moe_ep=1`, TPOP increased 21.54%, Output
Token Throughput increased 27.35%
<img width="3443" height="211" alt="image"
src="https://github.com/user-attachments/assets/a1a9c14d-2310-41be-9a03-36125dabae6e"
/>
3. `tp=8`, `dp=1`, `moe_tp=1`, `moe_ep=8`, TPOP increased 17.38%, Output
Token Throughput increased 6.86%
<img width="3443" height="211" alt="image"
src="https://github.com/user-attachments/assets/1ce92e92-720d-40c0-8b4d-c493e5cb10a6"
/>
- vLLM version: v0.10.1.1
- vLLM main:
https://github.com/vllm-project/vllm/commit/6997a25ac65ed6cc3c2be6d09ca45f633a345f63
---------
Signed-off-by: Ruri <33858552+zhoux77899@users.noreply.github.com>
Signed-off-by: zhoux77899 <zhouxiang100@huawei.com>
2025-09-04 11:37:32 +08:00
|
|
|
else:
|
2025-11-30 22:52:05 +08:00
|
|
|
w1_scale[0] = w1_scale[0].to(w2_scale[0].dtype)
|
[main] Fuse GroupedMatmul, Swiglu and DynamicQuant in `W8A8_DYNAMIC` quantized MoE layers (#2275)
### What this PR does / why we need it?
Fuse `GroupedMatmul`, `Swiglu` and `DynamicQuant` into one fusion
operation `GroupedMatmulSwigluQuant`.
1. extract common functions in `w4a8_dynamic.py` and `w8a8_dynamic.py`
2. if in supported occasion, use fusion operation
`npu_grouped_matmul_swiglu_quant`
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
Tested on W8A8 quantized Qwen3-235B-A22B model with `bs=16`
1. `tp=8`, `dp=1`, `moe_tp=8`, `moe_ep=1`, TPOP increased 21.54%, Output
Token Throughput increased 27.35%
<img width="3443" height="211" alt="image"
src="https://github.com/user-attachments/assets/a1a9c14d-2310-41be-9a03-36125dabae6e"
/>
3. `tp=8`, `dp=1`, `moe_tp=1`, `moe_ep=8`, TPOP increased 17.38%, Output
Token Throughput increased 6.86%
<img width="3443" height="211" alt="image"
src="https://github.com/user-attachments/assets/1ce92e92-720d-40c0-8b4d-c493e5cb10a6"
/>
- vLLM version: v0.10.1.1
- vLLM main:
https://github.com/vllm-project/vllm/commit/6997a25ac65ed6cc3c2be6d09ca45f633a345f63
---------
Signed-off-by: Ruri <33858552+zhoux77899@users.noreply.github.com>
Signed-off-by: zhoux77899 <zhouxiang100@huawei.com>
2025-09-04 11:37:32 +08:00
|
|
|
# gmm1: gate_up_proj
|
|
|
|
|
hidden_states = torch_npu.npu_grouped_matmul(
|
|
|
|
|
x=[hidden_states],
|
2025-11-30 22:52:05 +08:00
|
|
|
weight=w1,
|
|
|
|
|
scale=w1_scale,
|
[main] Fuse GroupedMatmul, Swiglu and DynamicQuant in `W8A8_DYNAMIC` quantized MoE layers (#2275)
### What this PR does / why we need it?
Fuse `GroupedMatmul`, `Swiglu` and `DynamicQuant` into one fusion
operation `GroupedMatmulSwigluQuant`.
1. extract common functions in `w4a8_dynamic.py` and `w8a8_dynamic.py`
2. if in supported occasion, use fusion operation
`npu_grouped_matmul_swiglu_quant`
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
Tested on W8A8 quantized Qwen3-235B-A22B model with `bs=16`
1. `tp=8`, `dp=1`, `moe_tp=8`, `moe_ep=1`, TPOP increased 21.54%, Output
Token Throughput increased 27.35%
<img width="3443" height="211" alt="image"
src="https://github.com/user-attachments/assets/a1a9c14d-2310-41be-9a03-36125dabae6e"
/>
3. `tp=8`, `dp=1`, `moe_tp=1`, `moe_ep=8`, TPOP increased 17.38%, Output
Token Throughput increased 6.86%
<img width="3443" height="211" alt="image"
src="https://github.com/user-attachments/assets/1ce92e92-720d-40c0-8b4d-c493e5cb10a6"
/>
- vLLM version: v0.10.1.1
- vLLM main:
https://github.com/vllm-project/vllm/commit/6997a25ac65ed6cc3c2be6d09ca45f633a345f63
---------
Signed-off-by: Ruri <33858552+zhoux77899@users.noreply.github.com>
Signed-off-by: zhoux77899 <zhouxiang100@huawei.com>
2025-09-04 11:37:32 +08:00
|
|
|
bias=bias1,
|
|
|
|
|
per_token_scale=[pertoken_scale],
|
|
|
|
|
split_item=2,
|
|
|
|
|
group_list_type=group_list_type,
|
|
|
|
|
group_type=0,
|
|
|
|
|
group_list=group_list,
|
2026-02-06 15:28:49 +08:00
|
|
|
output_dtype=_output_dtype,
|
|
|
|
|
)[0]
|
2025-11-04 16:49:58 +08:00
|
|
|
if quantized_hidden_states is not None:
|
|
|
|
|
dispose_tensor(quantized_hidden_states)
|
[main] Fuse GroupedMatmul, Swiglu and DynamicQuant in `W8A8_DYNAMIC` quantized MoE layers (#2275)
### What this PR does / why we need it?
Fuse `GroupedMatmul`, `Swiglu` and `DynamicQuant` into one fusion
operation `GroupedMatmulSwigluQuant`.
1. extract common functions in `w4a8_dynamic.py` and `w8a8_dynamic.py`
2. if in supported occasion, use fusion operation
`npu_grouped_matmul_swiglu_quant`
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
Tested on W8A8 quantized Qwen3-235B-A22B model with `bs=16`
1. `tp=8`, `dp=1`, `moe_tp=8`, `moe_ep=1`, TPOP increased 21.54%, Output
Token Throughput increased 27.35%
<img width="3443" height="211" alt="image"
src="https://github.com/user-attachments/assets/a1a9c14d-2310-41be-9a03-36125dabae6e"
/>
3. `tp=8`, `dp=1`, `moe_tp=1`, `moe_ep=8`, TPOP increased 17.38%, Output
Token Throughput increased 6.86%
<img width="3443" height="211" alt="image"
src="https://github.com/user-attachments/assets/1ce92e92-720d-40c0-8b4d-c493e5cb10a6"
/>
- vLLM version: v0.10.1.1
- vLLM main:
https://github.com/vllm-project/vllm/commit/6997a25ac65ed6cc3c2be6d09ca45f633a345f63
---------
Signed-off-by: Ruri <33858552+zhoux77899@users.noreply.github.com>
Signed-off-by: zhoux77899 <zhouxiang100@huawei.com>
2025-09-04 11:37:32 +08:00
|
|
|
# act_fn: swiglu
|
2025-12-22 16:01:58 +08:00
|
|
|
if HAS_TRITON:
|
2026-02-06 15:28:49 +08:00
|
|
|
from vllm_ascend.ops.triton.activation.swiglu_quant import swiglu_quant
|
|
|
|
|
|
2025-12-22 16:01:58 +08:00
|
|
|
hidden_states, swiglu_out_scale = swiglu_quant(
|
2026-02-06 15:28:49 +08:00
|
|
|
hidden_states, group_list=group_list, group_list_type=group_list_type
|
|
|
|
|
)
|
2025-12-22 16:01:58 +08:00
|
|
|
else:
|
|
|
|
|
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
2026-02-06 15:28:49 +08:00
|
|
|
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(hidden_states)
|
2025-08-30 22:28:50 +08:00
|
|
|
# gmm2: down_proj
|
|
|
|
|
hidden_states = torch_npu.npu_grouped_matmul(
|
|
|
|
|
x=[hidden_states],
|
2025-11-30 22:52:05 +08:00
|
|
|
weight=w2,
|
|
|
|
|
scale=w2_scale,
|
2025-08-30 22:28:50 +08:00
|
|
|
bias=bias2,
|
|
|
|
|
per_token_scale=[swiglu_out_scale],
|
|
|
|
|
split_item=2,
|
|
|
|
|
group_list_type=group_list_type,
|
|
|
|
|
group_type=0,
|
|
|
|
|
group_list=group_list,
|
2026-02-06 15:28:49 +08:00
|
|
|
output_dtype=_output_dtype,
|
|
|
|
|
)[0]
|
2025-08-30 22:28:50 +08:00
|
|
|
return hidden_states
|
|
|
|
|
|
|
|
|
|
|
2026-02-06 15:28:49 +08:00
|
|
|
def unquant_apply_mlp(
|
|
|
|
|
hidden_states: torch.Tensor,
|
|
|
|
|
w1: torch.Tensor,
|
|
|
|
|
w2: torch.Tensor,
|
|
|
|
|
group_list: torch.Tensor,
|
|
|
|
|
group_list_type: int = 1,
|
|
|
|
|
topk_scales: torch.Tensor | None = None,
|
|
|
|
|
need_trans: bool = True,
|
|
|
|
|
) -> torch.Tensor:
|
2025-09-08 20:09:50 +08:00
|
|
|
if need_trans:
|
|
|
|
|
w1 = w1.transpose(1, 2)
|
|
|
|
|
w2 = w2.transpose(1, 2)
|
|
|
|
|
|
2025-08-30 22:28:50 +08:00
|
|
|
gate_up_out = torch_npu.npu_grouped_matmul(
|
|
|
|
|
x=[hidden_states],
|
|
|
|
|
weight=[w1],
|
|
|
|
|
split_item=2,
|
|
|
|
|
group_list_type=group_list_type,
|
|
|
|
|
group_type=0,
|
|
|
|
|
group_list=group_list,
|
|
|
|
|
)[0]
|
[Feat.]: 310p support MOE models (#6530)
### What this PR does / why we need it?
This pull request integrates comprehensive support for Mixture of
Experts (MoE) models on the Ascend 310P device within the vllm-ascend
framework. It achieves this by introducing specialized modules for
expert selection, fused MoE layers, and optimized all-gather
communication. The changes also refine existing NPU operations, making
them more consistent and efficient for 310P, ultimately enhancing the
performance and compatibility of MoE models on this hardware.
Highlights
310P MoE Support: Introduces dedicated implementations for Mixture of
Experts (MoE) models on Ascend 310P devices, including new modules for
expert selection, fused MoE layers, and communication.
All-Gather Communication: Enforces the use of ALLGATHER communication
for MoE operations on 310P, optimizing data transfer and leveraging
NPU-specific token dispatching.
Simplified NPU Operations: Removes conditional type casting for
npu_swiglu and enables custom rotary embedding kernels unconditionally,
suggesting improved native support for 310P.
New MoE Classes Registered: Registers AscendFusedMoE310 and
AscendSharedFusedMoE310 to integrate 310P-specific MoE layers into the
system's custom operation registry.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
offline test and server test, with qwen3-30b-a3b,tp/ep 4 on 310p
- vLLM version: v0.15.0
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.15.0
---------
Signed-off-by: pu-zhe <zpuaa@outlook.com>
2026-02-06 10:30:56 +08:00
|
|
|
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
|
2025-08-30 22:28:50 +08:00
|
|
|
|
|
|
|
|
if topk_scales is not None:
|
|
|
|
|
gate_up_out *= topk_scales
|
|
|
|
|
|
|
|
|
|
hidden_states = torch_npu.npu_grouped_matmul(
|
|
|
|
|
x=[gate_up_out],
|
|
|
|
|
weight=[w2],
|
|
|
|
|
split_item=2,
|
|
|
|
|
group_list_type=group_list_type,
|
|
|
|
|
group_type=0,
|
|
|
|
|
group_list=group_list,
|
|
|
|
|
)[0]
|
|
|
|
|
return hidden_states
|
|
|
|
|
|
|
|
|
|
|
2026-02-06 15:28:49 +08:00
|
|
|
def unified_apply_mlp(
|
|
|
|
|
hidden_states: torch.Tensor,
|
|
|
|
|
w1: torch.Tensor | list[torch.Tensor],
|
|
|
|
|
w2: torch.Tensor | list[torch.Tensor],
|
|
|
|
|
group_list: torch.Tensor,
|
|
|
|
|
w1_scale: list[torch.Tensor] | None = None,
|
|
|
|
|
w2_scale: list[torch.Tensor] | None = None,
|
|
|
|
|
dynamic_scale: torch.Tensor = None,
|
|
|
|
|
group_list_type: int = 1,
|
|
|
|
|
w1_scale_bias: torch.Tensor = None,
|
|
|
|
|
w2_scale_bias: torch.Tensor = None,
|
|
|
|
|
w1_offset: torch.Tensor | None = None,
|
|
|
|
|
w2_offset: torch.Tensor | None = None,
|
|
|
|
|
topk_scales: torch.Tensor | None = None,
|
|
|
|
|
with_quant: bool = False,
|
|
|
|
|
fusion: bool = False,
|
|
|
|
|
need_trans: bool = True,
|
|
|
|
|
dynamic_eplb: bool = False,
|
|
|
|
|
) -> torch.Tensor:
|
2025-08-30 22:28:50 +08:00
|
|
|
if with_quant:
|
2025-11-30 22:52:05 +08:00
|
|
|
assert w1_scale is not None and w2_scale is not None
|
2026-02-06 15:28:49 +08:00
|
|
|
return quant_apply_mlp(
|
|
|
|
|
hidden_states=hidden_states,
|
|
|
|
|
w1=w1,
|
|
|
|
|
w1_scale=w1_scale,
|
|
|
|
|
w2=w2,
|
|
|
|
|
w2_scale=w2_scale,
|
|
|
|
|
group_list=group_list,
|
|
|
|
|
dynamic_scale=dynamic_scale,
|
|
|
|
|
group_list_type=group_list_type,
|
|
|
|
|
w1_scale_bias=w1_scale_bias,
|
|
|
|
|
w2_scale_bias=w2_scale_bias,
|
|
|
|
|
w1_offset=w1_offset,
|
|
|
|
|
w2_offset=w2_offset,
|
|
|
|
|
fusion=fusion,
|
|
|
|
|
dynamic_eplb=dynamic_eplb,
|
|
|
|
|
)
|
2025-08-30 22:28:50 +08:00
|
|
|
else:
|
2026-02-06 15:28:49 +08:00
|
|
|
return unquant_apply_mlp(
|
|
|
|
|
hidden_states=hidden_states,
|
|
|
|
|
w1=w1,
|
|
|
|
|
w2=w2,
|
|
|
|
|
group_list=group_list,
|
|
|
|
|
group_list_type=group_list_type,
|
|
|
|
|
topk_scales=topk_scales,
|
|
|
|
|
need_trans=need_trans,
|
|
|
|
|
)
|