Support multistream of MLA vector operations (#1135)
### What this PR does / why we need it?
Move all vector operations to a secondary stream, with the expected
overlaping being:
```
| q_rmsnorm | | kv_norm_rope_cache | | q_rope |
| matmul W_DQ | matmul W_DKV | index | index | matmul W_UQ | split | matmul W_KV_T |
```
Currently, the `IndexByTensor` operators introduced by computation of
`cos` and `sin` can't be offloaded to the secondary stream due to a
known bug of graph fusion optimization pass. So we instead keep it in
the main stream, only requires it be computed before `matmul W_UQ` to
avoid hindering later overlapping. The problem may be solved by later
optimization (#993), which hoists the computation of `cos` and `sin` up
to the first layer.
### Does this PR introduce _any_ user-facing change?
Controlled by `torchair_graph_config.enable_multistream_mla`, defaulted
to False.
### How was this patch tested?
Tested on 1x16 910 node, with tailored 2 layer DSKv2.
Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
This commit is contained in:
@@ -71,7 +71,8 @@ from vllm_ascend.distributed.parallel_state import get_ep_group
|
||||
from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
||||
from vllm_ascend.quantization.quant_config import AscendLinearMethod
|
||||
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
|
||||
from vllm_ascend.utils import dispose_tensor
|
||||
from vllm_ascend.utils import (dispose_tensor, npu_stream_switch,
|
||||
npu_wait_tensor)
|
||||
|
||||
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
|
||||
|
||||
@@ -496,6 +497,8 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
self.enable_multistream_mla = \
|
||||
ascend_config.torchair_graph_config.enable_multistream_mla
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -505,7 +508,14 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
|
||||
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
|
||||
if self.q_lora_rank is not None:
|
||||
ckq = self.q_a_proj(hidden_states)[0]
|
||||
hidden_states_or_q_c = self.q_a_layernorm(ckq)
|
||||
use_multistream_mla = (self.enable_multistream_mla
|
||||
and attn_metadata is not None
|
||||
and attn_metadata.num_decodes > 0)
|
||||
npu_wait_tensor(hidden_states, ckq, enabled=use_multistream_mla)
|
||||
with npu_stream_switch("mla_secondary",
|
||||
0,
|
||||
enabled=use_multistream_mla):
|
||||
hidden_states_or_q_c = self.q_a_layernorm(ckq)
|
||||
else:
|
||||
hidden_states_or_q_c = hidden_states
|
||||
if self.torchair_graph_enabled:
|
||||
|
||||
Reference in New Issue
Block a user