[Refactor][WIP] Refactor mla_v1 by moving all MLA preprocessing ops into mla_v1 attention impl (#2465)

### What this PR does / why we need it?
In order to support fused kernels, multi-stream, communication
optimization etc, it's better to aggregate all opreations in Attention
layer togather. This PR tries to refactor mla_v1 by moving all MLA
preprocessing ops into mla_v1 attention impl.
Note that new mla_v1 doesn't take torchair into consideration. So this
PR can only be merged after torchair related mla_v1 is isolated into a
new file.
### Does this PR introduce _any_ user-facing change?
NO
### How was this patch tested?

### Features Test

<img width="506" height="141" alt="image"
src="https://github.com/user-attachments/assets/f1ab2906-a1ac-4450-8433-94811cd89466"
/>

### Performance After Refact
<img width="648" height="486" alt="image"
src="https://github.com/user-attachments/assets/e33e038c-c5d9-4ba7-a8e9-1ac22f9833eb"
/>

### Performance Before Refact
<img width="618" height="494" alt="image"
src="https://github.com/user-attachments/assets/83861dc2-dc51-4af3-9310-90ab10c43bb1"
/>


- vLLM version: v0.10.1.1
- vLLM main:
e03940762b

---------

Signed-off-by: lwq <liwenquan5@huawei.com>
Signed-off-by: whx-sjtu <2952154980@qq.com>
Signed-off-by: SunnyLee219 <3294305115@qq.com>
Co-authored-by: lwq <liwenquan5@huawei.com>
Co-authored-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
LeeWenquan
2025-08-28 10:35:57 +08:00
committed by GitHub
parent 320edde2df
commit c8d1df3a3f
5 changed files with 410 additions and 345 deletions

View File

@@ -31,6 +31,7 @@ The following table lists the additional configuration options available in vLLM
| `refresh` | bool | `false` | Whether to refresh global ascend config content. This value is usually used by rlhf or ut/e2e test case. | | `refresh` | bool | `false` | Whether to refresh global ascend config content. This value is usually used by rlhf or ut/e2e test case. |
| `expert_map_path` | str | `None` | When using expert load balancing for the MOE model, an expert map path needs to be passed in. | | `expert_map_path` | str | `None` | When using expert load balancing for the MOE model, an expert map path needs to be passed in. |
| `chunked_prefill_for_mla` | bool | `False` | Whether to enable the fused operator-like chunked_prefill. | | `chunked_prefill_for_mla` | bool | `False` | Whether to enable the fused operator-like chunked_prefill. |
| `enable_prefetch` | bool | `False` | Whether to enable weight prefetch. |
| `kv_cache_dtype` | str | `None` | When using the kv cache quantization method, kv cache dtype needs to be set, currently only int8 is supported. | | `kv_cache_dtype` | str | `None` | When using the kv cache quantization method, kv cache dtype needs to be set, currently only int8 is supported. |
| `enable_shared_expert_dp` | bool | `False` | When the shared expert in DP, it has better performance but consumes more memory. Currently only DeepSeek series models are supported to use. | | `enable_shared_expert_dp` | bool | `False` | When the shared expert in DP, it has better performance but consumes more memory. Currently only DeepSeek series models are supported to use. |

View File

@@ -210,6 +210,7 @@ class TestAscendMLAMetadataBuilder(TestBase):
with patch("vllm_ascend.attention.mla_v1.get_ascend_config", with patch("vllm_ascend.attention.mla_v1.get_ascend_config",
return_value=ascend_config): return_value=ascend_config):
builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device) builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device)
builder.decode_threshold = 1
input_batch = MagicMock() input_batch = MagicMock()
input_batch.req_ids = [0, 1, 2, 3] input_batch.req_ids = [0, 1, 2, 3]
@@ -303,18 +304,16 @@ class TestAscendMLAImpl(TestBase):
self.assertEqual(self.impl.num_queries_per_kv, 32) self.assertEqual(self.impl.num_queries_per_kv, 32)
self.assertEqual(self.impl.tp_size, 2) self.assertEqual(self.impl.tp_size, 2)
def test_v_up_proj_and_o_proj(self): def test_v_up_proj(self):
batch_size = 4 batch_size = 4
x = torch.randn(batch_size, self.impl.num_heads, x = torch.randn(batch_size, self.impl.num_heads,
self.impl.kv_lora_rank) self.impl.kv_lora_rank)
self.impl.o_proj.return_value = (torch.randn(
batch_size, self.impl.num_heads * self.impl.v_head_dim), )
if not hasattr(self.impl, 'W_UV') or self.impl.W_UV is None: if not hasattr(self.impl, 'W_UV') or self.impl.W_UV is None:
self.impl.W_UV = torch.randn(self.impl.num_heads, self.impl.W_UV = torch.randn(self.impl.num_heads,
self.impl.kv_lora_rank, self.impl.kv_lora_rank,
self.impl.v_head_dim) self.impl.v_head_dim)
result = self.impl._v_up_proj_and_o_proj(x) result = self.impl._v_up_proj(x)
self.assertEqual(result.shape[0], batch_size) self.assertEqual(result.shape[0], batch_size)
self.assertEqual(result.shape[1], self.assertEqual(result.shape[1],
@@ -371,8 +370,11 @@ class TestAscendMLAImpl(TestBase):
metadata.prefill = None metadata.prefill = None
prefix_out = torch.randn(2, 16, 128) prefix_out = torch.randn(2, 16, 128)
prefix_lse = torch.randn(2, 16, 8) prefix_lse = torch.randn(2, 16, 8)
out, lse = self.impl._compute_prefill_context(query, kv_cache, 32, q_pe = query[..., self.impl.qk_nope_head_dim:]
metadata, prefix_out, q_nope = query[..., :self.impl.qk_nope_head_dim]
out, lse = self.impl._compute_prefill_context(q_nope, q_pe, kv_cache,
32, metadata, prefix_out,
prefix_lse) prefix_lse)
self.assertTrue(torch.equal(prefix_out, out)) self.assertTrue(torch.equal(prefix_out, out))
@@ -386,6 +388,8 @@ class TestAscendMLAImpl(TestBase):
latent_kv_dim = self.impl.kv_lora_rank latent_kv_dim = self.impl.kv_lora_rank
num_blocks, block_size = 100, 20 num_blocks, block_size = 100, 20
query = torch.randn(S, N, D) query = torch.randn(S, N, D)
q_nope = query[..., :self.impl.qk_nope_head_dim]
q_pe = query[..., self.impl.qk_nope_head_dim:]
kv_cache_0 = torch.randn(num_blocks, block_size, N, latent_kv_dim) kv_cache_0 = torch.randn(num_blocks, block_size, N, latent_kv_dim)
kv_cache_1 = torch.randn(num_blocks, block_size, N, D) kv_cache_1 = torch.randn(num_blocks, block_size, N, D)
kv_cache = [kv_cache_0, kv_cache_1] kv_cache = [kv_cache_0, kv_cache_1]
@@ -406,9 +410,11 @@ class TestAscendMLAImpl(TestBase):
meta = MagicMock() meta = MagicMock()
meta.prefill = prefill_meta meta.prefill = prefill_meta
self.impl.prefill_mask = torch.triu(
torch.ones(512, 512, device=q_nope.device, dtype=q_nope.dtype), 1)
out, lse = self.impl._compute_prefill_context(query, kv_cache, 32, out, lse = self.impl._compute_prefill_context(q_nope, q_pe, kv_cache,
meta, prefix_out, 32, meta, prefix_out,
prefix_lse) prefix_lse)
mock_load.assert_called_once() mock_load.assert_called_once()
@@ -417,67 +423,36 @@ class TestAscendMLAImpl(TestBase):
self.assertEqual(out.shape, prefix_out.shape) self.assertEqual(out.shape, prefix_out.shape)
self.assertEqual(lse.shape, prefix_lse.shape) self.assertEqual(lse.shape, prefix_lse.shape)
@patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._v_up_proj_and_o_proj") @patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._v_up_proj")
@patch("torch_npu._npu_paged_attention_mla") @patch("torch_npu.npu_fused_infer_attention_score")
def test_forward_decode_without_graph(self, mock_page_attention_mla, def test_forward_decode_without_graph(self,
mock_npu_fused_infer_attention_score,
mock_up_proj): mock_up_proj):
num_tokens = 100 num_tokens = 100
num_blocks = 256
block_size = 4 block_size = 4
q_nope = torch.randn(num_tokens, self.impl.num_heads, q_nope = torch.randn(num_tokens, self.impl.num_heads,
self.impl.qk_nope_head_dim) self.impl.qk_nope_head_dim)
q_pe = torch.randn(num_tokens, self.impl.num_heads, q_pe = torch.randn(num_tokens, self.impl.num_heads,
self.impl.qk_rope_head_dim) self.impl.qk_rope_head_dim)
kv_c_and_k_pe_cache = torch.randn(num_blocks, block_size, k_nope = torch.randn(num_tokens, self.impl.num_heads,
self.impl.num_heads, self.impl.qk_nope_head_dim)
self.impl.kv_lora_rank) k_pe = torch.randn(num_tokens, self.impl.num_heads,
self.impl.qk_rope_head_dim)
metadata = MagicMock() metadata = MagicMock()
metadata.decode = MagicMock() metadata.decode = MagicMock()
metadata.decode.block_table = MagicMock() metadata.decode.block_table = MagicMock()
metadata.decode.seq_lens = 10 metadata.decode.seq_lens = 10
mock_page_attention_mla.return_value = torch.randn( mock_npu_fused_infer_attention_score.return_value = [
num_tokens, self.impl.num_heads, self.impl.kv_lora_rank) torch.randn(num_tokens, self.impl.num_heads,
self.impl.kv_lora_rank), None
]
mock_up_proj.return_value = torch.randn(num_tokens, mock_up_proj.return_value = torch.randn(num_tokens,
self.impl.num_heads, self.impl.num_heads,
self.impl.v_head_dim) self.impl.v_head_dim)
result = self.impl._forward_decode(q_nope, q_pe, None, None, result = self.impl._forward_decode(q_nope, q_pe, k_nope, k_pe,
kv_c_and_k_pe_cache, metadata) block_size, metadata)
self.assertEqual(result.shape[0], num_tokens) self.assertEqual(result.shape[0], num_tokens)
self.assertEqual(result.shape[1], self.impl.num_heads) self.assertEqual(result.shape[1], self.impl.num_heads)
self.assertEqual(result.shape[2], self.impl.v_head_dim) self.assertEqual(result.shape[2], self.impl.v_head_dim)
mock_up_proj.assert_called_once() mock_up_proj.assert_called_once()
mock_page_attention_mla.assert_called_once() mock_npu_fused_infer_attention_score.assert_called_once()
@patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._forward_prefill")
@patch("torch_npu._npu_reshape_and_cache")
def test_forward_without_graph(self, _, mock_forward_prefill):
num_tokens = 100
num_blocks = 256
block_size = 4
rotary_emb_return_value = (torch.randn(num_tokens, 16,
self.impl.kv_lora_rank),
torch.randn(0, 1, self.impl.kv_lora_rank))
self.impl.rotary_emb.side_effect = lambda *args, **kwargs: rotary_emb_return_value
self.impl.o_proj.side_effect = lambda *args, **kwargs: torch.randn(
1, num_blocks, 128)
hidden_states_or_q_c = torch.randn(num_tokens, self.impl.q_lora_rank)
hidden_states_or_kv_c_normed = torch.randn(num_tokens,
self.impl.kv_lora_rank)
k_pe = torch.randn(num_tokens, self.impl.qk_rope_head_dim)
kv_cache = (torch.randn(num_blocks, block_size, self.impl.num_heads,
self.impl.kv_lora_rank),
torch.randn(num_blocks, block_size, self.impl.num_heads,
self.impl.qk_rope_head_dim))
output = torch.randn(num_tokens, self.impl.num_heads,
self.impl.v_head_dim)
metadata = MagicMock()
metadata.num_decodes = 0
metadata.num_prefills = num_tokens
mock_forward_prefill.return_value = torch.randn(
0, self.impl.num_heads * self.impl.v_head_dim)
result = self.impl.forward(None, hidden_states_or_q_c,
hidden_states_or_kv_c_normed, k_pe,
kv_cache, metadata, output, False)
self.assertEqual(result.shape[0], num_tokens)

View File

@@ -50,6 +50,7 @@ class AscendConfig:
self.enable_shared_expert_dp = additional_config.get( self.enable_shared_expert_dp = additional_config.get(
"enable_shared_expert_dp", False "enable_shared_expert_dp", False
) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel ) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel
self.enable_prefetch = additional_config.get("enable_prefetch", False)
class TorchairGraphConfig: class TorchairGraphConfig:

View File

@@ -1,19 +1,18 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Tuple, Type, TypeVar from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Type, TypeVar
import torch import torch
import torch.nn as nn
import torch_npu import torch_npu
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, from torch import nn
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata, AttentionMetadata,
MLAAttentionImpl) MLAAttentionImpl)
from vllm.config import VllmConfig, get_current_vllm_config from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
from vllm.model_executor.layers.linear import (LinearBase, from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.utils import cdiv, round_down from vllm.utils import cdiv, round_down
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
@@ -22,6 +21,7 @@ from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
from vllm_ascend.multistream.context import get_multistream_comm_context from vllm_ascend.multistream.context import get_multistream_comm_context
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
from vllm_ascend.utils import npu_prefetch
from vllm_ascend.worker.npu_input_batch import InputBatch from vllm_ascend.worker.npu_input_batch import InputBatch
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -184,6 +184,9 @@ class AscendMLAMetadataBuilder:
self.max_blocks = (vllm_config.model_config.max_model_len + self.max_blocks = (vllm_config.model_config.max_model_len +
self.block_size - 1) // self.block_size self.block_size - 1) // self.block_size
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
self.decode_threshold = 1
if self.chunked_prefill_enabled: if self.chunked_prefill_enabled:
self.chunked_prefill_workspace_size = min( self.chunked_prefill_workspace_size = min(
# Max sure there is enough for 8 full length request or at least # Max sure there is enough for 8 full length request or at least
@@ -224,7 +227,7 @@ class AscendMLAMetadataBuilder:
for i, req_id in enumerate(input_batch.req_ids): for i, req_id in enumerate(input_batch.req_ids):
num_tokens = scheduler_output.num_scheduled_tokens[req_id] num_tokens = scheduler_output.num_scheduled_tokens[req_id]
if num_tokens == 1: if num_tokens <= self.decode_threshold:
decodes.append(i) decodes.append(i)
else: else:
prefills.append(i) prefills.append(i)
@@ -270,9 +273,8 @@ class AscendMLAMetadataBuilder:
query_start_loc = common_attn_metadata.query_start_loc query_start_loc = common_attn_metadata.query_start_loc
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
# TODO(xyx): remove the if condition after mla supports torch mode speculative decoding # TODO(xyx): remove the if condition after mla supports torch mode speculative decoding
decode_threshold = 1
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
split_decodes_and_prefills(common_attn_metadata, decode_threshold=decode_threshold) split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold)
assert num_decodes + num_prefills == num_reqs assert num_decodes + num_prefills == num_reqs
assert num_decode_tokens + num_prefill_tokens == num_actual_tokens assert num_decode_tokens + num_prefill_tokens == num_actual_tokens
@@ -312,8 +314,8 @@ class AscendMLAMetadataBuilder:
if num_prefills > 0: if num_prefills > 0:
reqs_start = num_decodes # prefill_start reqs_start = num_decodes # prefill_start
tokens_start = num_decode_tokens tokens_start = num_decode_tokens
max_query_len = query_lens[tokens_start:].max().item() max_query_len = query_lens[reqs_start:].max().item()
max_seq_lens = seq_lens[tokens_start:].max().item() max_seq_lens = seq_lens[reqs_start:].max().item()
prefill_query_start_loc = query_start_loc[ prefill_query_start_loc = query_start_loc[
reqs_start:] - query_start_loc[reqs_start] reqs_start:] - query_start_loc[reqs_start]
@@ -359,9 +361,9 @@ class AscendMLAMetadataBuilder:
1).unsqueeze(2) 1).unsqueeze(2)
prefill_metadata = AscendMLAPrefillMetadata( prefill_metadata = AscendMLAPrefillMetadata(
attn_mask=common_attn_metadata.attn_mask, attn_mask=common_attn_metadata.attn_mask,
query_lens=query_lens[tokens_start:], query_lens=query_lens[reqs_start:],
seq_lens=seq_lens, seq_lens=seq_lens,
context_lens=seq_lens[tokens_start:], context_lens=seq_lens[reqs_start:],
input_positions=prefill_input_positions, input_positions=prefill_input_positions,
block_table=block_table[reqs_start:, ...], block_table=block_table[reqs_start:, ...],
max_query_len=max_query_len, max_query_len=max_query_len,
@@ -416,6 +418,21 @@ class AscendMLAMetadataBuilder:
) )
class DecodeMLAPreprocessResult(NamedTuple):
ql_nope: Optional[torch.Tensor] = None
q_pe: Optional[torch.Tensor] = None
k_nope: Optional[torch.Tensor] = None
k_pe: Optional[torch.Tensor] = None
class PrefillMLAPreprocessResult(NamedTuple):
q_nope: Optional[torch.Tensor] = None
q_pe: Optional[torch.Tensor] = None
k_nope: Optional[torch.Tensor] = None
k_pe: Optional[torch.Tensor] = None
value: Optional[torch.Tensor] = None
class AscendMLAImpl(MLAAttentionImpl): class AscendMLAImpl(MLAAttentionImpl):
""" """
NOTE: Please read the comment at the top of the file before trying to NOTE: Please read the comment at the top of the file before trying to
@@ -455,11 +472,18 @@ class AscendMLAImpl(MLAAttentionImpl):
self.o_proj = kwargs['o_proj'] self.o_proj = kwargs['o_proj']
self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None) self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None) self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
self.q_a_proj = kwargs.get('q_a_proj', None)
self.q_a_layernorm = kwargs.get('q_a_layernorm', None)
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
ascend_config = get_ascend_config() ascend_config = get_ascend_config()
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
self.enable_prefetch = ascend_config.enable_prefetch
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
self.chunked_prefill_for_mla = ascend_config.chunked_prefill_for_mla
self.prefill_mask = None
# Adapt torch air graph mode with spec decoding. # Adapt torch air graph mode with spec decoding.
speculative_config = get_current_vllm_config().speculative_config speculative_config = get_current_vllm_config().speculative_config
@@ -467,7 +491,7 @@ class AscendMLAImpl(MLAAttentionImpl):
self.spec_token_num = speculative_config.num_speculative_tokens self.spec_token_num = speculative_config.num_speculative_tokens
assert self.spec_token_num > 0 assert self.spec_token_num > 0
def _v_up_proj_and_o_proj(self, x): def _v_up_proj(self, x):
# Convert from (B, N, L) to (N, B, L) # Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
# Multiply (N, B, L) x (N, L, V) -> (N, B, V) # Multiply (N, B, L) x (N, L, V) -> (N, B, V)
@@ -546,7 +570,8 @@ class AscendMLAImpl(MLAAttentionImpl):
def _compute_prefill_context( def _compute_prefill_context(
self, self,
query: torch.Tensor, q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: Tuple[torch.Tensor], kv_c_and_k_pe_cache: Tuple[torch.Tensor],
rope_dim: int, rope_dim: int,
attn_metadata: AscendMLAMetadata, attn_metadata: AscendMLAMetadata,
@@ -559,8 +584,6 @@ class AscendMLAImpl(MLAAttentionImpl):
return prefix_output, prefix_lse return prefix_output, prefix_lse
iters = len(prefill_metadata.chunked_context.seq_tot) iters = len(prefill_metadata.chunked_context.seq_tot)
q_pe = query[..., self.qk_nope_head_dim:]
q_nope = query[..., :self.qk_nope_head_dim]
seq_len1 = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32) seq_len1 = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32)
cache_kv_c = kv_c_and_k_pe_cache[0] cache_kv_c = kv_c_and_k_pe_cache[0]
@@ -575,19 +598,19 @@ class AscendMLAImpl(MLAAttentionImpl):
kv_c_normed = torch.empty(toks, kv_c_normed = torch.empty(toks,
num_heads, num_heads,
latent_kv_dim, latent_kv_dim,
dtype=query.dtype, dtype=q_nope.dtype,
device=query.device) device=q_nope.device)
k_pe = torch.empty(toks, k_pe = torch.empty(toks,
num_heads, num_heads,
rope_dim, rope_dim,
dtype=query.dtype, dtype=q_nope.dtype,
device=query.device) device=q_nope.device)
torch_npu.atb.npu_paged_cache_load( torch_npu.atb.npu_paged_cache_load(
cache_kv_c, cache_kv_c,
cache_k_pe, cache_k_pe,
prefill_metadata.block_table, prefill_metadata.block_table,
seq_len2.to(query.device), seq_len2.to(q_nope.device),
seq_starts=prefill_metadata.chunked_context.starts[i], seq_starts=prefill_metadata.chunked_context.starts[i],
key=kv_c_normed, key=kv_c_normed,
value=k_pe, value=k_pe,
@@ -599,16 +622,13 @@ class AscendMLAImpl(MLAAttentionImpl):
k_nope, v = kv_nope\ k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k_pe = k_pe.expand((*k_nope.shape[:-1], -1)) k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
mask = torch.triu(
torch.ones(512, 512, device=query.device, dtype=query.dtype),
1)
torch_npu.atb.npu_ring_mla( torch_npu.atb.npu_ring_mla(
q_nope=q_nope, q_nope=q_nope,
q_rope=q_pe, q_rope=q_pe,
k_nope=k_nope, k_nope=k_nope,
k_rope=k_pe, k_rope=k_pe,
value=v, value=v,
mask=mask, mask=self.prefill_mask,
seqlen=seq_len, seqlen=seq_len,
head_num=self.num_heads, head_num=self.num_heads,
kv_head_num=self.num_heads, kv_head_num=self.num_heads,
@@ -625,33 +645,74 @@ class AscendMLAImpl(MLAAttentionImpl):
def _forward_prefill( def _forward_prefill(
self, self,
query: torch.Tensor, q_nope: torch.Tensor,
kv_c_normed: torch.Tensor, q_pe: torch.Tensor,
k_nope: torch.Tensor,
k_pe: torch.Tensor, k_pe: torch.Tensor,
value: torch.Tensor,
kv_c_and_k_pe_cache: Tuple[torch.Tensor], kv_c_and_k_pe_cache: Tuple[torch.Tensor],
attn_metadata: AscendMLAMetadata, attn_metadata: AscendMLAMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
assert attn_metadata.prefill is not None assert attn_metadata.prefill is not None
assert len(kv_c_and_k_pe_cache) > 1 assert len(kv_c_and_k_pe_cache) > 1
num_tokens = q_nope.size(0)
num_tokens = query.size(0)
attn_output = torch.empty(num_tokens, attn_output = torch.empty(num_tokens,
self.num_heads, self.num_heads,
self.v_head_dim, self.v_head_dim,
dtype=query.dtype, dtype=q_nope.dtype,
device=query.device) device=q_nope.device)
k_nope, value = self.kv_b_proj(kv_c_normed)[0].view( if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).split( query = torch.cat((q_nope, q_pe), dim=-1)
[self.qk_nope_head_dim, self.v_head_dim], dim=-1) key = torch.cat((k_nope, k_pe), dim=-1)
k_pe = k_pe.expand((*k_nope.shape[:-1], -1)) torch_npu._npu_flash_attention(
# Here is only 2 possibility of input, ChunkedPrefill or PrefillNoCache query=query,
ascend_config = get_ascend_config() key=key,
value=value,
if attn_metadata.attn_state in [ mask=attn_metadata.attn_mask,
AscendAttentionState.ChunkedPrefill, seq_len=attn_metadata.prefill.context_lens,
AscendAttentionState.SpecDecoding, scale_value=self.scale,
AscendAttentionState.PrefillCacheHit num_heads=self.num_heads,
] and not ascend_config.chunked_prefill_for_mla: num_kv_heads=self.num_heads,
out=attn_output)
elif self.chunked_prefill_for_mla:
attn_lse = torch.empty(self.num_heads,
num_tokens,
dtype=torch.float32,
device=q_nope.device)
if self.prefill_mask is None:
self.prefill_mask = torch.triu(
torch.ones(512,
512,
device=q_nope.device,
dtype=q_nope.dtype),
1) # 512: mask only support 512
if attn_metadata.num_prefills > 1:
self.prefill_mask = self.prefill_mask.unsqueeze(0).repeat(
attn_metadata.num_prefills, 1, 1)
torch_npu.atb.npu_ring_mla(
q_nope=q_nope,
q_rope=q_pe,
k_nope=k_nope,
k_rope=k_pe,
value=value,
mask=self.prefill_mask,
seqlen=torch.tensor(attn_metadata.prefill.query_lens,
dtype=torch.int32),
head_num=self.num_heads,
kv_head_num=self.num_heads,
pre_out=None,
prev_lse=None,
qk_scale=self.scale,
kernel_type="kernel_type_high_precision",
mask_type="mask_type_triu",
input_layout="type_bsnd",
calc_type="calc_type_first_ring",
output=attn_output,
softmax_lse=attn_lse)
attn_output, attn_lse = self._compute_prefill_context( \
q_nope, q_pe, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse)
else:
query = torch.cat((q_nope, q_pe), dim=-1)
attn_output_torch = torch.empty(num_tokens, attn_output_torch = torch.empty(num_tokens,
self.num_heads * self.v_head_dim, self.num_heads * self.v_head_dim,
dtype=query.dtype, dtype=query.dtype,
@@ -673,240 +734,318 @@ class AscendMLAImpl(MLAAttentionImpl):
scale=self.scale, scale=self.scale,
alibi_slopes=None, alibi_slopes=None,
causal=True) causal=True)
elif attn_metadata.attn_state in [
AscendAttentionState.ChunkedPrefill,
AscendAttentionState.SpecDecoding,
AscendAttentionState.PrefillCacheHit
]:
attn_lse = torch.empty(self.num_heads,
num_tokens,
dtype=torch.float32,
device=query.device)
q_pe = query[..., self.qk_nope_head_dim:]
q_nope = query[..., :self.qk_nope_head_dim]
mask = torch.triu(
torch.ones(512, 512, device=query.device, dtype=query.dtype),
1) # 512: mask only support 512
if attn_metadata.num_prefills > 1:
mask = mask.unsqueeze(0).repeat(attn_metadata.num_prefills, 1,
1)
torch_npu.atb.npu_ring_mla(
q_nope=q_nope,
q_rope=q_pe,
k_nope=k_nope,
k_rope=k_pe,
value=value,
mask=mask,
seqlen=torch.tensor(attn_metadata.prefill.query_lens,
dtype=torch.int32),
head_num=self.num_heads,
kv_head_num=self.num_heads,
pre_out=None,
prev_lse=None,
qk_scale=self.scale,
kernel_type="kernel_type_high_precision",
mask_type="mask_type_triu",
input_layout="type_bsnd",
calc_type="calc_type_first_ring",
output=attn_output,
softmax_lse=attn_lse)
attn_output, attn_lse = self._compute_prefill_context( \
query, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse)
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
key = torch.cat((k_nope, k_pe), dim=-1)
torch_npu._npu_flash_attention(
query=query,
key=key,
value=value,
mask=attn_metadata.attn_mask,
seq_len=attn_metadata.prefill.context_lens,
scale_value=self.scale,
num_heads=self.num_heads,
num_kv_heads=self.num_heads,
out=attn_output)
attn_output = attn_output.view(-1, self.num_heads, self.v_head_dim)
else:
raise RuntimeError(
"Unexpected path reached, AscendMLAImpl should only have PrefillNoCache, PrefillCacheHit, ChunkedPrefill and SpecDecoding scenario in forward prefill, please file a bug to vllm-ascend !"
)
attn_output = attn_output.reshape( attn_output = attn_output.reshape(
[num_tokens, self.num_heads * self.v_head_dim]) [num_tokens, self.num_heads * self.v_head_dim])
if attn_metadata.attn_state in [ if attn_metadata.attn_state in [
AscendAttentionState.ChunkedPrefill, AscendAttentionState.ChunkedPrefill,
AscendAttentionState.SpecDecoding, AscendAttentionState.SpecDecoding,
AscendAttentionState.PrefillCacheHit AscendAttentionState.PrefillCacheHit
] and not ascend_config.chunked_prefill_for_mla: ] and not self.chunked_prefill_for_mla:
attn_output = attn_output_torch attn_output = attn_output_torch
return attn_output return attn_output
def exec_kv_decode(
self,
kv_no_split: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
kv_cache: Tuple,
slots: torch.Tensor,
):
B = kv_no_split.shape[0]
N = self.num_kv_heads
S = 1
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
kv_no_split = kv_no_split.view(
B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
cache_mode = "PA_NZ" if self.enable_kv_nz else "PA"
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
kv_no_split,
self.kv_a_layernorm.weight,
cos,
sin,
slots.to(torch.int64),
kv_cache[1],
kv_cache[0],
epsilon=self.kv_a_layernorm.variance_epsilon,
cache_mode=cache_mode,
)
return k_pe, k_nope
def exec_kv_prefill(
self,
kv_no_split: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
kv_cache: Tuple,
slots: torch.Tensor,
):
B = kv_no_split.shape[0]
N = self.num_kv_heads
S = 1
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
kv_no_split = kv_no_split.view(
B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
cache_mode = "PA_BLK_NZ" if self.enable_kv_nz else "PA"
_, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache(
kv_no_split,
self.kv_a_layernorm.weight,
cos,
sin,
slots.to(torch.int64),
kv_cache[1],
kv_cache[0],
epsilon=self.kv_a_layernorm.variance_epsilon,
cache_mode=cache_mode,
is_output_kv=True,
)
return k_pe, k_nope
def rope_single(
self,
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
B, N, D = x.shape
S = 1
x = x.view(B, N, S, D)
x = torch_npu.npu_interleave_rope(x, cos, sin)
return x.view(B, N, D)
def _forward_decode( def _forward_decode(
self, self,
q_nope: torch.Tensor, q_nope: torch.Tensor,
q_pe: torch.Tensor, q_pe: torch.Tensor,
k_nope: torch.Tensor, k_nope: torch.Tensor,
k_pe: torch.Tensor, k_pe: torch.Tensor,
kv_c_and_k_pe_cache: Tuple[torch.Tensor], block_size: int,
attn_metadata: AscendMLAMetadata, attn_metadata: AscendMLAMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
decode_meta = attn_metadata.decode decode_meta = attn_metadata.decode
assert decode_meta is not None assert decode_meta is not None
num_tokens = q_nope.size(0) num_tokens = q_nope.size(0)
# The MLA_PA path will be used as default path in the future, `_npu_paged_attention_mla` will # shape of knope/k_pe for npu graph mode should be:
# be removed after the torch_npu contains `torch_npu.atb.npu_multi_head_latent_attention` become # [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
# public available actual_seq_lengths = None
assert len(kv_c_and_k_pe_cache) > 1 if self.enable_kv_nz:
if envs_ascend.VLLM_ASCEND_MLA_PA: k_nope = k_nope.view(-1, self.num_kv_heads,
attn_output = torch_npu.atb.npu_multi_head_latent_attention( self.kv_lora_rank // 16, block_size, 16)
q_nope, q_pe, kv_c_and_k_pe_cache[0], kv_c_and_k_pe_cache[1], k_pe = k_pe.view(-1, self.num_kv_heads,
attn_metadata.decode.block_table, self.qk_rope_head_dim // 16, block_size, 16)
attn_metadata.decode.seq_lens, self.num_heads, self.scale, input_layout = "BSND"
self.num_kv_heads)
else: else:
q = torch.cat([q_nope, q_pe], dim=-1) k_nope = k_nope.view(-1, self.num_kv_heads, block_size,
attn_output = torch.empty( self.kv_lora_rank)
[num_tokens, self.num_heads, self.kv_lora_rank], k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
dtype=q.dtype, self.qk_rope_head_dim)
device=q.device) input_layout = "BNSD"
k_cache = torch.cat(
[kv_c_and_k_pe_cache[0], kv_c_and_k_pe_cache[1]], dim=-1) if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
torch_npu._npu_paged_attention_mla( assert num_tokens % self.spec_token_num == 0
query=q, input_layout = "TND"
key_cache=k_cache, # [bs * q_seq_len, num_heads_per_rank, dim]
num_kv_heads=self.num_kv_heads, q_nope = q_nope.view(num_tokens, self.num_heads, -1)
num_heads=self.num_heads, q_pe = q_pe.view(num_tokens, self.num_heads, -1)
scale_value=self.scale, sparse_mode = 3
block_table=attn_metadata.decode.block_table, # type:ignore spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore
context_lens=attn_metadata.decode.seq_lens, # type:ignore actual_seq_lengths = decode_meta.actual_seq_lengths_q
mla_vheadsize=self.kv_lora_rank, else:
out=attn_output) if self.enable_kv_nz:
q_nope = q_nope.view(num_tokens, 1, self.num_heads, -1)
q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1)
else:
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
sparse_mode = 0
spec_attn_mask = None
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
q_nope,
k_nope,
k_nope,
query_rope=q_pe,
key_rope=k_pe,
num_heads=self.num_heads,
num_key_value_heads=self.num_kv_heads,
input_layout=input_layout,
atten_mask=spec_attn_mask,
sparse_mode=sparse_mode,
scale=self.scale,
antiquant_mode=0,
antiquant_scale=None,
block_table=decode_meta.block_table,
block_size=block_size,
actual_seq_lengths_kv=decode_meta.seq_lens_list,
actual_seq_lengths=actual_seq_lengths)
current_ms_metadata = get_multistream_comm_context() current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is None: if current_ms_metadata is None:
return self._v_up_proj_and_o_proj(attn_output) return self._v_up_proj(attn_output)
else: else:
current_ms_metadata.before_comm_event.record() current_ms_metadata.before_comm_event.record()
with torch.npu.stream(current_ms_metadata.comm_stream): with torch.npu.stream(current_ms_metadata.comm_stream):
current_ms_metadata.before_comm_event.wait() current_ms_metadata.before_comm_event.wait()
return self._v_up_proj_and_o_proj(attn_output) return self._v_up_proj(attn_output)
def _mla_preprocess(self, hidden_states, kv_cache, attn_metadata,
need_gather_q_kv):
# MLA Preprocess:
# 1. Perform q_a_proj and q_a_layernorm to obtain q_c
# 2. Perform kv_a_proj_with_mqa to obtain kv_no_split
# 3. If need_gather_q_kv, perform all_gather.
# 4. Preprocess decode tokens, write kv cache and get:
# decode_ql_nope, decode_q_pe, decode_k_pe, decode_k_nope
# 5. Preprocess prefill tokens, write kv cache and get:
# prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe, prefill_value
has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens
num_actual_tokens = attn_metadata.num_actual_tokens
if self.q_a_proj is not None:
npu_prefetch(self.q_a_proj.weight,
hidden_states,
enabled=self.enable_prefetch)
ckq = self.q_a_proj(hidden_states)[0]
q_c = self.q_a_layernorm(ckq)
else:
q_c = hidden_states
kv_no_split = self.kv_a_proj_with_mqa(hidden_states)[0]
# Process for shared_expert_dp
if need_gather_q_kv:
q_c = get_tp_group().all_gather(q_c, 0)
kv_no_split = get_tp_group().all_gather(kv_no_split, 0)
decode_preprocess_res = None
prefill_preprocess_res = None
# Preprocess for decode tokens
if has_decode:
decode_q_c = q_c[:num_decode_tokens]
cos = attn_metadata.decode.cos
sin = attn_metadata.decode.sin
decode_ql_nope, decode_q_pe = \
self._q_proj_and_k_up_proj(decode_q_c)
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
decode_slots = attn_metadata.slot_mapping[:num_decode_tokens]
decode_kv_no_split = kv_no_split[:num_decode_tokens]
decode_k_pe, decode_k_nope = self.exec_kv_decode(
decode_kv_no_split, cos, sin, kv_cache, decode_slots)
decode_preprocess_res = DecodeMLAPreprocessResult(
decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe)
# Preprocess for prefill tokens
if has_prefill:
prefill_kv_no_split = kv_no_split[
num_decode_tokens:num_actual_tokens]
prefill_q_c = q_c[num_decode_tokens:num_actual_tokens]
prefill_q = self.q_proj(prefill_q_c)[0]\
.view(-1, self.num_heads, self.qk_head_dim)
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim]
cos = attn_metadata.prefill.cos
sin = attn_metadata.prefill.sin
prefill_slots = attn_metadata.slot_mapping[
num_decode_tokens:num_actual_tokens]
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
prefill_k_pe, prefill_k_c_normed = self.exec_kv_prefill(
prefill_kv_no_split, cos, sin, kv_cache, prefill_slots)
prefill_k_pe = prefill_k_pe.view(prefill_q_c.shape[0],
self.num_kv_heads, -1)
prefill_k_nope, prefill_value = self.kv_b_proj(
prefill_k_c_normed)[0].view(
-1, self.num_heads,
self.qk_nope_head_dim + self.v_head_dim).split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
prefill_k_pe = prefill_k_pe.expand(
(*prefill_k_nope.shape[:-1], -1))
prefill_preprocess_res = PrefillMLAPreprocessResult(
prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe,
prefill_value)
return decode_preprocess_res, prefill_preprocess_res
def forward( def forward(
self, self,
layer: AttentionLayer, hidden_states: torch.Tensor, # query in unified attn
hidden_states_or_q_c: torch.Tensor, # query in unified attn
hidden_states_or_kv_c_normed: torch.Tensor, # key in unified attn
k_pe: torch.Tensor, # value in unified attn
kv_cache: Tuple[torch.Tensor], kv_cache: Tuple[torch.Tensor],
attn_metadata: M, attn_metadata: M,
need_gather_q_kv: bool = False,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
ckq: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert output is not None, "Output tensor must be provided." assert output is not None, "Output tensor must be provided."
if attn_metadata is None: if attn_metadata is None:
# Profiling run. # Profiling run.
return output return output
num_actual_toks = attn_metadata.num_actual_tokens num_actual_tokens = attn_metadata.num_actual_tokens
if k_pe is None:
kv_c, k_pe = self.kv_a_proj_with_mqa(
hidden_states_or_kv_c_normed)[0].split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
else:
kv_c_normed = hidden_states_or_kv_c_normed
assert attn_metadata.num_decodes is not None and \ assert attn_metadata.num_decodes is not None and \
attn_metadata.num_prefills is not None and \ attn_metadata.num_prefills is not None and \
attn_metadata.num_decode_tokens is not None attn_metadata.num_decode_tokens is not None
has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens num_decode_tokens = attn_metadata.num_decode_tokens
# Inputs and outputs may be padded for CUDA graphs # Inputs and outputs may be padded for CUDA graphs
output_padded = output output_padded = output
output = output[:num_actual_toks, ...] output = output[:num_actual_tokens, ...]
kv_c_normed = kv_c_normed[:num_actual_toks, ...] o_proj_input_shape = (num_actual_tokens,
prefill_k_c_normed = kv_c_normed[num_decode_tokens:]
hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...]
prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:]
decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
k_pe = k_pe[:num_actual_toks, ...]
k_pe = k_pe.unsqueeze(1)
decode_k_pe = k_pe[:num_decode_tokens]
prefill_k_pe = k_pe[num_decode_tokens:]
if has_decode:
decode_k_nope = None
assert attn_metadata.decode is not None
decode_ql_nope, decode_q_pe = \
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
attn_metadata.decode.input_positions,
decode_q_pe.contiguous(),
decode_k_pe,
max_seq_len=attn_metadata.decode.max_seq_lens)
if has_prefill:
assert attn_metadata.prefill is not None
prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\
.view(-1, self.num_heads, self.qk_head_dim)
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
attn_metadata.prefill.input_positions,
prefill_q_pe.contiguous(),
prefill_k_pe,
max_seq_len=attn_metadata.prefill.max_seq_lens)
assert len(
kv_cache
) > 1, "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)"
kv_c_normed = kv_c_normed.view(
[num_actual_toks, self.num_kv_heads, -1])
torch_npu._npu_reshape_and_cache(
key=kv_c_normed,
value=k_pe,
key_cache=kv_cache[0],
value_cache=kv_cache[1],
slot_indices=attn_metadata.slot_mapping)
o_proj_input_shape = (num_actual_toks,
self.num_heads * self.v_head_dim) self.num_heads * self.v_head_dim)
o_proj_input = torch.empty(o_proj_input_shape, o_proj_input = torch.empty(o_proj_input_shape,
dtype=hidden_states_or_q_c.dtype, dtype=hidden_states.dtype,
device=hidden_states_or_q_c.device) device=hidden_states.device)
if has_prefill:
# FIX: aicore move should be also placed on the comm stream in dbo,
# otherwise it may affect the accuracy
# TODO: use an elegant way to overlap
output_prefill = self._forward_prefill(prefill_q,
prefill_k_c_normed,
prefill_k_pe, kv_cache,
attn_metadata)
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is not None:
current_ms_metadata.before_comm_event.record()
with torch.npu.stream(current_ms_metadata.comm_stream):
current_ms_metadata.before_comm_event.wait()
o_proj_input[num_decode_tokens:] = output_prefill
else:
o_proj_input[num_decode_tokens:] = output_prefill
if has_decode: # MLA Preprocess
output_decode = self._forward_decode(decode_ql_nope, decode_q_pe, decode_preprocess_res, prefill_preprocess_res = self._mla_preprocess(
decode_k_nope, decode_k_pe, hidden_states, kv_cache, attn_metadata, need_gather_q_kv)
kv_cache, attn_metadata)
if decode_preprocess_res is not None:
# MLA Preprocess for decoding
output_decode = self._forward_decode(decode_preprocess_res.ql_nope,
decode_preprocess_res.q_pe,
decode_preprocess_res.k_nope,
decode_preprocess_res.k_pe,
kv_cache[0].shape[1],
attn_metadata)
current_ms_metadata = get_multistream_comm_context() current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is not None: if current_ms_metadata is not None:
with torch.npu.stream(current_ms_metadata.comm_stream): with torch.npu.stream(current_ms_metadata.comm_stream):
o_proj_input[:num_decode_tokens] = output_decode o_proj_input[:num_decode_tokens] = output_decode
current_ms_metadata.after_comm_event.record()
else: else:
o_proj_input[:num_decode_tokens] = output_decode o_proj_input[:num_decode_tokens] = output_decode
if prefill_preprocess_res is not None:
# FIX: aicore move should be also placed on the comm stream in dbo,
# otherwise it may affect the accuracy
# TODO: use an elegant way to overlap
output_prefill = self._forward_prefill(
prefill_preprocess_res.q_nope, prefill_preprocess_res.q_pe,
prefill_preprocess_res.k_nope, prefill_preprocess_res.k_pe,
prefill_preprocess_res.value, kv_cache, attn_metadata)
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is not None:
with torch.npu.stream(current_ms_metadata.comm_stream):
o_proj_input[num_decode_tokens:] = output_prefill
current_ms_metadata.after_comm_event.record()
else:
o_proj_input[num_decode_tokens:] = output_prefill
# O proj
current_ms_metadata = get_multistream_comm_context() current_ms_metadata = get_multistream_comm_context()
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024
if current_ms_metadata is None: if current_ms_metadata is None:
npu_prefetch(self.o_proj.weight,
o_proj_input,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=self.enable_prefetch)
output[...] = self.o_proj( output[...] = self.o_proj(
o_proj_input, o_proj_input,
is_prefill=True, is_prefill=prefill_preprocess_res is not None,
is_force_scatter=self.enable_shared_expert_dp)[0] is_force_scatter=self.enable_shared_expert_dp)[0]
else: else:
with torch.npu.stream(current_ms_metadata.comm_stream): with torch.npu.stream(current_ms_metadata.comm_stream):
npu_prefetch(self.o_proj.weight,
o_proj_input,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=self.enable_prefetch)
output[...] = self.o_proj( output[...] = self.o_proj(
o_proj_input, o_proj_input,
is_prefill=True, is_prefill=prefill_preprocess_res is not None,
is_force_scatter=self.enable_shared_expert_dp)[0] is_force_scatter=self.enable_shared_expert_dp)[0]
current_ms_metadata.after_comm_event.record() current_ms_metadata.after_comm_event.record()
del o_proj_input del o_proj_input

View File

@@ -37,7 +37,6 @@ from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
get_tp_group, split_tensor_along_last_dim, get_tp_group, split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
tensor_model_parallel_reduce_scatter) tensor_model_parallel_reduce_scatter)
from vllm.distributed.parallel_state import get_dp_group, get_ep_group from vllm.distributed.parallel_state import get_dp_group, get_ep_group
@@ -73,7 +72,7 @@ from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ops.fused_moe import AscendFusedMoE from vllm_ascend.ops.fused_moe import AscendFusedMoE
from vllm_ascend.quantization.quant_config import AscendLinearMethod from vllm_ascend.quantization.quant_config import AscendLinearMethod
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
from vllm_ascend.utils import dispose_tensor, npu_prefetch from vllm_ascend.utils import dispose_tensor
class CustomDeepseekV2SiluAndMul(SiluAndMul): class CustomDeepseekV2SiluAndMul(SiluAndMul):
@@ -471,9 +470,6 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
self.debug_layer_idx = int(self.prefix.split(".")[-2]) self.debug_layer_idx = int(self.prefix.split(".")[-2])
ascend_config = get_ascend_config() ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.enable_multistream_mla = \
ascend_config.torchair_graph_config.enable_multistream_mla
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
if self.q_lora_rank is not None: if self.q_lora_rank is not None:
@@ -515,8 +511,7 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
if (config.n_routed_experts is not None if (config.n_routed_experts is not None
and self.debug_layer_idx >= config.first_k_dense_replace and self.debug_layer_idx >= config.first_k_dense_replace
and self.debug_layer_idx % config.moe_layer_freq == 0 and self.debug_layer_idx % config.moe_layer_freq == 0
and (ascend_config.torchair_graph_config.enable_multistream_moe and self.enable_shared_expert_dp):
or self.enable_shared_expert_dp)):
self.o_proj = CustomDeepseekV2RowParallelLinearReplaceAllreduce( self.o_proj = CustomDeepseekV2RowParallelLinearReplaceAllreduce(
self.num_heads * self.v_head_dim, self.num_heads * self.v_head_dim,
self.hidden_size, self.hidden_size,
@@ -568,6 +563,9 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
qk_head_dim=self.qk_head_dim, qk_head_dim=self.qk_head_dim,
v_head_dim=self.v_head_dim, v_head_dim=self.v_head_dim,
rotary_emb=self.rotary_emb, rotary_emb=self.rotary_emb,
q_a_proj=self.q_a_proj if self.q_lora_rank is not None else None,
q_a_layernorm=self.q_a_layernorm
if self.q_lora_rank is not None else None,
q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj, q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj,
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa, kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
kv_a_layernorm=self.kv_a_layernorm, kv_a_layernorm=self.kv_a_layernorm,
@@ -582,55 +580,29 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
kv_cache: Optional[torch.Tensor] = None, kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
forward_context = get_forward_context() forward_context = get_forward_context()
enable_multistream_mla = (self.enable_multistream_mla if kv_cache is None:
and attn_metadata is not None kv_cache = self.mla_attn.kv_cache[forward_context.virtual_engine]
and not forward_context.with_prefill num_tokens = hidden_states.shape[0]
and attn_metadata.num_decodes > 0) need_gather_q_kv = False
forward_kwargs = {"enable_multistream_mla": enable_multistream_mla} if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers:
if self.q_lora_rank is not None: # Simulate all gather to calculate output shape
npu_prefetch(self.q_a_proj.weight, num_tokens = num_tokens * self.tp_size
hidden_states, need_gather_q_kv = True
enabled=enable_multistream_mla) if not self.enable_shared_expert_dp or self.debug_layer_idx < self.first_k_dense_replace:
ckq = self.q_a_proj(hidden_states)[0]
hidden_states_or_q_c = self.q_a_layernorm(ckq)
forward_kwargs['ckq'] = ckq
else:
hidden_states_or_q_c = hidden_states
if self.torchair_graph_enabled:
output_shape = hidden_states.shape output_shape = hidden_states.shape
output = torch.empty(output_shape,
dtype=hidden_states_or_q_c.dtype,
device=hidden_states_or_q_c.device)
forward_kwargs['output'] = output
output = self.mla_attn.impl.forward(self.mla_attn,
hidden_states_or_q_c,
hidden_states, None, kv_cache,
attn_metadata,
**forward_kwargs)
output = output.view(-1, output_shape[-1])
return output
else: else:
kv_no_split = self.kv_a_proj_with_mqa(hidden_states)[0] rows = num_tokens // self.tp_size
if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers: if num_tokens % self.tp_size:
hidden_states_or_q_c = get_tp_group().all_gather( rows += 1
hidden_states_or_q_c, 0) output_shape = (rows, hidden_states.shape[1])
kv_no_split = get_tp_group().all_gather(kv_no_split, 0) output = torch.empty(output_shape,
dtype=hidden_states.dtype,
kv_c, k_pe = kv_no_split.split( device=hidden_states.device)
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) output = self.mla_attn.impl.forward(hidden_states, kv_cache,
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) forward_context.attn_metadata,
if not self.enable_shared_expert_dp or self.debug_layer_idx < self.first_k_dense_replace: need_gather_q_kv, output)
output_shape = hidden_states.shape output = output.view(-1, output_shape[-1])
else: return output
num_tokens = hidden_states_or_q_c.shape[0]
rows = num_tokens // self.tp_size
if num_tokens % self.tp_size:
rows += 1
output_shape = (rows, hidden_states.shape[1])
return self.mla_attn(hidden_states_or_q_c,
kv_c_normed,
k_pe,
output_shape=output_shape)
class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
@@ -688,8 +660,6 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
) )
self.mla_moe_communication = ascend_config.torchair_graph_config.enable_multistream_moe \
and model_config.use_mla and self.tp_size > 1
else: else:
self.mlp = CustomDeepseekV2MLP( self.mlp = CustomDeepseekV2MLP(
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
@@ -698,7 +668,6 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
) )
self.mla_moe_communication = False
self.input_layernorm = RMSNorm(config.hidden_size, self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size, self.post_attention_layernorm = RMSNorm(config.hidden_size,
@@ -718,10 +687,6 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
replace_allreduce: bool = False, replace_allreduce: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
# Self Attention # Self Attention
if attn_metadata is not None and attn_metadata.num_decodes > 0:
mla_moe_communication = self.mla_moe_communication and replace_allreduce
else:
mla_moe_communication = False
if residual is None: if residual is None:
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
@@ -733,9 +698,6 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
# to save npu memory because they're no longer used. # to save npu memory because they're no longer used.
dispose_tensor(previous_hidden_states) dispose_tensor(previous_hidden_states)
dispose_tensor(previous_residual) dispose_tensor(previous_residual)
if mla_moe_communication and self.layer_idx > self.first_k_dense_replace:
hidden_states = tensor_model_parallel_all_gather(hidden_states,
dim=0)
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
@@ -744,13 +706,6 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
) )
if mla_moe_communication and residual.shape[0] != hidden_states.shape[
0]:
chunk_hidden_states = torch.tensor_split(residual,
self.tp_size,
dim=0)
residual = chunk_hidden_states[self.tp_rank]
if hidden_states.dtype == torch.float16: if hidden_states.dtype == torch.float16:
# Fix FP16 overflow # Fix FP16 overflow
# We scale both hidden_states and residual before # We scale both hidden_states and residual before
@@ -778,9 +733,7 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
hidden_states, residual) hidden_states, residual)
if isinstance(self.mlp, CustomDeepseekV2MoE): if isinstance(self.mlp, CustomDeepseekV2MoE):
hidden_states = self.mlp(hidden_states, hidden_states = self.mlp(hidden_states, attn_metadata)
attn_metadata,
replace_allreduce=mla_moe_communication)
else: else:
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states)
@@ -793,10 +746,6 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
# The scaling of DeepseekV2MOE output would be done in the forward # The scaling of DeepseekV2MOE output would be done in the forward
# of DeepseekV2MOE # of DeepseekV2MOE
hidden_states *= 1. / self.routed_scaling_factor hidden_states *= 1. / self.routed_scaling_factor
if mla_moe_communication and self.layer_idx == self.layers - 1:
hidden_states = tensor_model_parallel_all_gather(hidden_states,
dim=0)
residual = tensor_model_parallel_all_gather(residual, dim=0)
# for last layer of main model and mtp layer. # for last layer of main model and mtp layer.
if self.enable_shared_expert_dp and self.layer_idx >= ( if self.enable_shared_expert_dp and self.layer_idx >= (