[Perf] move quant before allgather in Allgather EP (#3420)

### What this PR does / why we need it?
move quant before allgather in Allgather EP, rely on
https://github.com/vllm-project/vllm-ascend/pull/3334

Deepseek R1 W8A8 performance on A2 with
`HCCL_ALGO="level0:NA;level1:pipeline"`:
| Seq length | Mean TTFT (ms) main | Mean TTFT (ms)  this PR |
|----------|----------|----------|
| 4k   |  375.21  | 364.99   |
| 16k  | 1465.23   | 1421.75  |
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?


- vLLM version: v0.11.0
- vLLM main:
83f478bb19

---------

Signed-off-by: realliujiaxu <realliujiaxu@163.com>
This commit is contained in:
realliujiaxu
2025-11-04 16:49:58 +08:00
committed by GitHub
parent 44b58b8665
commit bedf223771
10 changed files with 160 additions and 66 deletions

View File

@@ -458,6 +458,7 @@ class TestUnifiedApplyMLP(TestBase):
dtype=torch.float32))
hidden_states = torch.randn(10, 20, dtype=torch.bfloat16)
hidden_states_shape = hidden_states.shape
w1 = torch.randn(5, 20, 40, dtype=torch.bfloat16)
w1_scale = torch.randn(5, 40, dtype=torch.bfloat16)
w2 = torch.randn(5, 40, 20, dtype=torch.bfloat16)
@@ -486,7 +487,7 @@ class TestUnifiedApplyMLP(TestBase):
mock_npu_swiglu.assert_called_once()
mock_npu_dynamic_quant.assert_called_once()
self.assertEqual(result.shape, hidden_states.shape)
self.assertEqual(result.shape, hidden_states_shape)
self.assertEqual(result.dtype, torch.bfloat16)
@patch('vllm_ascend.ops.fused_moe.moe_mlp.is_310p')
@@ -568,6 +569,7 @@ class TestUnifiedApplyMLP(TestBase):
dtype=torch.float32))
hidden_states = torch.randn(10, 20, dtype=torch.bfloat16)
hidden_states_shape = hidden_states.shape
w1 = torch.randn(5, 20, 40, dtype=torch.bfloat16)
w1_scale = torch.randn(5, 40, dtype=torch.bfloat16)
w2 = torch.randn(5, 40, 20, dtype=torch.bfloat16)
@@ -596,7 +598,7 @@ class TestUnifiedApplyMLP(TestBase):
mock_npu_grouped_matmul_swiglu_quant.assert_called_once()
self.assertTrue(mock_forward_context.with_quant)
self.assertEqual(result.shape, hidden_states.shape)
self.assertEqual(result.shape, hidden_states_shape)
self.assertEqual(result.dtype, torch.bfloat16)