[Perf] move quant before allgather in Allgather EP (#3420)
### What this PR does / why we need it?
move quant before allgather in Allgather EP, rely on
https://github.com/vllm-project/vllm-ascend/pull/3334
Deepseek R1 W8A8 performance on A2 with
`HCCL_ALGO="level0:NA;level1:pipeline"`:
| Seq length | Mean TTFT (ms) main | Mean TTFT (ms) this PR |
|----------|----------|----------|
| 4k | 375.21 | 364.99 |
| 16k | 1465.23 | 1421.75 |
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
83f478bb19
---------
Signed-off-by: realliujiaxu <realliujiaxu@163.com>
This commit is contained in:
@@ -458,6 +458,7 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
dtype=torch.float32))
|
||||
|
||||
hidden_states = torch.randn(10, 20, dtype=torch.bfloat16)
|
||||
hidden_states_shape = hidden_states.shape
|
||||
w1 = torch.randn(5, 20, 40, dtype=torch.bfloat16)
|
||||
w1_scale = torch.randn(5, 40, dtype=torch.bfloat16)
|
||||
w2 = torch.randn(5, 40, 20, dtype=torch.bfloat16)
|
||||
@@ -486,7 +487,7 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
mock_npu_swiglu.assert_called_once()
|
||||
mock_npu_dynamic_quant.assert_called_once()
|
||||
|
||||
self.assertEqual(result.shape, hidden_states.shape)
|
||||
self.assertEqual(result.shape, hidden_states_shape)
|
||||
self.assertEqual(result.dtype, torch.bfloat16)
|
||||
|
||||
@patch('vllm_ascend.ops.fused_moe.moe_mlp.is_310p')
|
||||
@@ -568,6 +569,7 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
dtype=torch.float32))
|
||||
|
||||
hidden_states = torch.randn(10, 20, dtype=torch.bfloat16)
|
||||
hidden_states_shape = hidden_states.shape
|
||||
w1 = torch.randn(5, 20, 40, dtype=torch.bfloat16)
|
||||
w1_scale = torch.randn(5, 40, dtype=torch.bfloat16)
|
||||
w2 = torch.randn(5, 40, 20, dtype=torch.bfloat16)
|
||||
@@ -596,7 +598,7 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
mock_npu_grouped_matmul_swiglu_quant.assert_called_once()
|
||||
|
||||
self.assertTrue(mock_forward_context.with_quant)
|
||||
self.assertEqual(result.shape, hidden_states.shape)
|
||||
self.assertEqual(result.shape, hidden_states_shape)
|
||||
self.assertEqual(result.dtype, torch.bfloat16)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user