adjusting the communication method in graph mode (#1194)

### What this PR does / why we need it?
Communication performance optimization: replace allreduce with
reduce_scatter+all_gather in MLA layer's TP group,to remove
stridedsliced and all_gather in MOE layer.
when tp > 1, It is enabled during the decode phase of the graph mode
when enable_multistream_moe、MLA, use_v1, and MC2 are used.
According to the end-to-end RL inference test results, this PR can bring
3% gain in the decode stage.

**Before Improvement**
Profiling kernel_details

![image](https://github.com/user-attachments/assets/1bb5dfa1-809b-410a-90c9-c5fd23cff003)
Evaluation

![image](https://github.com/user-attachments/assets/0b8ea0c7-88e7-410f-9ef4-f0cfe910cdc7)

![image](https://github.com/user-attachments/assets/94fde910-c125-4c2e-8de4-88fc3fafc057)

**After Improvement**
Profiling kernel_details

![image](https://github.com/user-attachments/assets/55fac0e0-11f2-4654-8fd4-287949e0b29e)
Evaluation

![image](https://github.com/user-attachments/assets/e923f74b-29c4-4171-9382-40a00cf05df0)

![image](https://github.com/user-attachments/assets/5dba7967-07ea-4926-a8be-804bfd34e3e4)

### Does this PR introduce _any_ user-facing change?
Users need to configure enable_multistream_moe=True

### How was this patch tested?
Add e2e test cases to cover code logic

Signed-off-by: sharonyunyun <zhangying134@huawei.com>
This commit is contained in:
sharonyunyun
2025-06-25 19:56:49 +08:00
committed by GitHub
parent 205cb85a1e
commit 941269a6c5
6 changed files with 195 additions and 37 deletions

View File

@@ -9,6 +9,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
MLAAttentionImpl)
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.utils import cdiv, round_down
@@ -557,6 +558,7 @@ class AscendMLAImpl(MLAAttentionImpl):
self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.tp_size = get_tensor_model_parallel_world_size()
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
@@ -586,7 +588,7 @@ class AscendMLAImpl(MLAAttentionImpl):
x = torch.bmm(x, self.W_UV)
# Convert from (N, B, V) to (B, N * V)
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
return self.o_proj(x)[0]
return self.o_proj(x, is_prefill=False)[0]
# Return `ql_nope`, `q_pe`
def _q_proj_and_k_up_proj(self, x):
@@ -847,12 +849,12 @@ class AscendMLAImpl(MLAAttentionImpl):
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is None:
return self.o_proj(attn_output)[0]
return self.o_proj(attn_output, is_prefill=True)[0]
else:
current_ms_metadata.before_comm_event.record()
with torch.npu.stream(current_ms_metadata.comm_stream):
current_ms_metadata.before_comm_event.wait()
return self.o_proj(attn_output)[0]
return self.o_proj(attn_output, is_prefill=True)[0]
def exec_kv(
self,