[main][prefill optimization] Optimize parallel strategies to reduce communication overhead (#2198)
### What this PR does / why we need it?
1.Shared Expert Sharding Strategy Update: Switched from TP-aligned to
pure DP for shared experts, enabling more efficient execution.
2.O_Proj AllReduce → ReduceScatter: Reduced communication overhead by
using ReduceScatter, made possible by pure DP sharding.
3.AllGather Postponed: Delayed to after QKV down projection to reduce
synchronization impact during prefill.
### How was this patch tested?
Adding ut case in `tests/ut/attention/test_mla_v1.py`
#### How to run
use parameter `--additional_config='{"enable_shared_expert_dp": true}'`
##### a.How to run eager mode
eg:
python -m vllm.entrypoints.openai.api_server --model=/model_path
--trust-remote-code -tp 8 -dp 2 --enable_expert_parallel --port 8002
--max-model-len 5120 --max-num-batched-tokens 16384 --enforce-eager
--disable-log-requests
--additional_config='{"ascend_scheduler_config":{"enabled":true},"enable_shared_expert_dp":
true,"chunked_prefill_for_mla":true}'
##### b.How to run graph mode
eg:
python -m vllm.entrypoints.openai.api_server --model=/model_path
--trust-remote-code -tp 8 -dp 2 --enable_expert_parallel --port 8002
--max-model-len 5120 --max-num-batched-tokens 16384
--disable-log-requests
--additional_config='{"ascend_scheduler_config":{"enabled":true},"enable_shared_expert_dp":
true,"chunked_prefill_for_mla":true,"torchair_graph_config":{"enabled":true}}'
- vLLM version: v0.10.0
- vLLM main:
9edd1db02b
---------
Signed-off-by: Wang Kunpeng <1289706727@qq.com>
Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
Co-authored-by: SlightwindSec <slightwindsec@gmail.com>
This commit is contained in:
@@ -691,3 +691,40 @@ class TestAscendMLAImpl(TestBase):
|
||||
self.assertEqual(result.shape[2], self.impl.v_head_dim)
|
||||
mock_up_proj.assert_called_once()
|
||||
mock_page_attention_mla.assert_called_once()
|
||||
|
||||
@patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._forward_prefill")
|
||||
@patch("torch_npu._npu_reshape_and_cache")
|
||||
def test_forward_without_graph(self, _, mock_forward_prefill):
|
||||
self.impl.running_in_graph = False
|
||||
self.impl.torchair_graph_enabled = False
|
||||
|
||||
num_tokens = 100
|
||||
num_blocks = 256
|
||||
block_size = 4
|
||||
rotary_emb_return_value = (torch.randn(num_tokens, 16,
|
||||
self.impl.kv_lora_rank),
|
||||
torch.randn(0, 1, self.impl.kv_lora_rank))
|
||||
self.impl.rotary_emb.side_effect = lambda *args, **kwargs: rotary_emb_return_value
|
||||
self.impl.o_proj.side_effect = lambda *args, **kwargs: torch.randn(
|
||||
1, num_blocks, 128)
|
||||
|
||||
hidden_states_or_q_c = torch.randn(num_tokens, self.impl.q_lora_rank)
|
||||
hidden_states_or_kv_c_normed = torch.randn(num_tokens,
|
||||
self.impl.kv_lora_rank)
|
||||
k_pe = torch.randn(num_tokens, self.impl.qk_rope_head_dim)
|
||||
kv_cache = (torch.randn(num_blocks, block_size, self.impl.num_heads,
|
||||
self.impl.kv_lora_rank),
|
||||
torch.randn(num_blocks, block_size, self.impl.num_heads,
|
||||
self.impl.qk_rope_head_dim))
|
||||
output = torch.randn(num_tokens, self.impl.num_heads,
|
||||
self.impl.v_head_dim)
|
||||
|
||||
metadata = MagicMock()
|
||||
metadata.num_decodes = 0
|
||||
metadata.num_prefills = num_tokens
|
||||
mock_forward_prefill.return_value = torch.randn(
|
||||
0, self.impl.num_heads * self.impl.v_head_dim)
|
||||
result = self.impl.forward(None, hidden_states_or_q_c,
|
||||
hidden_states_or_kv_c_normed, k_pe,
|
||||
kv_cache, metadata, output, False)
|
||||
self.assertEqual(result.shape[0], num_tokens)
|
||||
|
||||
Reference in New Issue
Block a user