Support multistream of MLA vector operations (#1135)
### What this PR does / why we need it?
Move all vector operations to a secondary stream, with the expected
overlaping being:
```
| q_rmsnorm | | kv_norm_rope_cache | | q_rope |
| matmul W_DQ | matmul W_DKV | index | index | matmul W_UQ | split | matmul W_KV_T |
```
Currently, the `IndexByTensor` operators introduced by computation of
`cos` and `sin` can't be offloaded to the secondary stream due to a
known bug of graph fusion optimization pass. So we instead keep it in
the main stream, only requires it be computed before `matmul W_UQ` to
avoid hindering later overlapping. The problem may be solved by later
optimization (#993), which hoists the computation of `cos` and `sin` up
to the first layer.
### Does this PR introduce _any_ user-facing change?
Controlled by `torchair_graph_config.enable_multistream_mla`, defaulted
to False.
### How was this patch tested?
Tested on 1x16 910 node, with tailored 2 layer DSKv2.
Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
This commit is contained in:
@@ -39,6 +39,7 @@ The details of each config option are as follows:
|
|||||||
| Name | Type | Default | Description |
|
| Name | Type | Default | Description |
|
||||||
| ---- | ---- | ------- | ----------- |
|
| ---- | ---- | ------- | ----------- |
|
||||||
| `enabled` | bool | `False` | Whether to enable torchair graph mode |
|
| `enabled` | bool | `False` | Whether to enable torchair graph mode |
|
||||||
|
| `enable_multistream_mla`| bool | `False` | Whether to put vector ops of MLA to another stream |
|
||||||
| `enable_multistream_moe`| bool | `False` | Whether to enable multistream shared expert |
|
| `enable_multistream_moe`| bool | `False` | Whether to enable multistream shared expert |
|
||||||
| `enable_view_optimize` | bool | `True` | Whether to enable torchair view optimization |
|
| `enable_view_optimize` | bool | `True` | Whether to enable torchair view optimization |
|
||||||
| `use_cached_graph` | bool | `False` | Whether to use cached graph |
|
| `use_cached_graph` | bool | `False` | Whether to use cached graph |
|
||||||
|
|||||||
@@ -59,6 +59,7 @@ def test_run_with_ascend_config():
|
|||||||
"graph_batch_sizes": [1, 2, 4, 8],
|
"graph_batch_sizes": [1, 2, 4, 8],
|
||||||
"graph_batch_sizes_init": False,
|
"graph_batch_sizes_init": False,
|
||||||
"enable_multistream_moe": True,
|
"enable_multistream_moe": True,
|
||||||
|
"enable_multistream_mla": True,
|
||||||
},
|
},
|
||||||
"ascend_scheduler_config": {
|
"ascend_scheduler_config": {
|
||||||
"enabled": True,
|
"enabled": True,
|
||||||
@@ -79,6 +80,7 @@ def test_run_with_ascend_config():
|
|||||||
1, 2, 4, 8
|
1, 2, 4, 8
|
||||||
]
|
]
|
||||||
assert not ascend_config.torchair_graph_config.graph_batch_sizes_init
|
assert not ascend_config.torchair_graph_config.graph_batch_sizes_init
|
||||||
|
assert ascend_config.torchair_graph_config.enable_multistream_mla
|
||||||
assert ascend_config.torchair_graph_config.enable_multistream_moe
|
assert ascend_config.torchair_graph_config.enable_multistream_moe
|
||||||
assert ascend_config.ascend_scheduler_config.enabled
|
assert ascend_config.ascend_scheduler_config.enabled
|
||||||
assert ascend_config.ascend_scheduler_config.enable_chunked_prefill
|
assert ascend_config.ascend_scheduler_config.enable_chunked_prefill
|
||||||
|
|||||||
@@ -54,6 +54,8 @@ class TorchairGraphConfig:
|
|||||||
"graph_batch_sizes", [])
|
"graph_batch_sizes", [])
|
||||||
self.graph_batch_sizes_init = torchair_graph_config.get(
|
self.graph_batch_sizes_init = torchair_graph_config.get(
|
||||||
"graph_batch_sizes_init", False)
|
"graph_batch_sizes_init", False)
|
||||||
|
self.enable_multistream_mla = torchair_graph_config.get(
|
||||||
|
"enable_multistream_mla", False)
|
||||||
self.enable_multistream_moe = torchair_graph_config.get(
|
self.enable_multistream_moe = torchair_graph_config.get(
|
||||||
"enable_multistream_moe", False)
|
"enable_multistream_moe", False)
|
||||||
self.enable_view_optimize = torchair_graph_config.get(
|
self.enable_view_optimize = torchair_graph_config.get(
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
|||||||
from vllm_ascend.multistream.context import get_multistream_comm_context
|
from vllm_ascend.multistream.context import get_multistream_comm_context
|
||||||
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
||||||
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
|
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
|
||||||
|
from vllm_ascend.utils import npu_stream_switch, npu_wait_tensor
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
@@ -481,6 +482,9 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
ascend_config = get_ascend_config()
|
ascend_config = get_ascend_config()
|
||||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||||
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
|
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
|
||||||
|
self.enable_multistream_mla = \
|
||||||
|
ascend_config.torchair_graph_config.enable_multistream_mla
|
||||||
|
|
||||||
# Adapt torch air graph mode with spec decoding.
|
# Adapt torch air graph mode with spec decoding.
|
||||||
speculative_config = get_current_vllm_config().speculative_config
|
speculative_config = get_current_vllm_config().speculative_config
|
||||||
if speculative_config is not None:
|
if speculative_config is not None:
|
||||||
@@ -664,17 +668,20 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
|
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
|
||||||
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
|
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
|
||||||
cache_mode = "PA_NZ" if self.enable_kv_nz else "PA"
|
cache_mode = "PA_NZ" if self.enable_kv_nz else "PA"
|
||||||
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
|
with npu_stream_switch("mla_secondary",
|
||||||
kv,
|
0,
|
||||||
self.kv_a_layernorm.weight,
|
enabled=self.enable_multistream_mla):
|
||||||
cos,
|
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
|
||||||
sin,
|
kv,
|
||||||
slots.to(torch.int64),
|
self.kv_a_layernorm.weight,
|
||||||
kv_cache[1],
|
cos,
|
||||||
kv_cache[0],
|
sin,
|
||||||
epsilon=self.kv_a_layernorm.variance_epsilon,
|
slots.to(torch.int64),
|
||||||
cache_mode=cache_mode,
|
kv_cache[1],
|
||||||
)
|
kv_cache[0],
|
||||||
|
epsilon=self.kv_a_layernorm.variance_epsilon,
|
||||||
|
cache_mode=cache_mode,
|
||||||
|
)
|
||||||
return k_pe, k_nope
|
return k_pe, k_nope
|
||||||
|
|
||||||
def exec_kv_prefill(
|
def exec_kv_prefill(
|
||||||
@@ -867,23 +874,38 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
if has_decode:
|
if has_decode:
|
||||||
decode_k_nope = None
|
decode_k_nope = None
|
||||||
assert attn_metadata.decode is not None
|
assert attn_metadata.decode is not None
|
||||||
decode_ql_nope, decode_q_pe = \
|
|
||||||
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
|
|
||||||
if self.running_in_graph:
|
if self.running_in_graph:
|
||||||
seq_len = self.rotary_emb.max_position_embeddings
|
seq_len = self.rotary_emb.max_position_embeddings
|
||||||
cos = self.rotary_emb.cos_cached[:seq_len].to(
|
cos = self.rotary_emb.cos_cached[:seq_len].to(
|
||||||
dtype=decode_q_pe.dtype)
|
dtype=decode_hs_or_q_c.dtype)
|
||||||
sin = self.rotary_emb.sin_cached[:seq_len].to(
|
sin = self.rotary_emb.sin_cached[:seq_len].to(
|
||||||
dtype=decode_q_pe.dtype)
|
dtype=decode_hs_or_q_c.dtype)
|
||||||
cos = cos[attn_metadata.decode.input_positions]
|
cos = cos[attn_metadata.decode.input_positions]
|
||||||
sin = sin[attn_metadata.decode.input_positions]
|
sin = sin[attn_metadata.decode.input_positions]
|
||||||
cos = cos[:, None, None, :]
|
cos = cos[:, None, None, :]
|
||||||
sin = sin[:, None, None, :]
|
sin = sin[:, None, None, :]
|
||||||
|
# Without explicitly controlling the order, IndexByTensor operations
|
||||||
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
|
# would be placed after `matmul W_KV_T` hindering the overlapping of
|
||||||
|
# KvRmsNormRopeCache and SingleRope.
|
||||||
|
npu_wait_tensor(decode_hs_or_q_c,
|
||||||
|
cos,
|
||||||
|
enabled=self.enable_multistream_mla)
|
||||||
|
npu_wait_tensor(decode_hs_or_q_c,
|
||||||
|
sin,
|
||||||
|
enabled=self.enable_multistream_mla)
|
||||||
|
decode_ql_nope, decode_q_pe = \
|
||||||
|
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
|
||||||
|
if self.running_in_graph:
|
||||||
decode_k_pe, decode_k_nope = self.exec_kv(
|
decode_k_pe, decode_k_nope = self.exec_kv(
|
||||||
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
|
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
|
||||||
attn_metadata.slot_mapping)
|
attn_metadata.slot_mapping)
|
||||||
|
with npu_stream_switch("mla_secondary",
|
||||||
|
0,
|
||||||
|
enabled=self.enable_multistream_mla):
|
||||||
|
npu_wait_tensor(decode_q_pe,
|
||||||
|
decode_k_pe,
|
||||||
|
enabled=self.enable_multistream_mla)
|
||||||
|
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
|
||||||
else:
|
else:
|
||||||
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
|
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
|
||||||
attn_metadata.decode.input_positions,
|
attn_metadata.decode.input_positions,
|
||||||
|
|||||||
@@ -71,7 +71,8 @@ from vllm_ascend.distributed.parallel_state import get_ep_group
|
|||||||
from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
||||||
from vllm_ascend.quantization.quant_config import AscendLinearMethod
|
from vllm_ascend.quantization.quant_config import AscendLinearMethod
|
||||||
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
|
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
|
||||||
from vllm_ascend.utils import dispose_tensor
|
from vllm_ascend.utils import (dispose_tensor, npu_stream_switch,
|
||||||
|
npu_wait_tensor)
|
||||||
|
|
||||||
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
|
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
|
||||||
|
|
||||||
@@ -496,6 +497,8 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
|
|||||||
|
|
||||||
ascend_config = get_ascend_config()
|
ascend_config = get_ascend_config()
|
||||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||||
|
self.enable_multistream_mla = \
|
||||||
|
ascend_config.torchair_graph_config.enable_multistream_mla
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -505,7 +508,14 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
|
|||||||
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
|
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
|
||||||
if self.q_lora_rank is not None:
|
if self.q_lora_rank is not None:
|
||||||
ckq = self.q_a_proj(hidden_states)[0]
|
ckq = self.q_a_proj(hidden_states)[0]
|
||||||
hidden_states_or_q_c = self.q_a_layernorm(ckq)
|
use_multistream_mla = (self.enable_multistream_mla
|
||||||
|
and attn_metadata is not None
|
||||||
|
and attn_metadata.num_decodes > 0)
|
||||||
|
npu_wait_tensor(hidden_states, ckq, enabled=use_multistream_mla)
|
||||||
|
with npu_stream_switch("mla_secondary",
|
||||||
|
0,
|
||||||
|
enabled=use_multistream_mla):
|
||||||
|
hidden_states_or_q_c = self.q_a_layernorm(ckq)
|
||||||
else:
|
else:
|
||||||
hidden_states_or_q_c = hidden_states
|
hidden_states_or_q_c = hidden_states
|
||||||
if self.torchair_graph_enabled:
|
if self.torchair_graph_enabled:
|
||||||
|
|||||||
Reference in New Issue
Block a user