From aff5189c8781f1127d7dbca6c59a9ec44a427100 Mon Sep 17 00:00:00 2001
From: Ruri <33858552+zhoux77899@users.noreply.github.com>
Date: Thu, 4 Sep 2025 11:37:32 +0800
Subject: [PATCH] [main] Fuse GroupedMatmul, Swiglu and DynamicQuant in
`W8A8_DYNAMIC` quantized MoE layers (#2275)
### What this PR does / why we need it?
Fuse `GroupedMatmul`, `Swiglu` and `DynamicQuant` into one fusion
operation `GroupedMatmulSwigluQuant`.
1. extract common functions in `w4a8_dynamic.py` and `w8a8_dynamic.py`
2. if in supported occasion, use fusion operation
`npu_grouped_matmul_swiglu_quant`
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
Tested on W8A8 quantized Qwen3-235B-A22B model with `bs=16`
1. `tp=8`, `dp=1`, `moe_tp=8`, `moe_ep=1`, TPOP increased 21.54%, Output
Token Throughput increased 27.35%
3. `tp=8`, `dp=1`, `moe_tp=1`, `moe_ep=8`, TPOP increased 17.38%, Output
Token Throughput increased 6.86%
- vLLM version: v0.10.1.1
- vLLM main:
https://github.com/vllm-project/vllm/commit/6997a25ac65ed6cc3c2be6d09ca45f633a345f63
---------
Signed-off-by: Ruri <33858552+zhoux77899@users.noreply.github.com>
Signed-off-by: zhoux77899
---
tests/ut/ops/test_fused_ops.py | 104 +++++++++++-
tests/ut/quantization/test_w8a8_dynamic.py | 49 ++++++
vllm_ascend/ops/fused_moe.py | 6 +-
vllm_ascend/ops/layers/moe_mlp.py | 144 +++++++++++------
vllm_ascend/quantization/w8a8_dynamic.py | 174 +--------------------
5 files changed, 257 insertions(+), 220 deletions(-)
create mode 100644 tests/ut/quantization/test_w8a8_dynamic.py
diff --git a/tests/ut/ops/test_fused_ops.py b/tests/ut/ops/test_fused_ops.py
index 6a51d1d..2e1661b 100644
--- a/tests/ut/ops/test_fused_ops.py
+++ b/tests/ut/ops/test_fused_ops.py
@@ -29,7 +29,7 @@ from vllm_ascend.ascend_forward_context import (FusedMoEState,
from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
AscendUnquantizedFusedMoEMethod)
from vllm_ascend.ops.layers.experts_selector import select_experts
-from vllm_ascend.ops.layers.moe_mlp import unified_apply_mlp
+from vllm_ascend.ops.layers.moe_mlp import cumsum_group_list, unified_apply_mlp
from vllm_ascend.utils import AscendSocVersion, adapt_patch
adapt_patch(True)
@@ -524,6 +524,43 @@ class TestExpertsSelector:
assert topk_ids.shape == (8, 2)
+class TestCumsumGroupList(TestBase):
+
+ def setUp(self):
+ self.active_num = 8
+ self.expert_num = 128
+ self.experts = torch.zeros((self.expert_num, ), dtype=torch.int64)
+ self.experts[:self.active_num] = 1
+ self.experts = self.experts[torch.randperm(self.expert_num)]
+ self.group_list = self.experts.cumsum(dim=0)
+
+ def test_cumsum_group_list_with_type_0(self):
+ group_list = self.experts.cumsum(dim=0)
+ group_list_type = 0
+ result = cumsum_group_list(group_list, group_list_type)
+ self.assertTrue(torch.equal(result, self.group_list))
+
+ def test_cumsum_group_list_with_type_1(self):
+ group_list = self.experts
+ group_list_type = 1
+ result = cumsum_group_list(group_list, group_list_type)
+ self.assertTrue(torch.equal(result, self.group_list))
+
+ def test_cumsum_group_list_with_type_2(self):
+ tokens = torch.arange(self.expert_num, dtype=torch.int64)
+ group_list = torch.cat([
+ tokens.reshape(self.expert_num, 1),
+ self.experts.reshape(self.expert_num, 1)
+ ],
+ dim=1)
+ group_list_type = 2
+ result = cumsum_group_list(group_list,
+ group_list_type,
+ active_num=self.active_num,
+ expert_num=self.expert_num)
+ self.assertTrue(torch.equal(result, self.group_list))
+
+
class TestUnifiedApplyMLP(TestBase):
@patch('vllm_ascend.ops.layers.moe_mlp.get_forward_context')
@@ -739,3 +776,68 @@ class TestUnifiedApplyMLP(TestBase):
self.assertEqual(result.shape, hidden_states.shape)
self.assertEqual(result.dtype, torch.float16)
+
+ @patch("vllm_ascend.ops.layers.moe_mlp.get_forward_context")
+ @patch("torch_npu.npu_grouped_matmul")
+ @patch("torch_npu.npu_swiglu")
+ @patch("torch_npu.npu_grouped_matmul_swiglu_quant")
+ @patch("torch_npu.npu_dynamic_quant")
+ def test_unified_apply_mlp_with_quantization_and_fusion_mlp(
+ self, mock_npu_dynamic_quant, mock_npu_grouped_matmul_swiglu_quant,
+ mock_npu_swiglu, mock_npu_grouped_matmul,
+ mock_get_forward_context):
+
+ mock_forward_context = MagicMock()
+ mock_forward_context.with_quant = True
+ mock_forward_context.fused_moe_state = "NOT_MC2"
+ mock_get_forward_context.return_value = mock_forward_context
+
+ mock_npu_grouped_matmul_swiglu_quant.return_value = (torch.randint(
+ -128, 127, (10, 40),
+ dtype=torch.int8), torch.rand(
+ 10, 1,
+ dtype=torch.float32), torch.rand(10, 1, dtype=torch.float32))
+ mock_npu_grouped_matmul.side_effect = [[
+ torch.randn(10, 20, dtype=torch.bfloat16)
+ ]]
+ mock_npu_swiglu.return_value = torch.randn(10,
+ 40,
+ dtype=torch.bfloat16)
+ mock_npu_dynamic_quant.return_value = (torch.randint(-128,
+ 127, (10, 40),
+ dtype=torch.int8),
+ torch.rand(10,
+ 1,
+ dtype=torch.float32))
+
+ hidden_states = torch.randn(10, 20, dtype=torch.bfloat16)
+ w1 = torch.randn(5, 20, 40, dtype=torch.bfloat16)
+ w1_scale = torch.randn(5, 40, dtype=torch.bfloat16)
+ w2 = torch.randn(5, 40, 20, dtype=torch.bfloat16)
+ w2_scale = torch.randn(5, 20, dtype=torch.bfloat16)
+ w1_scale_bias = torch.randn(5, 40, dtype=torch.bfloat16)
+ w2_scale_bias = torch.randn(5, 20, dtype=torch.bfloat16)
+ group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
+ provided_dynamic_scale = torch.rand(10, 1, dtype=torch.float32)
+
+ result = unified_apply_mlp(hidden_states=hidden_states,
+ w1=w1,
+ w1_scale=w1_scale,
+ w2=w2,
+ w2_scale=w2_scale,
+ group_list=group_list,
+ dynamic_scale=provided_dynamic_scale,
+ group_list_type=1,
+ w1_scale_bias=w1_scale_bias,
+ w2_scale_bias=w2_scale_bias,
+ topk_scales=None,
+ with_quant=True,
+ fusion=True)
+
+ mock_get_forward_context.assert_called()
+ mock_npu_grouped_matmul.assert_called_once()
+ mock_npu_grouped_matmul_swiglu_quant.assert_called_once()
+
+ self.assertTrue(mock_forward_context.with_quant)
+ self.assertEqual(result.shape, hidden_states.shape)
+ self.assertEqual(result.dtype, torch.bfloat16)
diff --git a/tests/ut/quantization/test_w8a8_dynamic.py b/tests/ut/quantization/test_w8a8_dynamic.py
new file mode 100644
index 0000000..690778e
--- /dev/null
+++ b/tests/ut/quantization/test_w8a8_dynamic.py
@@ -0,0 +1,49 @@
+from unittest.mock import Mock, patch
+
+import torch
+
+from tests.ut.base import TestBase
+from vllm_ascend.quantization.w8a8_dynamic import \
+ AscendW8A8DynamicFusedMoEMethod
+
+
+class TestAscendW8A8FusedMoEMethod(TestBase):
+ num_experts = 8
+ hidden_size = 128
+ intermediate_size = 128
+
+ @patch("torch.distributed.get_rank")
+ @patch("vllm_ascend.quantization.w8a8_dynamic.get_mc2_group")
+ @patch("vllm_ascend.quantization.w8a8_dynamic.get_ascend_config")
+ @patch("vllm_ascend.quantization.w8a8_dynamic.get_ep_group")
+ def setUp(self, mock_get_ep_group, mock_get_ascend_config,
+ mock_get_mc2_group, mock_get_rank):
+ mock_ep_group = Mock()
+ mock_get_ep_group.return_value = mock_ep_group
+ mock_ascend_config = Mock()
+ mock_ascend_config.torchair_graph_config = Mock(enabled=False)
+ mock_get_ascend_config.return_value = mock_ascend_config
+ mock_mc2_group = Mock(device_group=0)
+ mock_get_mc2_group.return_value = mock_mc2_group
+ mock_rank = Mock()
+ mock_get_rank.return_value = mock_rank
+
+ self.quant_method = AscendW8A8DynamicFusedMoEMethod()
+
+ def test_get_weight(self):
+ param_dict = self.quant_method.get_weight(self.num_experts,
+ self.intermediate_size,
+ self.hidden_size,
+ torch.bfloat16)
+ self.assertEqual(param_dict["w13_weight"].dtype, torch.int8)
+ self.assertEqual(
+ param_dict["w13_weight"].shape,
+ (self.num_experts, 2 * self.intermediate_size, self.hidden_size))
+
+ def test_get_dynamic_quant_param(self):
+ param_dict = self.quant_method.get_dynamic_quant_param(
+ self.num_experts, self.intermediate_size, self.hidden_size,
+ torch.bfloat16)
+ self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.bfloat16)
+ self.assertEqual(param_dict["w13_weight_scale"].shape,
+ (self.num_experts, 2 * self.intermediate_size, 1))
diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py
index 14396c1..11c4ec5 100644
--- a/vllm_ascend/ops/fused_moe.py
+++ b/vllm_ascend/ops/fused_moe.py
@@ -70,7 +70,8 @@ def unified_fused_experts_eager(hidden_states: torch.Tensor,
shared_dequant_scale: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
- with_quant: bool = False):
+ with_quant: bool = False,
+ fusion_mlp: bool = False):
token_dispatcher = get_forward_context().token_dispatcher
results = token_dispatcher.token_dispatch(
@@ -100,7 +101,8 @@ def unified_fused_experts_eager(hidden_states: torch.Tensor,
w1_scale_bias=w1_scale_bias,
w2_scale_bias=w2_scale_bias,
topk_scales=results.get("topk_scales"),
- with_quant=with_quant)
+ with_quant=with_quant,
+ fusion=fusion_mlp)
final_hidden_states = token_dispatcher.token_combine(expert_output)
return final_hidden_states
diff --git a/vllm_ascend/ops/layers/moe_mlp.py b/vllm_ascend/ops/layers/moe_mlp.py
index c73e8ea..d6f67bb 100644
--- a/vllm_ascend/ops/layers/moe_mlp.py
+++ b/vllm_ascend/ops/layers/moe_mlp.py
@@ -18,22 +18,52 @@ from typing import Optional
import torch
import torch_npu
+from torch.nn.functional import pad
from vllm.forward_context import get_forward_context
from vllm_ascend.ascend_forward_context import FusedMoEState
from vllm_ascend.utils import dispose_tensor, is_310p
+def cumsum_group_list(group_list: torch.Tensor,
+ group_list_type: int,
+ active_num: int = 0,
+ expert_num: int = 0) -> torch.Tensor:
+ if group_list_type not in [0, 1, 2]:
+ raise ValueError(
+ f"group_list_type should be in [0, 1, 2], but received {group_list_type}"
+ )
+
+ if group_list_type == 0:
+ return group_list
+ if group_list_type == 1:
+ return group_list.cumsum(dim=0)
+
+ experts = pad(group_list[:, 0], (1, 0))
+ tokens = pad(group_list[:, 1].cumsum(dim=0), (1, 0))
+ cumsum_group_list = torch.full(size=(expert_num, ),
+ fill_value=active_num,
+ dtype=group_list.dtype,
+ device=group_list.device)
+
+ for i, (start, end) in enumerate(zip(experts[:-1], experts[1:])):
+ if end > start:
+ cumsum_group_list[start:end] = tokens[i]
+
+ return cumsum_group_list
+
+
def quant_apply_mlp(hidden_states: torch.Tensor,
w1: torch.Tensor,
w1_scale: torch.Tensor,
w2: torch.Tensor,
w2_scale: torch.Tensor,
group_list: torch.Tensor,
- dynamic_scale: torch.Tensor = None,
group_list_type: int = 1,
+ dynamic_scale: torch.Tensor = None,
w1_scale_bias: torch.Tensor = None,
- w2_scale_bias: torch.Tensor = None) -> torch.Tensor:
+ w2_scale_bias: torch.Tensor = None,
+ fusion: bool = False) -> torch.Tensor:
if dynamic_scale is None:
unquantized_hidden_states = hidden_states
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
@@ -49,31 +79,38 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
is_mc2 = get_forward_context().fused_moe_state == FusedMoEState.MC2
if w1_scale_bias is None and is_mc2:
- w1_scale = w1_scale.to(torch.float32)
-
- # gmm1: gate_up_proj
- hidden_states = torch_npu.npu_grouped_matmul(
- x=[hidden_states],
- weight=[w1],
- split_item=3,
- group_list_type=group_list_type,
- group_type=0,
- group_list=group_list,
- output_dtype=torch.int32)[0]
-
- # act_fn: swiglu
- hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
- x=hidden_states,
- weight_scale=w1_scale,
- activation_scale=pertoken_scale,
- bias=None,
- quant_scale=None,
- quant_offset=None,
- group_index=group_list,
- activate_left=True,
- quant_mode=1,
- )
-
+ if w1_scale.dtype != torch.float32:
+ w1_scale = w1_scale.to(torch.float32)
+ if fusion:
+ # gmm1: gate_up_proj & act_fn: swiglu
+ hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant(
+ x=hidden_states,
+ weight=w1,
+ group_list=cumsum_group_list(group_list, group_list_type),
+ weight_scale=w1_scale,
+ x_scale=pertoken_scale)
+ else:
+ # gmm1: gate_up_proj
+ hidden_states = torch_npu.npu_grouped_matmul(
+ x=[hidden_states],
+ weight=[w1],
+ split_item=3,
+ group_list_type=group_list_type,
+ group_type=0,
+ group_list=group_list,
+ output_dtype=torch.int32)[0]
+ # act_fn: swiglu
+ hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
+ x=hidden_states,
+ weight_scale=w1_scale,
+ activation_scale=pertoken_scale,
+ bias=None,
+ quant_scale=None,
+ quant_offset=None,
+ group_index=group_list,
+ activate_left=True,
+ quant_mode=1,
+ )
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
@@ -92,29 +129,37 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
[group_list[:1],
torch.diff(group_list, dim=0)])
group_list_type = 1
- bias1 = [w1_scale_bias]
+ bias1 = [w1_scale_bias] if not fusion else w1_scale_bias
bias2 = [w2_scale_bias]
# TODO w4a8 scene: dynamic acquisition of dtype in the future
_output_dtype = torch.bfloat16
- # gmm1: gate_up_proj
- hidden_states = torch_npu.npu_grouped_matmul(
- x=[hidden_states],
- weight=[w1],
- scale=[w1_scale],
- bias=bias1,
- per_token_scale=[pertoken_scale],
- split_item=2,
- group_list_type=group_list_type,
- group_type=0,
- group_list=group_list,
- output_dtype=_output_dtype)[0]
-
- # act_fn: swiglu
- hidden_states = torch_npu.npu_swiglu(hidden_states)
- hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
- hidden_states)
-
+ if fusion:
+ # gmm1: gate_up_proj & act_fn: swiglu
+ hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant(
+ x=hidden_states,
+ weight=w1,
+ bias=bias1,
+ group_list=cumsum_group_list(group_list, group_list_type),
+ weight_scale=w1_scale,
+ x_scale=pertoken_scale)
+ else:
+ # gmm1: gate_up_proj
+ hidden_states = torch_npu.npu_grouped_matmul(
+ x=[hidden_states],
+ weight=[w1],
+ scale=[w1_scale.to(w2_scale.dtype)],
+ bias=bias1,
+ per_token_scale=[pertoken_scale],
+ split_item=2,
+ group_list_type=group_list_type,
+ group_type=0,
+ group_list=group_list,
+ output_dtype=_output_dtype)[0]
+ # act_fn: swiglu
+ hidden_states = torch_npu.npu_swiglu(hidden_states)
+ hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
+ hidden_states)
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
@@ -127,6 +172,7 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
group_type=0,
group_list=group_list,
output_dtype=_output_dtype)[0]
+
return hidden_states
@@ -178,7 +224,8 @@ def unified_apply_mlp(hidden_states: torch.Tensor,
w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None,
topk_scales: Optional[torch.Tensor] = None,
- with_quant: bool = False) -> torch.Tensor:
+ with_quant: bool = False,
+ fusion: bool = False) -> torch.Tensor:
if with_quant:
return quant_apply_mlp(hidden_states=hidden_states,
w1=w1,
@@ -189,7 +236,8 @@ def unified_apply_mlp(hidden_states: torch.Tensor,
dynamic_scale=dynamic_scale,
group_list_type=group_list_type,
w1_scale_bias=w1_scale_bias,
- w2_scale_bias=w2_scale_bias)
+ w2_scale_bias=w2_scale_bias,
+ fusion=fusion)
else:
return unquant_apply_mlp(hidden_states=hidden_states,
w1=w1,
diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py
index 20c68be..f710bd2 100644
--- a/vllm_ascend/quantization/w8a8_dynamic.py
+++ b/vllm_ascend/quantization/w8a8_dynamic.py
@@ -31,173 +31,7 @@ from vllm_ascend.ops.common_fused_moe import \
fused_experts as unified_fused_experts
from vllm_ascend.ops.fused_moe import unified_fused_experts_eager
from vllm_ascend.ops.layers.experts_selector import select_experts
-from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, dispose_tensor
-
-
-def apply_mlp_decode(hidden_states: torch.Tensor,
- w1: torch.Tensor,
- w1_scale: torch.Tensor,
- w2: torch.Tensor,
- w2_scale: torch.Tensor,
- group_list: torch.Tensor,
- dynamic_scale: torch.Tensor = None,
- group_list_type: int = 1) -> torch.Tensor:
- """
- apply MLP: gate_up_proj -> swiglu -> down_proj
- Args:
- hidden_states_wrapper: wrapper of input hidden states with shape (num_tokens, hidden_size).
- w1: expert weights1 with shape
- (num_experts, hidden_size, intermediate_size * 2)
- w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2)
- w2: expert weights2 with shape
- (num_experts, intermediate_size, hidden_size)
- w2_scale: weights2 scale with shape (num_experts, hidden_size)
- group_list: number of tokens for each expert, follow cumsum mode, and
- with shape (num_experts).
- transpose_weight:
- w1: (num_experts, intermediate_size * 2, hidden_size) ->
- (num_experts, hidden_size, intermediate_size * 2)
- w2: (num_experts, hidden_size, intermediate_size) ->
- (num_experts, intermediate_size, hidden_size)
- Returns:
- hidden_states: output hidden states after MLP.
- """
-
- if dynamic_scale is None:
- unquantized_hidden_states = hidden_states
- hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
- hidden_states)
- # Dispose the original unquantized hidden states
- # to save npu memory because they're no longer used.
- dispose_tensor(unquantized_hidden_states)
- else:
- pertoken_scale = dynamic_scale
-
- # gmm1: gate_up_proj
- hidden_states = torch_npu.npu_grouped_matmul(
- x=[hidden_states],
- weight=[w1],
- split_item=3,
- group_list_type=group_list_type,
- group_type=0,
- group_list=group_list,
- output_dtype=torch.int32)[0]
-
- # act_fn: swiglu
- hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
- x=hidden_states,
- weight_scale=w1_scale,
- activation_scale=pertoken_scale,
- bias=None,
- quant_scale=None,
- quant_offset=None,
- group_index=group_list,
- activate_left=True,
- quant_mode=1,
- )
-
- # gmm2: down_proj
- hidden_states = torch_npu.npu_grouped_matmul(
- x=[hidden_states],
- weight=[w2],
- scale=[w2_scale],
- per_token_scale=[swiglu_out_scale],
- split_item=2,
- group_list_type=group_list_type,
- group_type=0,
- group_list=group_list,
- output_dtype=w2_scale.dtype)[0]
- return hidden_states
-
-
-def apply_mlp(hidden_states: torch.Tensor,
- w1: torch.Tensor,
- w1_scale: torch.Tensor,
- w2: torch.Tensor,
- w2_scale: torch.Tensor,
- group_list: torch.Tensor,
- dynamic_scale: torch.Tensor = None,
- group_list_type: int = 1,
- w1_scale_bias: torch.Tensor = None,
- w2_scale_bias: torch.Tensor = None) -> torch.Tensor:
- """
- apply MLP: gate_up_proj -> swiglu -> down_proj
-
- Args:
- hidden_states: input hidden states with shape (num_tokens, hidden_size).
- w1: expert weights1 with shape
- (num_experts, hidden_size, intermediate_size * 2)
- w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2)
- w2: expert weights2 with shape
- (num_experts, intermediate_size, hidden_size)
- w2_scale: weights2 scale with shape (num_experts, hidden_size)
- group_list: number of tokens for each expert, follow cumsum mode, and
- with shape (num_experts).
- transpose_weight:
- w1: (num_experts, intermediate_size * 2, hidden_size) ->
- (num_experts, hidden_size, intermediate_size * 2)
- w2: (num_experts, hidden_size, intermediate_size) ->
- (num_experts, intermediate_size, hidden_size)
-
- Returns:
- hidden_states: output hidden states after MLP.
- """
-
- if dynamic_scale is None:
- unquantized_hidden_states = hidden_states
- hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
- hidden_states)
- # Dispose the original unquantized hidden states
- # to save npu memory because they're no longer used.
- dispose_tensor(unquantized_hidden_states)
- else:
- pertoken_scale = dynamic_scale
-
- bias1, bias2 = None, None
- _output_dtype = w2_scale.dtype
-
- if w1_scale_bias is not None:
- if group_list_type == 0:
- group_list = torch.cat(
- [group_list[:1], torch.diff(group_list, dim=0)])
- group_list_type = 1
- bias1 = [w1_scale_bias]
- bias2 = [w2_scale_bias]
- # TODO w4a8 scene: dynamic acquisition of dtype in the future
- _output_dtype = torch.bfloat16
-
- # gmm1: gate_up_proj
- hidden_states = torch_npu.npu_grouped_matmul(
- x=[hidden_states],
- weight=[w1],
- scale=[w1_scale],
- bias=bias1,
- per_token_scale=[pertoken_scale],
- split_item=2,
- group_list_type=group_list_type,
- group_type=0,
- group_list=group_list,
- output_dtype=_output_dtype)[0]
-
- # act_fn: swiglu
- hidden_states = torch_npu.npu_swiglu(hidden_states)
- hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
- hidden_states)
-
- # gmm2: down_proj
- hidden_states = torch_npu.npu_grouped_matmul(
- x=[hidden_states],
- weight=[w2],
- scale=[w2_scale],
- bias=bias2,
- per_token_scale=[swiglu_out_scale],
- split_item=2,
- group_list_type=group_list_type,
- group_type=0,
- group_list=group_list,
- output_dtype=_output_dtype)[0]
-
- return hidden_states
+from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ
class AscendW8A8DynamicLinearMethod:
@@ -418,7 +252,7 @@ class AscendW8A8DynamicFusedMoEMethod:
return unified_fused_experts_eager(
hidden_states=x,
w1=layer.w13_weight,
- w1_scale=layer.w13_weight_scale,
+ w1_scale=layer.w13_weight_scale_fp32,
w2=layer.w2_weight,
w2_scale=layer.w2_weight_scale,
topk_weights=topk_weights,
@@ -431,7 +265,8 @@ class AscendW8A8DynamicFusedMoEMethod:
shared_gate_up=shared_gate_up,
shared_dequant_scale=shared_dequant_scale,
mc2_mask=kwargs.get("mc2_mask", None),
- with_quant=True)
+ with_quant=True,
+ fusion_mlp=True)
def process_weights_after_loading(self, layer):
if self.transpose_weight:
@@ -439,6 +274,7 @@ class AscendW8A8DynamicFusedMoEMethod:
1, 2).contiguous()
layer.w2_weight.data = layer.w2_weight.data.transpose(
1, 2).contiguous()
+ torch_npu.npu_format_cast_(layer.w13_weight, ACL_FORMAT_FRACTAL_NZ)
if envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP:
torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ)
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(