adjusting the communication method in graph mode (#1194)
### What this PR does / why we need it? Communication performance optimization: replace allreduce with reduce_scatter+all_gather in MLA layer's TP group,to remove stridedsliced and all_gather in MOE layer. when tp > 1, It is enabled during the decode phase of the graph mode when enable_multistream_moe、MLA, use_v1, and MC2 are used. According to the end-to-end RL inference test results, this PR can bring 3% gain in the decode stage. **Before Improvement** Profiling kernel_details  Evaluation   **After Improvement** Profiling kernel_details  Evaluation   ### Does this PR introduce _any_ user-facing change? Users need to configure enable_multistream_moe=True ### How was this patch tested? Add e2e test cases to cover code logic Signed-off-by: sharonyunyun <zhangying134@huawei.com>
This commit is contained in:
@@ -9,6 +9,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
|
||||
MLAAttentionImpl)
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.utils import cdiv, round_down
|
||||
@@ -557,6 +558,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
|
||||
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
@@ -586,7 +588,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
x = torch.bmm(x, self.W_UV)
|
||||
# Convert from (N, B, V) to (B, N * V)
|
||||
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
|
||||
return self.o_proj(x)[0]
|
||||
return self.o_proj(x, is_prefill=False)[0]
|
||||
|
||||
# Return `ql_nope`, `q_pe`
|
||||
def _q_proj_and_k_up_proj(self, x):
|
||||
@@ -847,12 +849,12 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
|
||||
current_ms_metadata = get_multistream_comm_context()
|
||||
if current_ms_metadata is None:
|
||||
return self.o_proj(attn_output)[0]
|
||||
return self.o_proj(attn_output, is_prefill=True)[0]
|
||||
else:
|
||||
current_ms_metadata.before_comm_event.record()
|
||||
with torch.npu.stream(current_ms_metadata.comm_stream):
|
||||
current_ms_metadata.before_comm_event.wait()
|
||||
return self.o_proj(attn_output)[0]
|
||||
return self.o_proj(attn_output, is_prefill=True)[0]
|
||||
|
||||
def exec_kv(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user