[Refactor][WIP] Refactor mla_v1 by moving all MLA preprocessing ops into mla_v1 attention impl (#2465)
### What this PR does / why we need it?
In order to support fused kernels, multi-stream, communication
optimization etc, it's better to aggregate all opreations in Attention
layer togather. This PR tries to refactor mla_v1 by moving all MLA
preprocessing ops into mla_v1 attention impl.
Note that new mla_v1 doesn't take torchair into consideration. So this
PR can only be merged after torchair related mla_v1 is isolated into a
new file.
### Does this PR introduce _any_ user-facing change?
NO
### How was this patch tested?
### Features Test
<img width="506" height="141" alt="image"
src="https://github.com/user-attachments/assets/f1ab2906-a1ac-4450-8433-94811cd89466"
/>
### Performance After Refact
<img width="648" height="486" alt="image"
src="https://github.com/user-attachments/assets/e33e038c-c5d9-4ba7-a8e9-1ac22f9833eb"
/>
### Performance Before Refact
<img width="618" height="494" alt="image"
src="https://github.com/user-attachments/assets/83861dc2-dc51-4af3-9310-90ab10c43bb1"
/>
- vLLM version: v0.10.1.1
- vLLM main:
e03940762b
---------
Signed-off-by: lwq <liwenquan5@huawei.com>
Signed-off-by: whx-sjtu <2952154980@qq.com>
Signed-off-by: SunnyLee219 <3294305115@qq.com>
Co-authored-by: lwq <liwenquan5@huawei.com>
Co-authored-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
@@ -31,6 +31,7 @@ The following table lists the additional configuration options available in vLLM
|
|||||||
| `refresh` | bool | `false` | Whether to refresh global ascend config content. This value is usually used by rlhf or ut/e2e test case. |
|
| `refresh` | bool | `false` | Whether to refresh global ascend config content. This value is usually used by rlhf or ut/e2e test case. |
|
||||||
| `expert_map_path` | str | `None` | When using expert load balancing for the MOE model, an expert map path needs to be passed in. |
|
| `expert_map_path` | str | `None` | When using expert load balancing for the MOE model, an expert map path needs to be passed in. |
|
||||||
| `chunked_prefill_for_mla` | bool | `False` | Whether to enable the fused operator-like chunked_prefill. |
|
| `chunked_prefill_for_mla` | bool | `False` | Whether to enable the fused operator-like chunked_prefill. |
|
||||||
|
| `enable_prefetch` | bool | `False` | Whether to enable weight prefetch. |
|
||||||
| `kv_cache_dtype` | str | `None` | When using the kv cache quantization method, kv cache dtype needs to be set, currently only int8 is supported. |
|
| `kv_cache_dtype` | str | `None` | When using the kv cache quantization method, kv cache dtype needs to be set, currently only int8 is supported. |
|
||||||
| `enable_shared_expert_dp` | bool | `False` | When the shared expert in DP, it has better performance but consumes more memory. Currently only DeepSeek series models are supported to use. |
|
| `enable_shared_expert_dp` | bool | `False` | When the shared expert in DP, it has better performance but consumes more memory. Currently only DeepSeek series models are supported to use. |
|
||||||
|
|
||||||
|
|||||||
@@ -210,6 +210,7 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
|||||||
with patch("vllm_ascend.attention.mla_v1.get_ascend_config",
|
with patch("vllm_ascend.attention.mla_v1.get_ascend_config",
|
||||||
return_value=ascend_config):
|
return_value=ascend_config):
|
||||||
builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device)
|
builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device)
|
||||||
|
builder.decode_threshold = 1
|
||||||
|
|
||||||
input_batch = MagicMock()
|
input_batch = MagicMock()
|
||||||
input_batch.req_ids = [0, 1, 2, 3]
|
input_batch.req_ids = [0, 1, 2, 3]
|
||||||
@@ -303,18 +304,16 @@ class TestAscendMLAImpl(TestBase):
|
|||||||
self.assertEqual(self.impl.num_queries_per_kv, 32)
|
self.assertEqual(self.impl.num_queries_per_kv, 32)
|
||||||
self.assertEqual(self.impl.tp_size, 2)
|
self.assertEqual(self.impl.tp_size, 2)
|
||||||
|
|
||||||
def test_v_up_proj_and_o_proj(self):
|
def test_v_up_proj(self):
|
||||||
batch_size = 4
|
batch_size = 4
|
||||||
x = torch.randn(batch_size, self.impl.num_heads,
|
x = torch.randn(batch_size, self.impl.num_heads,
|
||||||
self.impl.kv_lora_rank)
|
self.impl.kv_lora_rank)
|
||||||
|
|
||||||
self.impl.o_proj.return_value = (torch.randn(
|
|
||||||
batch_size, self.impl.num_heads * self.impl.v_head_dim), )
|
|
||||||
if not hasattr(self.impl, 'W_UV') or self.impl.W_UV is None:
|
if not hasattr(self.impl, 'W_UV') or self.impl.W_UV is None:
|
||||||
self.impl.W_UV = torch.randn(self.impl.num_heads,
|
self.impl.W_UV = torch.randn(self.impl.num_heads,
|
||||||
self.impl.kv_lora_rank,
|
self.impl.kv_lora_rank,
|
||||||
self.impl.v_head_dim)
|
self.impl.v_head_dim)
|
||||||
result = self.impl._v_up_proj_and_o_proj(x)
|
result = self.impl._v_up_proj(x)
|
||||||
|
|
||||||
self.assertEqual(result.shape[0], batch_size)
|
self.assertEqual(result.shape[0], batch_size)
|
||||||
self.assertEqual(result.shape[1],
|
self.assertEqual(result.shape[1],
|
||||||
@@ -371,8 +370,11 @@ class TestAscendMLAImpl(TestBase):
|
|||||||
metadata.prefill = None
|
metadata.prefill = None
|
||||||
prefix_out = torch.randn(2, 16, 128)
|
prefix_out = torch.randn(2, 16, 128)
|
||||||
prefix_lse = torch.randn(2, 16, 8)
|
prefix_lse = torch.randn(2, 16, 8)
|
||||||
out, lse = self.impl._compute_prefill_context(query, kv_cache, 32,
|
q_pe = query[..., self.impl.qk_nope_head_dim:]
|
||||||
metadata, prefix_out,
|
q_nope = query[..., :self.impl.qk_nope_head_dim]
|
||||||
|
|
||||||
|
out, lse = self.impl._compute_prefill_context(q_nope, q_pe, kv_cache,
|
||||||
|
32, metadata, prefix_out,
|
||||||
prefix_lse)
|
prefix_lse)
|
||||||
|
|
||||||
self.assertTrue(torch.equal(prefix_out, out))
|
self.assertTrue(torch.equal(prefix_out, out))
|
||||||
@@ -386,6 +388,8 @@ class TestAscendMLAImpl(TestBase):
|
|||||||
latent_kv_dim = self.impl.kv_lora_rank
|
latent_kv_dim = self.impl.kv_lora_rank
|
||||||
num_blocks, block_size = 100, 20
|
num_blocks, block_size = 100, 20
|
||||||
query = torch.randn(S, N, D)
|
query = torch.randn(S, N, D)
|
||||||
|
q_nope = query[..., :self.impl.qk_nope_head_dim]
|
||||||
|
q_pe = query[..., self.impl.qk_nope_head_dim:]
|
||||||
kv_cache_0 = torch.randn(num_blocks, block_size, N, latent_kv_dim)
|
kv_cache_0 = torch.randn(num_blocks, block_size, N, latent_kv_dim)
|
||||||
kv_cache_1 = torch.randn(num_blocks, block_size, N, D)
|
kv_cache_1 = torch.randn(num_blocks, block_size, N, D)
|
||||||
kv_cache = [kv_cache_0, kv_cache_1]
|
kv_cache = [kv_cache_0, kv_cache_1]
|
||||||
@@ -406,9 +410,11 @@ class TestAscendMLAImpl(TestBase):
|
|||||||
|
|
||||||
meta = MagicMock()
|
meta = MagicMock()
|
||||||
meta.prefill = prefill_meta
|
meta.prefill = prefill_meta
|
||||||
|
self.impl.prefill_mask = torch.triu(
|
||||||
|
torch.ones(512, 512, device=q_nope.device, dtype=q_nope.dtype), 1)
|
||||||
|
|
||||||
out, lse = self.impl._compute_prefill_context(query, kv_cache, 32,
|
out, lse = self.impl._compute_prefill_context(q_nope, q_pe, kv_cache,
|
||||||
meta, prefix_out,
|
32, meta, prefix_out,
|
||||||
prefix_lse)
|
prefix_lse)
|
||||||
|
|
||||||
mock_load.assert_called_once()
|
mock_load.assert_called_once()
|
||||||
@@ -417,67 +423,36 @@ class TestAscendMLAImpl(TestBase):
|
|||||||
self.assertEqual(out.shape, prefix_out.shape)
|
self.assertEqual(out.shape, prefix_out.shape)
|
||||||
self.assertEqual(lse.shape, prefix_lse.shape)
|
self.assertEqual(lse.shape, prefix_lse.shape)
|
||||||
|
|
||||||
@patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._v_up_proj_and_o_proj")
|
@patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._v_up_proj")
|
||||||
@patch("torch_npu._npu_paged_attention_mla")
|
@patch("torch_npu.npu_fused_infer_attention_score")
|
||||||
def test_forward_decode_without_graph(self, mock_page_attention_mla,
|
def test_forward_decode_without_graph(self,
|
||||||
|
mock_npu_fused_infer_attention_score,
|
||||||
mock_up_proj):
|
mock_up_proj):
|
||||||
num_tokens = 100
|
num_tokens = 100
|
||||||
num_blocks = 256
|
|
||||||
block_size = 4
|
block_size = 4
|
||||||
q_nope = torch.randn(num_tokens, self.impl.num_heads,
|
q_nope = torch.randn(num_tokens, self.impl.num_heads,
|
||||||
self.impl.qk_nope_head_dim)
|
self.impl.qk_nope_head_dim)
|
||||||
q_pe = torch.randn(num_tokens, self.impl.num_heads,
|
q_pe = torch.randn(num_tokens, self.impl.num_heads,
|
||||||
self.impl.qk_rope_head_dim)
|
self.impl.qk_rope_head_dim)
|
||||||
kv_c_and_k_pe_cache = torch.randn(num_blocks, block_size,
|
k_nope = torch.randn(num_tokens, self.impl.num_heads,
|
||||||
self.impl.num_heads,
|
self.impl.qk_nope_head_dim)
|
||||||
self.impl.kv_lora_rank)
|
k_pe = torch.randn(num_tokens, self.impl.num_heads,
|
||||||
|
self.impl.qk_rope_head_dim)
|
||||||
metadata = MagicMock()
|
metadata = MagicMock()
|
||||||
metadata.decode = MagicMock()
|
metadata.decode = MagicMock()
|
||||||
metadata.decode.block_table = MagicMock()
|
metadata.decode.block_table = MagicMock()
|
||||||
metadata.decode.seq_lens = 10
|
metadata.decode.seq_lens = 10
|
||||||
mock_page_attention_mla.return_value = torch.randn(
|
mock_npu_fused_infer_attention_score.return_value = [
|
||||||
num_tokens, self.impl.num_heads, self.impl.kv_lora_rank)
|
torch.randn(num_tokens, self.impl.num_heads,
|
||||||
|
self.impl.kv_lora_rank), None
|
||||||
|
]
|
||||||
mock_up_proj.return_value = torch.randn(num_tokens,
|
mock_up_proj.return_value = torch.randn(num_tokens,
|
||||||
self.impl.num_heads,
|
self.impl.num_heads,
|
||||||
self.impl.v_head_dim)
|
self.impl.v_head_dim)
|
||||||
result = self.impl._forward_decode(q_nope, q_pe, None, None,
|
result = self.impl._forward_decode(q_nope, q_pe, k_nope, k_pe,
|
||||||
kv_c_and_k_pe_cache, metadata)
|
block_size, metadata)
|
||||||
self.assertEqual(result.shape[0], num_tokens)
|
self.assertEqual(result.shape[0], num_tokens)
|
||||||
self.assertEqual(result.shape[1], self.impl.num_heads)
|
self.assertEqual(result.shape[1], self.impl.num_heads)
|
||||||
self.assertEqual(result.shape[2], self.impl.v_head_dim)
|
self.assertEqual(result.shape[2], self.impl.v_head_dim)
|
||||||
mock_up_proj.assert_called_once()
|
mock_up_proj.assert_called_once()
|
||||||
mock_page_attention_mla.assert_called_once()
|
mock_npu_fused_infer_attention_score.assert_called_once()
|
||||||
|
|
||||||
@patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._forward_prefill")
|
|
||||||
@patch("torch_npu._npu_reshape_and_cache")
|
|
||||||
def test_forward_without_graph(self, _, mock_forward_prefill):
|
|
||||||
num_tokens = 100
|
|
||||||
num_blocks = 256
|
|
||||||
block_size = 4
|
|
||||||
rotary_emb_return_value = (torch.randn(num_tokens, 16,
|
|
||||||
self.impl.kv_lora_rank),
|
|
||||||
torch.randn(0, 1, self.impl.kv_lora_rank))
|
|
||||||
self.impl.rotary_emb.side_effect = lambda *args, **kwargs: rotary_emb_return_value
|
|
||||||
self.impl.o_proj.side_effect = lambda *args, **kwargs: torch.randn(
|
|
||||||
1, num_blocks, 128)
|
|
||||||
|
|
||||||
hidden_states_or_q_c = torch.randn(num_tokens, self.impl.q_lora_rank)
|
|
||||||
hidden_states_or_kv_c_normed = torch.randn(num_tokens,
|
|
||||||
self.impl.kv_lora_rank)
|
|
||||||
k_pe = torch.randn(num_tokens, self.impl.qk_rope_head_dim)
|
|
||||||
kv_cache = (torch.randn(num_blocks, block_size, self.impl.num_heads,
|
|
||||||
self.impl.kv_lora_rank),
|
|
||||||
torch.randn(num_blocks, block_size, self.impl.num_heads,
|
|
||||||
self.impl.qk_rope_head_dim))
|
|
||||||
output = torch.randn(num_tokens, self.impl.num_heads,
|
|
||||||
self.impl.v_head_dim)
|
|
||||||
|
|
||||||
metadata = MagicMock()
|
|
||||||
metadata.num_decodes = 0
|
|
||||||
metadata.num_prefills = num_tokens
|
|
||||||
mock_forward_prefill.return_value = torch.randn(
|
|
||||||
0, self.impl.num_heads * self.impl.v_head_dim)
|
|
||||||
result = self.impl.forward(None, hidden_states_or_q_c,
|
|
||||||
hidden_states_or_kv_c_normed, k_pe,
|
|
||||||
kv_cache, metadata, output, False)
|
|
||||||
self.assertEqual(result.shape[0], num_tokens)
|
|
||||||
|
|||||||
@@ -50,6 +50,7 @@ class AscendConfig:
|
|||||||
self.enable_shared_expert_dp = additional_config.get(
|
self.enable_shared_expert_dp = additional_config.get(
|
||||||
"enable_shared_expert_dp", False
|
"enable_shared_expert_dp", False
|
||||||
) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel
|
) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel
|
||||||
|
self.enable_prefetch = additional_config.get("enable_prefetch", False)
|
||||||
|
|
||||||
|
|
||||||
class TorchairGraphConfig:
|
class TorchairGraphConfig:
|
||||||
|
|||||||
@@ -1,19 +1,18 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Optional, Tuple, Type, TypeVar
|
from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Type, TypeVar
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
import torch_npu
|
import torch_npu
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
|
from torch import nn
|
||||||
|
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||||
AttentionMetadata,
|
AttentionMetadata,
|
||||||
MLAAttentionImpl)
|
MLAAttentionImpl)
|
||||||
from vllm.config import VllmConfig, get_current_vllm_config
|
from vllm.config import VllmConfig, get_current_vllm_config
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
|
||||||
from vllm.model_executor.layers.linear import (LinearBase,
|
from vllm.model_executor.layers.linear import (LinearBase,
|
||||||
UnquantizedLinearMethod)
|
UnquantizedLinearMethod)
|
||||||
from vllm.utils import cdiv, round_down
|
from vllm.utils import cdiv, round_down
|
||||||
|
|
||||||
import vllm_ascend.envs as envs_ascend
|
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||||
@@ -22,6 +21,7 @@ from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
|||||||
from vllm_ascend.multistream.context import get_multistream_comm_context
|
from vllm_ascend.multistream.context import get_multistream_comm_context
|
||||||
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
||||||
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
|
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
|
||||||
|
from vllm_ascend.utils import npu_prefetch
|
||||||
from vllm_ascend.worker.npu_input_batch import InputBatch
|
from vllm_ascend.worker.npu_input_batch import InputBatch
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -184,6 +184,9 @@ class AscendMLAMetadataBuilder:
|
|||||||
self.max_blocks = (vllm_config.model_config.max_model_len +
|
self.max_blocks = (vllm_config.model_config.max_model_len +
|
||||||
self.block_size - 1) // self.block_size
|
self.block_size - 1) // self.block_size
|
||||||
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
|
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
|
||||||
|
|
||||||
|
self.decode_threshold = 1
|
||||||
|
|
||||||
if self.chunked_prefill_enabled:
|
if self.chunked_prefill_enabled:
|
||||||
self.chunked_prefill_workspace_size = min(
|
self.chunked_prefill_workspace_size = min(
|
||||||
# Max sure there is enough for 8 full length request or at least
|
# Max sure there is enough for 8 full length request or at least
|
||||||
@@ -224,7 +227,7 @@ class AscendMLAMetadataBuilder:
|
|||||||
|
|
||||||
for i, req_id in enumerate(input_batch.req_ids):
|
for i, req_id in enumerate(input_batch.req_ids):
|
||||||
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||||
if num_tokens == 1:
|
if num_tokens <= self.decode_threshold:
|
||||||
decodes.append(i)
|
decodes.append(i)
|
||||||
else:
|
else:
|
||||||
prefills.append(i)
|
prefills.append(i)
|
||||||
@@ -270,9 +273,8 @@ class AscendMLAMetadataBuilder:
|
|||||||
query_start_loc = common_attn_metadata.query_start_loc
|
query_start_loc = common_attn_metadata.query_start_loc
|
||||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||||
# TODO(xyx): remove the if condition after mla supports torch mode speculative decoding
|
# TODO(xyx): remove the if condition after mla supports torch mode speculative decoding
|
||||||
decode_threshold = 1
|
|
||||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
|
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
|
||||||
split_decodes_and_prefills(common_attn_metadata, decode_threshold=decode_threshold)
|
split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold)
|
||||||
assert num_decodes + num_prefills == num_reqs
|
assert num_decodes + num_prefills == num_reqs
|
||||||
assert num_decode_tokens + num_prefill_tokens == num_actual_tokens
|
assert num_decode_tokens + num_prefill_tokens == num_actual_tokens
|
||||||
|
|
||||||
@@ -312,8 +314,8 @@ class AscendMLAMetadataBuilder:
|
|||||||
if num_prefills > 0:
|
if num_prefills > 0:
|
||||||
reqs_start = num_decodes # prefill_start
|
reqs_start = num_decodes # prefill_start
|
||||||
tokens_start = num_decode_tokens
|
tokens_start = num_decode_tokens
|
||||||
max_query_len = query_lens[tokens_start:].max().item()
|
max_query_len = query_lens[reqs_start:].max().item()
|
||||||
max_seq_lens = seq_lens[tokens_start:].max().item()
|
max_seq_lens = seq_lens[reqs_start:].max().item()
|
||||||
prefill_query_start_loc = query_start_loc[
|
prefill_query_start_loc = query_start_loc[
|
||||||
reqs_start:] - query_start_loc[reqs_start]
|
reqs_start:] - query_start_loc[reqs_start]
|
||||||
|
|
||||||
@@ -359,9 +361,9 @@ class AscendMLAMetadataBuilder:
|
|||||||
1).unsqueeze(2)
|
1).unsqueeze(2)
|
||||||
prefill_metadata = AscendMLAPrefillMetadata(
|
prefill_metadata = AscendMLAPrefillMetadata(
|
||||||
attn_mask=common_attn_metadata.attn_mask,
|
attn_mask=common_attn_metadata.attn_mask,
|
||||||
query_lens=query_lens[tokens_start:],
|
query_lens=query_lens[reqs_start:],
|
||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
context_lens=seq_lens[tokens_start:],
|
context_lens=seq_lens[reqs_start:],
|
||||||
input_positions=prefill_input_positions,
|
input_positions=prefill_input_positions,
|
||||||
block_table=block_table[reqs_start:, ...],
|
block_table=block_table[reqs_start:, ...],
|
||||||
max_query_len=max_query_len,
|
max_query_len=max_query_len,
|
||||||
@@ -416,6 +418,21 @@ class AscendMLAMetadataBuilder:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DecodeMLAPreprocessResult(NamedTuple):
|
||||||
|
ql_nope: Optional[torch.Tensor] = None
|
||||||
|
q_pe: Optional[torch.Tensor] = None
|
||||||
|
k_nope: Optional[torch.Tensor] = None
|
||||||
|
k_pe: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
|
||||||
|
class PrefillMLAPreprocessResult(NamedTuple):
|
||||||
|
q_nope: Optional[torch.Tensor] = None
|
||||||
|
q_pe: Optional[torch.Tensor] = None
|
||||||
|
k_nope: Optional[torch.Tensor] = None
|
||||||
|
k_pe: Optional[torch.Tensor] = None
|
||||||
|
value: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
|
||||||
class AscendMLAImpl(MLAAttentionImpl):
|
class AscendMLAImpl(MLAAttentionImpl):
|
||||||
"""
|
"""
|
||||||
NOTE: Please read the comment at the top of the file before trying to
|
NOTE: Please read the comment at the top of the file before trying to
|
||||||
@@ -455,11 +472,18 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
self.o_proj = kwargs['o_proj']
|
self.o_proj = kwargs['o_proj']
|
||||||
self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
|
self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
|
||||||
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
|
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
|
||||||
|
self.q_a_proj = kwargs.get('q_a_proj', None)
|
||||||
|
self.q_a_layernorm = kwargs.get('q_a_layernorm', None)
|
||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
|
||||||
ascend_config = get_ascend_config()
|
ascend_config = get_ascend_config()
|
||||||
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||||
|
self.enable_prefetch = ascend_config.enable_prefetch
|
||||||
|
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
|
||||||
|
self.chunked_prefill_for_mla = ascend_config.chunked_prefill_for_mla
|
||||||
|
|
||||||
|
self.prefill_mask = None
|
||||||
|
|
||||||
# Adapt torch air graph mode with spec decoding.
|
# Adapt torch air graph mode with spec decoding.
|
||||||
speculative_config = get_current_vllm_config().speculative_config
|
speculative_config = get_current_vllm_config().speculative_config
|
||||||
@@ -467,7 +491,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
self.spec_token_num = speculative_config.num_speculative_tokens
|
self.spec_token_num = speculative_config.num_speculative_tokens
|
||||||
assert self.spec_token_num > 0
|
assert self.spec_token_num > 0
|
||||||
|
|
||||||
def _v_up_proj_and_o_proj(self, x):
|
def _v_up_proj(self, x):
|
||||||
# Convert from (B, N, L) to (N, B, L)
|
# Convert from (B, N, L) to (N, B, L)
|
||||||
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
||||||
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
|
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
|
||||||
@@ -546,7 +570,8 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
|
|
||||||
def _compute_prefill_context(
|
def _compute_prefill_context(
|
||||||
self,
|
self,
|
||||||
query: torch.Tensor,
|
q_nope: torch.Tensor,
|
||||||
|
q_pe: torch.Tensor,
|
||||||
kv_c_and_k_pe_cache: Tuple[torch.Tensor],
|
kv_c_and_k_pe_cache: Tuple[torch.Tensor],
|
||||||
rope_dim: int,
|
rope_dim: int,
|
||||||
attn_metadata: AscendMLAMetadata,
|
attn_metadata: AscendMLAMetadata,
|
||||||
@@ -559,8 +584,6 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
return prefix_output, prefix_lse
|
return prefix_output, prefix_lse
|
||||||
|
|
||||||
iters = len(prefill_metadata.chunked_context.seq_tot)
|
iters = len(prefill_metadata.chunked_context.seq_tot)
|
||||||
q_pe = query[..., self.qk_nope_head_dim:]
|
|
||||||
q_nope = query[..., :self.qk_nope_head_dim]
|
|
||||||
|
|
||||||
seq_len1 = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32)
|
seq_len1 = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32)
|
||||||
cache_kv_c = kv_c_and_k_pe_cache[0]
|
cache_kv_c = kv_c_and_k_pe_cache[0]
|
||||||
@@ -575,19 +598,19 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
kv_c_normed = torch.empty(toks,
|
kv_c_normed = torch.empty(toks,
|
||||||
num_heads,
|
num_heads,
|
||||||
latent_kv_dim,
|
latent_kv_dim,
|
||||||
dtype=query.dtype,
|
dtype=q_nope.dtype,
|
||||||
device=query.device)
|
device=q_nope.device)
|
||||||
k_pe = torch.empty(toks,
|
k_pe = torch.empty(toks,
|
||||||
num_heads,
|
num_heads,
|
||||||
rope_dim,
|
rope_dim,
|
||||||
dtype=query.dtype,
|
dtype=q_nope.dtype,
|
||||||
device=query.device)
|
device=q_nope.device)
|
||||||
|
|
||||||
torch_npu.atb.npu_paged_cache_load(
|
torch_npu.atb.npu_paged_cache_load(
|
||||||
cache_kv_c,
|
cache_kv_c,
|
||||||
cache_k_pe,
|
cache_k_pe,
|
||||||
prefill_metadata.block_table,
|
prefill_metadata.block_table,
|
||||||
seq_len2.to(query.device),
|
seq_len2.to(q_nope.device),
|
||||||
seq_starts=prefill_metadata.chunked_context.starts[i],
|
seq_starts=prefill_metadata.chunked_context.starts[i],
|
||||||
key=kv_c_normed,
|
key=kv_c_normed,
|
||||||
value=k_pe,
|
value=k_pe,
|
||||||
@@ -599,16 +622,13 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
k_nope, v = kv_nope\
|
k_nope, v = kv_nope\
|
||||||
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||||
k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
|
k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
|
||||||
mask = torch.triu(
|
|
||||||
torch.ones(512, 512, device=query.device, dtype=query.dtype),
|
|
||||||
1)
|
|
||||||
torch_npu.atb.npu_ring_mla(
|
torch_npu.atb.npu_ring_mla(
|
||||||
q_nope=q_nope,
|
q_nope=q_nope,
|
||||||
q_rope=q_pe,
|
q_rope=q_pe,
|
||||||
k_nope=k_nope,
|
k_nope=k_nope,
|
||||||
k_rope=k_pe,
|
k_rope=k_pe,
|
||||||
value=v,
|
value=v,
|
||||||
mask=mask,
|
mask=self.prefill_mask,
|
||||||
seqlen=seq_len,
|
seqlen=seq_len,
|
||||||
head_num=self.num_heads,
|
head_num=self.num_heads,
|
||||||
kv_head_num=self.num_heads,
|
kv_head_num=self.num_heads,
|
||||||
@@ -625,33 +645,74 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
|
|
||||||
def _forward_prefill(
|
def _forward_prefill(
|
||||||
self,
|
self,
|
||||||
query: torch.Tensor,
|
q_nope: torch.Tensor,
|
||||||
kv_c_normed: torch.Tensor,
|
q_pe: torch.Tensor,
|
||||||
|
k_nope: torch.Tensor,
|
||||||
k_pe: torch.Tensor,
|
k_pe: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
kv_c_and_k_pe_cache: Tuple[torch.Tensor],
|
kv_c_and_k_pe_cache: Tuple[torch.Tensor],
|
||||||
attn_metadata: AscendMLAMetadata,
|
attn_metadata: AscendMLAMetadata,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert attn_metadata.prefill is not None
|
assert attn_metadata.prefill is not None
|
||||||
assert len(kv_c_and_k_pe_cache) > 1
|
assert len(kv_c_and_k_pe_cache) > 1
|
||||||
|
num_tokens = q_nope.size(0)
|
||||||
num_tokens = query.size(0)
|
|
||||||
attn_output = torch.empty(num_tokens,
|
attn_output = torch.empty(num_tokens,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.v_head_dim,
|
self.v_head_dim,
|
||||||
dtype=query.dtype,
|
dtype=q_nope.dtype,
|
||||||
device=query.device)
|
device=q_nope.device)
|
||||||
k_nope, value = self.kv_b_proj(kv_c_normed)[0].view(
|
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||||
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).split(
|
query = torch.cat((q_nope, q_pe), dim=-1)
|
||||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
key = torch.cat((k_nope, k_pe), dim=-1)
|
||||||
k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
|
torch_npu._npu_flash_attention(
|
||||||
# Here is only 2 possibility of input, ChunkedPrefill or PrefillNoCache
|
query=query,
|
||||||
ascend_config = get_ascend_config()
|
key=key,
|
||||||
|
value=value,
|
||||||
if attn_metadata.attn_state in [
|
mask=attn_metadata.attn_mask,
|
||||||
AscendAttentionState.ChunkedPrefill,
|
seq_len=attn_metadata.prefill.context_lens,
|
||||||
AscendAttentionState.SpecDecoding,
|
scale_value=self.scale,
|
||||||
AscendAttentionState.PrefillCacheHit
|
num_heads=self.num_heads,
|
||||||
] and not ascend_config.chunked_prefill_for_mla:
|
num_kv_heads=self.num_heads,
|
||||||
|
out=attn_output)
|
||||||
|
elif self.chunked_prefill_for_mla:
|
||||||
|
attn_lse = torch.empty(self.num_heads,
|
||||||
|
num_tokens,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=q_nope.device)
|
||||||
|
if self.prefill_mask is None:
|
||||||
|
self.prefill_mask = torch.triu(
|
||||||
|
torch.ones(512,
|
||||||
|
512,
|
||||||
|
device=q_nope.device,
|
||||||
|
dtype=q_nope.dtype),
|
||||||
|
1) # 512: mask only support 512
|
||||||
|
if attn_metadata.num_prefills > 1:
|
||||||
|
self.prefill_mask = self.prefill_mask.unsqueeze(0).repeat(
|
||||||
|
attn_metadata.num_prefills, 1, 1)
|
||||||
|
torch_npu.atb.npu_ring_mla(
|
||||||
|
q_nope=q_nope,
|
||||||
|
q_rope=q_pe,
|
||||||
|
k_nope=k_nope,
|
||||||
|
k_rope=k_pe,
|
||||||
|
value=value,
|
||||||
|
mask=self.prefill_mask,
|
||||||
|
seqlen=torch.tensor(attn_metadata.prefill.query_lens,
|
||||||
|
dtype=torch.int32),
|
||||||
|
head_num=self.num_heads,
|
||||||
|
kv_head_num=self.num_heads,
|
||||||
|
pre_out=None,
|
||||||
|
prev_lse=None,
|
||||||
|
qk_scale=self.scale,
|
||||||
|
kernel_type="kernel_type_high_precision",
|
||||||
|
mask_type="mask_type_triu",
|
||||||
|
input_layout="type_bsnd",
|
||||||
|
calc_type="calc_type_first_ring",
|
||||||
|
output=attn_output,
|
||||||
|
softmax_lse=attn_lse)
|
||||||
|
attn_output, attn_lse = self._compute_prefill_context( \
|
||||||
|
q_nope, q_pe, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse)
|
||||||
|
else:
|
||||||
|
query = torch.cat((q_nope, q_pe), dim=-1)
|
||||||
attn_output_torch = torch.empty(num_tokens,
|
attn_output_torch = torch.empty(num_tokens,
|
||||||
self.num_heads * self.v_head_dim,
|
self.num_heads * self.v_head_dim,
|
||||||
dtype=query.dtype,
|
dtype=query.dtype,
|
||||||
@@ -673,240 +734,318 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
scale=self.scale,
|
scale=self.scale,
|
||||||
alibi_slopes=None,
|
alibi_slopes=None,
|
||||||
causal=True)
|
causal=True)
|
||||||
elif attn_metadata.attn_state in [
|
|
||||||
AscendAttentionState.ChunkedPrefill,
|
|
||||||
AscendAttentionState.SpecDecoding,
|
|
||||||
AscendAttentionState.PrefillCacheHit
|
|
||||||
]:
|
|
||||||
attn_lse = torch.empty(self.num_heads,
|
|
||||||
num_tokens,
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=query.device)
|
|
||||||
q_pe = query[..., self.qk_nope_head_dim:]
|
|
||||||
q_nope = query[..., :self.qk_nope_head_dim]
|
|
||||||
mask = torch.triu(
|
|
||||||
torch.ones(512, 512, device=query.device, dtype=query.dtype),
|
|
||||||
1) # 512: mask only support 512
|
|
||||||
if attn_metadata.num_prefills > 1:
|
|
||||||
mask = mask.unsqueeze(0).repeat(attn_metadata.num_prefills, 1,
|
|
||||||
1)
|
|
||||||
torch_npu.atb.npu_ring_mla(
|
|
||||||
q_nope=q_nope,
|
|
||||||
q_rope=q_pe,
|
|
||||||
k_nope=k_nope,
|
|
||||||
k_rope=k_pe,
|
|
||||||
value=value,
|
|
||||||
mask=mask,
|
|
||||||
seqlen=torch.tensor(attn_metadata.prefill.query_lens,
|
|
||||||
dtype=torch.int32),
|
|
||||||
head_num=self.num_heads,
|
|
||||||
kv_head_num=self.num_heads,
|
|
||||||
pre_out=None,
|
|
||||||
prev_lse=None,
|
|
||||||
qk_scale=self.scale,
|
|
||||||
kernel_type="kernel_type_high_precision",
|
|
||||||
mask_type="mask_type_triu",
|
|
||||||
input_layout="type_bsnd",
|
|
||||||
calc_type="calc_type_first_ring",
|
|
||||||
output=attn_output,
|
|
||||||
softmax_lse=attn_lse)
|
|
||||||
attn_output, attn_lse = self._compute_prefill_context( \
|
|
||||||
query, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse)
|
|
||||||
|
|
||||||
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
|
||||||
key = torch.cat((k_nope, k_pe), dim=-1)
|
|
||||||
torch_npu._npu_flash_attention(
|
|
||||||
query=query,
|
|
||||||
key=key,
|
|
||||||
value=value,
|
|
||||||
mask=attn_metadata.attn_mask,
|
|
||||||
seq_len=attn_metadata.prefill.context_lens,
|
|
||||||
scale_value=self.scale,
|
|
||||||
num_heads=self.num_heads,
|
|
||||||
num_kv_heads=self.num_heads,
|
|
||||||
out=attn_output)
|
|
||||||
attn_output = attn_output.view(-1, self.num_heads, self.v_head_dim)
|
|
||||||
else:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Unexpected path reached, AscendMLAImpl should only have PrefillNoCache, PrefillCacheHit, ChunkedPrefill and SpecDecoding scenario in forward prefill, please file a bug to vllm-ascend !"
|
|
||||||
)
|
|
||||||
attn_output = attn_output.reshape(
|
attn_output = attn_output.reshape(
|
||||||
[num_tokens, self.num_heads * self.v_head_dim])
|
[num_tokens, self.num_heads * self.v_head_dim])
|
||||||
if attn_metadata.attn_state in [
|
if attn_metadata.attn_state in [
|
||||||
AscendAttentionState.ChunkedPrefill,
|
AscendAttentionState.ChunkedPrefill,
|
||||||
AscendAttentionState.SpecDecoding,
|
AscendAttentionState.SpecDecoding,
|
||||||
AscendAttentionState.PrefillCacheHit
|
AscendAttentionState.PrefillCacheHit
|
||||||
] and not ascend_config.chunked_prefill_for_mla:
|
] and not self.chunked_prefill_for_mla:
|
||||||
attn_output = attn_output_torch
|
attn_output = attn_output_torch
|
||||||
|
|
||||||
return attn_output
|
return attn_output
|
||||||
|
|
||||||
|
def exec_kv_decode(
|
||||||
|
self,
|
||||||
|
kv_no_split: torch.Tensor,
|
||||||
|
cos: torch.Tensor,
|
||||||
|
sin: torch.Tensor,
|
||||||
|
kv_cache: Tuple,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
):
|
||||||
|
B = kv_no_split.shape[0]
|
||||||
|
N = self.num_kv_heads
|
||||||
|
S = 1
|
||||||
|
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
|
||||||
|
kv_no_split = kv_no_split.view(
|
||||||
|
B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
|
||||||
|
cache_mode = "PA_NZ" if self.enable_kv_nz else "PA"
|
||||||
|
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
|
||||||
|
kv_no_split,
|
||||||
|
self.kv_a_layernorm.weight,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
slots.to(torch.int64),
|
||||||
|
kv_cache[1],
|
||||||
|
kv_cache[0],
|
||||||
|
epsilon=self.kv_a_layernorm.variance_epsilon,
|
||||||
|
cache_mode=cache_mode,
|
||||||
|
)
|
||||||
|
return k_pe, k_nope
|
||||||
|
|
||||||
|
def exec_kv_prefill(
|
||||||
|
self,
|
||||||
|
kv_no_split: torch.Tensor,
|
||||||
|
cos: torch.Tensor,
|
||||||
|
sin: torch.Tensor,
|
||||||
|
kv_cache: Tuple,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
):
|
||||||
|
B = kv_no_split.shape[0]
|
||||||
|
N = self.num_kv_heads
|
||||||
|
S = 1
|
||||||
|
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
|
||||||
|
kv_no_split = kv_no_split.view(
|
||||||
|
B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
|
||||||
|
cache_mode = "PA_BLK_NZ" if self.enable_kv_nz else "PA"
|
||||||
|
_, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache(
|
||||||
|
kv_no_split,
|
||||||
|
self.kv_a_layernorm.weight,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
slots.to(torch.int64),
|
||||||
|
kv_cache[1],
|
||||||
|
kv_cache[0],
|
||||||
|
epsilon=self.kv_a_layernorm.variance_epsilon,
|
||||||
|
cache_mode=cache_mode,
|
||||||
|
is_output_kv=True,
|
||||||
|
)
|
||||||
|
return k_pe, k_nope
|
||||||
|
|
||||||
|
def rope_single(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
cos: torch.Tensor,
|
||||||
|
sin: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
B, N, D = x.shape
|
||||||
|
S = 1
|
||||||
|
x = x.view(B, N, S, D)
|
||||||
|
x = torch_npu.npu_interleave_rope(x, cos, sin)
|
||||||
|
return x.view(B, N, D)
|
||||||
|
|
||||||
def _forward_decode(
|
def _forward_decode(
|
||||||
self,
|
self,
|
||||||
q_nope: torch.Tensor,
|
q_nope: torch.Tensor,
|
||||||
q_pe: torch.Tensor,
|
q_pe: torch.Tensor,
|
||||||
k_nope: torch.Tensor,
|
k_nope: torch.Tensor,
|
||||||
k_pe: torch.Tensor,
|
k_pe: torch.Tensor,
|
||||||
kv_c_and_k_pe_cache: Tuple[torch.Tensor],
|
block_size: int,
|
||||||
attn_metadata: AscendMLAMetadata,
|
attn_metadata: AscendMLAMetadata,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
decode_meta = attn_metadata.decode
|
decode_meta = attn_metadata.decode
|
||||||
assert decode_meta is not None
|
assert decode_meta is not None
|
||||||
num_tokens = q_nope.size(0)
|
num_tokens = q_nope.size(0)
|
||||||
# The MLA_PA path will be used as default path in the future, `_npu_paged_attention_mla` will
|
# shape of knope/k_pe for npu graph mode should be:
|
||||||
# be removed after the torch_npu contains `torch_npu.atb.npu_multi_head_latent_attention` become
|
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
|
||||||
# public available
|
actual_seq_lengths = None
|
||||||
assert len(kv_c_and_k_pe_cache) > 1
|
if self.enable_kv_nz:
|
||||||
if envs_ascend.VLLM_ASCEND_MLA_PA:
|
k_nope = k_nope.view(-1, self.num_kv_heads,
|
||||||
attn_output = torch_npu.atb.npu_multi_head_latent_attention(
|
self.kv_lora_rank // 16, block_size, 16)
|
||||||
q_nope, q_pe, kv_c_and_k_pe_cache[0], kv_c_and_k_pe_cache[1],
|
k_pe = k_pe.view(-1, self.num_kv_heads,
|
||||||
attn_metadata.decode.block_table,
|
self.qk_rope_head_dim // 16, block_size, 16)
|
||||||
attn_metadata.decode.seq_lens, self.num_heads, self.scale,
|
input_layout = "BSND"
|
||||||
self.num_kv_heads)
|
|
||||||
else:
|
else:
|
||||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
k_nope = k_nope.view(-1, self.num_kv_heads, block_size,
|
||||||
attn_output = torch.empty(
|
self.kv_lora_rank)
|
||||||
[num_tokens, self.num_heads, self.kv_lora_rank],
|
k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
|
||||||
dtype=q.dtype,
|
self.qk_rope_head_dim)
|
||||||
device=q.device)
|
input_layout = "BNSD"
|
||||||
k_cache = torch.cat(
|
|
||||||
[kv_c_and_k_pe_cache[0], kv_c_and_k_pe_cache[1]], dim=-1)
|
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
|
||||||
torch_npu._npu_paged_attention_mla(
|
assert num_tokens % self.spec_token_num == 0
|
||||||
query=q,
|
input_layout = "TND"
|
||||||
key_cache=k_cache,
|
# [bs * q_seq_len, num_heads_per_rank, dim]
|
||||||
num_kv_heads=self.num_kv_heads,
|
q_nope = q_nope.view(num_tokens, self.num_heads, -1)
|
||||||
num_heads=self.num_heads,
|
q_pe = q_pe.view(num_tokens, self.num_heads, -1)
|
||||||
scale_value=self.scale,
|
sparse_mode = 3
|
||||||
block_table=attn_metadata.decode.block_table, # type:ignore
|
spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore
|
||||||
context_lens=attn_metadata.decode.seq_lens, # type:ignore
|
actual_seq_lengths = decode_meta.actual_seq_lengths_q
|
||||||
mla_vheadsize=self.kv_lora_rank,
|
else:
|
||||||
out=attn_output)
|
if self.enable_kv_nz:
|
||||||
|
q_nope = q_nope.view(num_tokens, 1, self.num_heads, -1)
|
||||||
|
q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1)
|
||||||
|
else:
|
||||||
|
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
|
||||||
|
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
|
||||||
|
sparse_mode = 0
|
||||||
|
spec_attn_mask = None
|
||||||
|
|
||||||
|
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
|
||||||
|
q_nope,
|
||||||
|
k_nope,
|
||||||
|
k_nope,
|
||||||
|
query_rope=q_pe,
|
||||||
|
key_rope=k_pe,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
num_key_value_heads=self.num_kv_heads,
|
||||||
|
input_layout=input_layout,
|
||||||
|
atten_mask=spec_attn_mask,
|
||||||
|
sparse_mode=sparse_mode,
|
||||||
|
scale=self.scale,
|
||||||
|
antiquant_mode=0,
|
||||||
|
antiquant_scale=None,
|
||||||
|
block_table=decode_meta.block_table,
|
||||||
|
block_size=block_size,
|
||||||
|
actual_seq_lengths_kv=decode_meta.seq_lens_list,
|
||||||
|
actual_seq_lengths=actual_seq_lengths)
|
||||||
|
|
||||||
current_ms_metadata = get_multistream_comm_context()
|
current_ms_metadata = get_multistream_comm_context()
|
||||||
if current_ms_metadata is None:
|
if current_ms_metadata is None:
|
||||||
return self._v_up_proj_and_o_proj(attn_output)
|
return self._v_up_proj(attn_output)
|
||||||
else:
|
else:
|
||||||
current_ms_metadata.before_comm_event.record()
|
current_ms_metadata.before_comm_event.record()
|
||||||
with torch.npu.stream(current_ms_metadata.comm_stream):
|
with torch.npu.stream(current_ms_metadata.comm_stream):
|
||||||
current_ms_metadata.before_comm_event.wait()
|
current_ms_metadata.before_comm_event.wait()
|
||||||
return self._v_up_proj_and_o_proj(attn_output)
|
return self._v_up_proj(attn_output)
|
||||||
|
|
||||||
|
def _mla_preprocess(self, hidden_states, kv_cache, attn_metadata,
|
||||||
|
need_gather_q_kv):
|
||||||
|
# MLA Preprocess:
|
||||||
|
# 1. Perform q_a_proj and q_a_layernorm to obtain q_c
|
||||||
|
# 2. Perform kv_a_proj_with_mqa to obtain kv_no_split
|
||||||
|
# 3. If need_gather_q_kv, perform all_gather.
|
||||||
|
# 4. Preprocess decode tokens, write kv cache and get:
|
||||||
|
# decode_ql_nope, decode_q_pe, decode_k_pe, decode_k_nope
|
||||||
|
# 5. Preprocess prefill tokens, write kv cache and get:
|
||||||
|
# prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe, prefill_value
|
||||||
|
has_decode = attn_metadata.num_decodes > 0
|
||||||
|
has_prefill = attn_metadata.num_prefills > 0
|
||||||
|
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||||
|
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||||
|
if self.q_a_proj is not None:
|
||||||
|
npu_prefetch(self.q_a_proj.weight,
|
||||||
|
hidden_states,
|
||||||
|
enabled=self.enable_prefetch)
|
||||||
|
ckq = self.q_a_proj(hidden_states)[0]
|
||||||
|
q_c = self.q_a_layernorm(ckq)
|
||||||
|
else:
|
||||||
|
q_c = hidden_states
|
||||||
|
|
||||||
|
kv_no_split = self.kv_a_proj_with_mqa(hidden_states)[0]
|
||||||
|
# Process for shared_expert_dp
|
||||||
|
if need_gather_q_kv:
|
||||||
|
q_c = get_tp_group().all_gather(q_c, 0)
|
||||||
|
kv_no_split = get_tp_group().all_gather(kv_no_split, 0)
|
||||||
|
decode_preprocess_res = None
|
||||||
|
prefill_preprocess_res = None
|
||||||
|
# Preprocess for decode tokens
|
||||||
|
if has_decode:
|
||||||
|
decode_q_c = q_c[:num_decode_tokens]
|
||||||
|
cos = attn_metadata.decode.cos
|
||||||
|
sin = attn_metadata.decode.sin
|
||||||
|
decode_ql_nope, decode_q_pe = \
|
||||||
|
self._q_proj_and_k_up_proj(decode_q_c)
|
||||||
|
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
|
||||||
|
decode_slots = attn_metadata.slot_mapping[:num_decode_tokens]
|
||||||
|
decode_kv_no_split = kv_no_split[:num_decode_tokens]
|
||||||
|
decode_k_pe, decode_k_nope = self.exec_kv_decode(
|
||||||
|
decode_kv_no_split, cos, sin, kv_cache, decode_slots)
|
||||||
|
decode_preprocess_res = DecodeMLAPreprocessResult(
|
||||||
|
decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe)
|
||||||
|
# Preprocess for prefill tokens
|
||||||
|
if has_prefill:
|
||||||
|
prefill_kv_no_split = kv_no_split[
|
||||||
|
num_decode_tokens:num_actual_tokens]
|
||||||
|
prefill_q_c = q_c[num_decode_tokens:num_actual_tokens]
|
||||||
|
prefill_q = self.q_proj(prefill_q_c)[0]\
|
||||||
|
.view(-1, self.num_heads, self.qk_head_dim)
|
||||||
|
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
|
||||||
|
prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim]
|
||||||
|
cos = attn_metadata.prefill.cos
|
||||||
|
sin = attn_metadata.prefill.sin
|
||||||
|
prefill_slots = attn_metadata.slot_mapping[
|
||||||
|
num_decode_tokens:num_actual_tokens]
|
||||||
|
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
|
||||||
|
prefill_k_pe, prefill_k_c_normed = self.exec_kv_prefill(
|
||||||
|
prefill_kv_no_split, cos, sin, kv_cache, prefill_slots)
|
||||||
|
prefill_k_pe = prefill_k_pe.view(prefill_q_c.shape[0],
|
||||||
|
self.num_kv_heads, -1)
|
||||||
|
prefill_k_nope, prefill_value = self.kv_b_proj(
|
||||||
|
prefill_k_c_normed)[0].view(
|
||||||
|
-1, self.num_heads,
|
||||||
|
self.qk_nope_head_dim + self.v_head_dim).split(
|
||||||
|
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||||
|
prefill_k_pe = prefill_k_pe.expand(
|
||||||
|
(*prefill_k_nope.shape[:-1], -1))
|
||||||
|
prefill_preprocess_res = PrefillMLAPreprocessResult(
|
||||||
|
prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe,
|
||||||
|
prefill_value)
|
||||||
|
return decode_preprocess_res, prefill_preprocess_res
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
layer: AttentionLayer,
|
hidden_states: torch.Tensor, # query in unified attn
|
||||||
hidden_states_or_q_c: torch.Tensor, # query in unified attn
|
|
||||||
hidden_states_or_kv_c_normed: torch.Tensor, # key in unified attn
|
|
||||||
k_pe: torch.Tensor, # value in unified attn
|
|
||||||
kv_cache: Tuple[torch.Tensor],
|
kv_cache: Tuple[torch.Tensor],
|
||||||
attn_metadata: M,
|
attn_metadata: M,
|
||||||
|
need_gather_q_kv: bool = False,
|
||||||
output: Optional[torch.Tensor] = None,
|
output: Optional[torch.Tensor] = None,
|
||||||
ckq: Optional[torch.Tensor] = None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert output is not None, "Output tensor must be provided."
|
assert output is not None, "Output tensor must be provided."
|
||||||
if attn_metadata is None:
|
if attn_metadata is None:
|
||||||
# Profiling run.
|
# Profiling run.
|
||||||
return output
|
return output
|
||||||
num_actual_toks = attn_metadata.num_actual_tokens
|
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||||
if k_pe is None:
|
|
||||||
kv_c, k_pe = self.kv_a_proj_with_mqa(
|
|
||||||
hidden_states_or_kv_c_normed)[0].split(
|
|
||||||
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
|
||||||
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
|
||||||
else:
|
|
||||||
kv_c_normed = hidden_states_or_kv_c_normed
|
|
||||||
assert attn_metadata.num_decodes is not None and \
|
assert attn_metadata.num_decodes is not None and \
|
||||||
attn_metadata.num_prefills is not None and \
|
attn_metadata.num_prefills is not None and \
|
||||||
attn_metadata.num_decode_tokens is not None
|
attn_metadata.num_decode_tokens is not None
|
||||||
has_decode = attn_metadata.num_decodes > 0
|
|
||||||
has_prefill = attn_metadata.num_prefills > 0
|
|
||||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||||
# Inputs and outputs may be padded for CUDA graphs
|
# Inputs and outputs may be padded for CUDA graphs
|
||||||
output_padded = output
|
output_padded = output
|
||||||
output = output[:num_actual_toks, ...]
|
output = output[:num_actual_tokens, ...]
|
||||||
kv_c_normed = kv_c_normed[:num_actual_toks, ...]
|
o_proj_input_shape = (num_actual_tokens,
|
||||||
prefill_k_c_normed = kv_c_normed[num_decode_tokens:]
|
|
||||||
hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...]
|
|
||||||
prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:]
|
|
||||||
decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
|
|
||||||
k_pe = k_pe[:num_actual_toks, ...]
|
|
||||||
k_pe = k_pe.unsqueeze(1)
|
|
||||||
decode_k_pe = k_pe[:num_decode_tokens]
|
|
||||||
prefill_k_pe = k_pe[num_decode_tokens:]
|
|
||||||
if has_decode:
|
|
||||||
decode_k_nope = None
|
|
||||||
assert attn_metadata.decode is not None
|
|
||||||
decode_ql_nope, decode_q_pe = \
|
|
||||||
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
|
|
||||||
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
|
|
||||||
attn_metadata.decode.input_positions,
|
|
||||||
decode_q_pe.contiguous(),
|
|
||||||
decode_k_pe,
|
|
||||||
max_seq_len=attn_metadata.decode.max_seq_lens)
|
|
||||||
if has_prefill:
|
|
||||||
assert attn_metadata.prefill is not None
|
|
||||||
prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\
|
|
||||||
.view(-1, self.num_heads, self.qk_head_dim)
|
|
||||||
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
|
|
||||||
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
|
|
||||||
attn_metadata.prefill.input_positions,
|
|
||||||
prefill_q_pe.contiguous(),
|
|
||||||
prefill_k_pe,
|
|
||||||
max_seq_len=attn_metadata.prefill.max_seq_lens)
|
|
||||||
|
|
||||||
assert len(
|
|
||||||
kv_cache
|
|
||||||
) > 1, "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)"
|
|
||||||
kv_c_normed = kv_c_normed.view(
|
|
||||||
[num_actual_toks, self.num_kv_heads, -1])
|
|
||||||
torch_npu._npu_reshape_and_cache(
|
|
||||||
key=kv_c_normed,
|
|
||||||
value=k_pe,
|
|
||||||
key_cache=kv_cache[0],
|
|
||||||
value_cache=kv_cache[1],
|
|
||||||
slot_indices=attn_metadata.slot_mapping)
|
|
||||||
o_proj_input_shape = (num_actual_toks,
|
|
||||||
self.num_heads * self.v_head_dim)
|
self.num_heads * self.v_head_dim)
|
||||||
o_proj_input = torch.empty(o_proj_input_shape,
|
o_proj_input = torch.empty(o_proj_input_shape,
|
||||||
dtype=hidden_states_or_q_c.dtype,
|
dtype=hidden_states.dtype,
|
||||||
device=hidden_states_or_q_c.device)
|
device=hidden_states.device)
|
||||||
if has_prefill:
|
|
||||||
# FIX: aicore move should be also placed on the comm stream in dbo,
|
|
||||||
# otherwise it may affect the accuracy
|
|
||||||
# TODO: use an elegant way to overlap
|
|
||||||
output_prefill = self._forward_prefill(prefill_q,
|
|
||||||
prefill_k_c_normed,
|
|
||||||
prefill_k_pe, kv_cache,
|
|
||||||
attn_metadata)
|
|
||||||
current_ms_metadata = get_multistream_comm_context()
|
|
||||||
if current_ms_metadata is not None:
|
|
||||||
current_ms_metadata.before_comm_event.record()
|
|
||||||
with torch.npu.stream(current_ms_metadata.comm_stream):
|
|
||||||
current_ms_metadata.before_comm_event.wait()
|
|
||||||
o_proj_input[num_decode_tokens:] = output_prefill
|
|
||||||
else:
|
|
||||||
o_proj_input[num_decode_tokens:] = output_prefill
|
|
||||||
|
|
||||||
if has_decode:
|
# MLA Preprocess
|
||||||
output_decode = self._forward_decode(decode_ql_nope, decode_q_pe,
|
decode_preprocess_res, prefill_preprocess_res = self._mla_preprocess(
|
||||||
decode_k_nope, decode_k_pe,
|
hidden_states, kv_cache, attn_metadata, need_gather_q_kv)
|
||||||
kv_cache, attn_metadata)
|
|
||||||
|
if decode_preprocess_res is not None:
|
||||||
|
# MLA Preprocess for decoding
|
||||||
|
output_decode = self._forward_decode(decode_preprocess_res.ql_nope,
|
||||||
|
decode_preprocess_res.q_pe,
|
||||||
|
decode_preprocess_res.k_nope,
|
||||||
|
decode_preprocess_res.k_pe,
|
||||||
|
kv_cache[0].shape[1],
|
||||||
|
attn_metadata)
|
||||||
current_ms_metadata = get_multistream_comm_context()
|
current_ms_metadata = get_multistream_comm_context()
|
||||||
if current_ms_metadata is not None:
|
if current_ms_metadata is not None:
|
||||||
with torch.npu.stream(current_ms_metadata.comm_stream):
|
with torch.npu.stream(current_ms_metadata.comm_stream):
|
||||||
o_proj_input[:num_decode_tokens] = output_decode
|
o_proj_input[:num_decode_tokens] = output_decode
|
||||||
|
current_ms_metadata.after_comm_event.record()
|
||||||
else:
|
else:
|
||||||
o_proj_input[:num_decode_tokens] = output_decode
|
o_proj_input[:num_decode_tokens] = output_decode
|
||||||
|
|
||||||
|
if prefill_preprocess_res is not None:
|
||||||
|
# FIX: aicore move should be also placed on the comm stream in dbo,
|
||||||
|
# otherwise it may affect the accuracy
|
||||||
|
# TODO: use an elegant way to overlap
|
||||||
|
output_prefill = self._forward_prefill(
|
||||||
|
prefill_preprocess_res.q_nope, prefill_preprocess_res.q_pe,
|
||||||
|
prefill_preprocess_res.k_nope, prefill_preprocess_res.k_pe,
|
||||||
|
prefill_preprocess_res.value, kv_cache, attn_metadata)
|
||||||
|
current_ms_metadata = get_multistream_comm_context()
|
||||||
|
if current_ms_metadata is not None:
|
||||||
|
with torch.npu.stream(current_ms_metadata.comm_stream):
|
||||||
|
o_proj_input[num_decode_tokens:] = output_prefill
|
||||||
|
current_ms_metadata.after_comm_event.record()
|
||||||
|
else:
|
||||||
|
o_proj_input[num_decode_tokens:] = output_prefill
|
||||||
|
# O proj
|
||||||
current_ms_metadata = get_multistream_comm_context()
|
current_ms_metadata = get_multistream_comm_context()
|
||||||
|
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024
|
||||||
if current_ms_metadata is None:
|
if current_ms_metadata is None:
|
||||||
|
npu_prefetch(self.o_proj.weight,
|
||||||
|
o_proj_input,
|
||||||
|
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
||||||
|
enabled=self.enable_prefetch)
|
||||||
|
|
||||||
output[...] = self.o_proj(
|
output[...] = self.o_proj(
|
||||||
o_proj_input,
|
o_proj_input,
|
||||||
is_prefill=True,
|
is_prefill=prefill_preprocess_res is not None,
|
||||||
is_force_scatter=self.enable_shared_expert_dp)[0]
|
is_force_scatter=self.enable_shared_expert_dp)[0]
|
||||||
else:
|
else:
|
||||||
with torch.npu.stream(current_ms_metadata.comm_stream):
|
with torch.npu.stream(current_ms_metadata.comm_stream):
|
||||||
|
npu_prefetch(self.o_proj.weight,
|
||||||
|
o_proj_input,
|
||||||
|
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
||||||
|
enabled=self.enable_prefetch)
|
||||||
output[...] = self.o_proj(
|
output[...] = self.o_proj(
|
||||||
o_proj_input,
|
o_proj_input,
|
||||||
is_prefill=True,
|
is_prefill=prefill_preprocess_res is not None,
|
||||||
is_force_scatter=self.enable_shared_expert_dp)[0]
|
is_force_scatter=self.enable_shared_expert_dp)[0]
|
||||||
current_ms_metadata.after_comm_event.record()
|
current_ms_metadata.after_comm_event.record()
|
||||||
del o_proj_input
|
del o_proj_input
|
||||||
|
|||||||
@@ -37,7 +37,6 @@ from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
|
|||||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
get_tp_group, split_tensor_along_last_dim,
|
get_tp_group, split_tensor_along_last_dim,
|
||||||
tensor_model_parallel_all_gather,
|
|
||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
tensor_model_parallel_reduce_scatter)
|
tensor_model_parallel_reduce_scatter)
|
||||||
from vllm.distributed.parallel_state import get_dp_group, get_ep_group
|
from vllm.distributed.parallel_state import get_dp_group, get_ep_group
|
||||||
@@ -73,7 +72,7 @@ from vllm_ascend.ascend_config import get_ascend_config
|
|||||||
from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
||||||
from vllm_ascend.quantization.quant_config import AscendLinearMethod
|
from vllm_ascend.quantization.quant_config import AscendLinearMethod
|
||||||
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
|
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
|
||||||
from vllm_ascend.utils import dispose_tensor, npu_prefetch
|
from vllm_ascend.utils import dispose_tensor
|
||||||
|
|
||||||
|
|
||||||
class CustomDeepseekV2SiluAndMul(SiluAndMul):
|
class CustomDeepseekV2SiluAndMul(SiluAndMul):
|
||||||
@@ -471,9 +470,6 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
|
|||||||
self.debug_layer_idx = int(self.prefix.split(".")[-2])
|
self.debug_layer_idx = int(self.prefix.split(".")[-2])
|
||||||
|
|
||||||
ascend_config = get_ascend_config()
|
ascend_config = get_ascend_config()
|
||||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
|
||||||
self.enable_multistream_mla = \
|
|
||||||
ascend_config.torchair_graph_config.enable_multistream_mla
|
|
||||||
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||||
|
|
||||||
if self.q_lora_rank is not None:
|
if self.q_lora_rank is not None:
|
||||||
@@ -515,8 +511,7 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
|
|||||||
if (config.n_routed_experts is not None
|
if (config.n_routed_experts is not None
|
||||||
and self.debug_layer_idx >= config.first_k_dense_replace
|
and self.debug_layer_idx >= config.first_k_dense_replace
|
||||||
and self.debug_layer_idx % config.moe_layer_freq == 0
|
and self.debug_layer_idx % config.moe_layer_freq == 0
|
||||||
and (ascend_config.torchair_graph_config.enable_multistream_moe
|
and self.enable_shared_expert_dp):
|
||||||
or self.enable_shared_expert_dp)):
|
|
||||||
self.o_proj = CustomDeepseekV2RowParallelLinearReplaceAllreduce(
|
self.o_proj = CustomDeepseekV2RowParallelLinearReplaceAllreduce(
|
||||||
self.num_heads * self.v_head_dim,
|
self.num_heads * self.v_head_dim,
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
@@ -568,6 +563,9 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
|
|||||||
qk_head_dim=self.qk_head_dim,
|
qk_head_dim=self.qk_head_dim,
|
||||||
v_head_dim=self.v_head_dim,
|
v_head_dim=self.v_head_dim,
|
||||||
rotary_emb=self.rotary_emb,
|
rotary_emb=self.rotary_emb,
|
||||||
|
q_a_proj=self.q_a_proj if self.q_lora_rank is not None else None,
|
||||||
|
q_a_layernorm=self.q_a_layernorm
|
||||||
|
if self.q_lora_rank is not None else None,
|
||||||
q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj,
|
q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj,
|
||||||
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
|
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
|
||||||
kv_a_layernorm=self.kv_a_layernorm,
|
kv_a_layernorm=self.kv_a_layernorm,
|
||||||
@@ -582,55 +580,29 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
|
|||||||
kv_cache: Optional[torch.Tensor] = None,
|
kv_cache: Optional[torch.Tensor] = None,
|
||||||
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
|
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
|
||||||
forward_context = get_forward_context()
|
forward_context = get_forward_context()
|
||||||
enable_multistream_mla = (self.enable_multistream_mla
|
if kv_cache is None:
|
||||||
and attn_metadata is not None
|
kv_cache = self.mla_attn.kv_cache[forward_context.virtual_engine]
|
||||||
and not forward_context.with_prefill
|
num_tokens = hidden_states.shape[0]
|
||||||
and attn_metadata.num_decodes > 0)
|
need_gather_q_kv = False
|
||||||
forward_kwargs = {"enable_multistream_mla": enable_multistream_mla}
|
if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers:
|
||||||
if self.q_lora_rank is not None:
|
# Simulate all gather to calculate output shape
|
||||||
npu_prefetch(self.q_a_proj.weight,
|
num_tokens = num_tokens * self.tp_size
|
||||||
hidden_states,
|
need_gather_q_kv = True
|
||||||
enabled=enable_multistream_mla)
|
if not self.enable_shared_expert_dp or self.debug_layer_idx < self.first_k_dense_replace:
|
||||||
ckq = self.q_a_proj(hidden_states)[0]
|
|
||||||
hidden_states_or_q_c = self.q_a_layernorm(ckq)
|
|
||||||
forward_kwargs['ckq'] = ckq
|
|
||||||
else:
|
|
||||||
hidden_states_or_q_c = hidden_states
|
|
||||||
if self.torchair_graph_enabled:
|
|
||||||
output_shape = hidden_states.shape
|
output_shape = hidden_states.shape
|
||||||
output = torch.empty(output_shape,
|
|
||||||
dtype=hidden_states_or_q_c.dtype,
|
|
||||||
device=hidden_states_or_q_c.device)
|
|
||||||
forward_kwargs['output'] = output
|
|
||||||
output = self.mla_attn.impl.forward(self.mla_attn,
|
|
||||||
hidden_states_or_q_c,
|
|
||||||
hidden_states, None, kv_cache,
|
|
||||||
attn_metadata,
|
|
||||||
**forward_kwargs)
|
|
||||||
output = output.view(-1, output_shape[-1])
|
|
||||||
return output
|
|
||||||
else:
|
else:
|
||||||
kv_no_split = self.kv_a_proj_with_mqa(hidden_states)[0]
|
rows = num_tokens // self.tp_size
|
||||||
if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers:
|
if num_tokens % self.tp_size:
|
||||||
hidden_states_or_q_c = get_tp_group().all_gather(
|
rows += 1
|
||||||
hidden_states_or_q_c, 0)
|
output_shape = (rows, hidden_states.shape[1])
|
||||||
kv_no_split = get_tp_group().all_gather(kv_no_split, 0)
|
output = torch.empty(output_shape,
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
kv_c, k_pe = kv_no_split.split(
|
device=hidden_states.device)
|
||||||
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
output = self.mla_attn.impl.forward(hidden_states, kv_cache,
|
||||||
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
forward_context.attn_metadata,
|
||||||
if not self.enable_shared_expert_dp or self.debug_layer_idx < self.first_k_dense_replace:
|
need_gather_q_kv, output)
|
||||||
output_shape = hidden_states.shape
|
output = output.view(-1, output_shape[-1])
|
||||||
else:
|
return output
|
||||||
num_tokens = hidden_states_or_q_c.shape[0]
|
|
||||||
rows = num_tokens // self.tp_size
|
|
||||||
if num_tokens % self.tp_size:
|
|
||||||
rows += 1
|
|
||||||
output_shape = (rows, hidden_states.shape[1])
|
|
||||||
return self.mla_attn(hidden_states_or_q_c,
|
|
||||||
kv_c_normed,
|
|
||||||
k_pe,
|
|
||||||
output_shape=output_shape)
|
|
||||||
|
|
||||||
|
|
||||||
class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
||||||
@@ -688,8 +660,6 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.mlp",
|
prefix=f"{prefix}.mlp",
|
||||||
)
|
)
|
||||||
self.mla_moe_communication = ascend_config.torchair_graph_config.enable_multistream_moe \
|
|
||||||
and model_config.use_mla and self.tp_size > 1
|
|
||||||
else:
|
else:
|
||||||
self.mlp = CustomDeepseekV2MLP(
|
self.mlp = CustomDeepseekV2MLP(
|
||||||
hidden_size=config.hidden_size,
|
hidden_size=config.hidden_size,
|
||||||
@@ -698,7 +668,6 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.mlp",
|
prefix=f"{prefix}.mlp",
|
||||||
)
|
)
|
||||||
self.mla_moe_communication = False
|
|
||||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||||
eps=config.rms_norm_eps)
|
eps=config.rms_norm_eps)
|
||||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||||
@@ -718,10 +687,6 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
|||||||
replace_allreduce: bool = False,
|
replace_allreduce: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# Self Attention
|
# Self Attention
|
||||||
if attn_metadata is not None and attn_metadata.num_decodes > 0:
|
|
||||||
mla_moe_communication = self.mla_moe_communication and replace_allreduce
|
|
||||||
else:
|
|
||||||
mla_moe_communication = False
|
|
||||||
if residual is None:
|
if residual is None:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
@@ -733,9 +698,6 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
|||||||
# to save npu memory because they're no longer used.
|
# to save npu memory because they're no longer used.
|
||||||
dispose_tensor(previous_hidden_states)
|
dispose_tensor(previous_hidden_states)
|
||||||
dispose_tensor(previous_residual)
|
dispose_tensor(previous_residual)
|
||||||
if mla_moe_communication and self.layer_idx > self.first_k_dense_replace:
|
|
||||||
hidden_states = tensor_model_parallel_all_gather(hidden_states,
|
|
||||||
dim=0)
|
|
||||||
|
|
||||||
hidden_states = self.self_attn(
|
hidden_states = self.self_attn(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
@@ -744,13 +706,6 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
|||||||
attn_metadata=attn_metadata,
|
attn_metadata=attn_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
if mla_moe_communication and residual.shape[0] != hidden_states.shape[
|
|
||||||
0]:
|
|
||||||
chunk_hidden_states = torch.tensor_split(residual,
|
|
||||||
self.tp_size,
|
|
||||||
dim=0)
|
|
||||||
residual = chunk_hidden_states[self.tp_rank]
|
|
||||||
|
|
||||||
if hidden_states.dtype == torch.float16:
|
if hidden_states.dtype == torch.float16:
|
||||||
# Fix FP16 overflow
|
# Fix FP16 overflow
|
||||||
# We scale both hidden_states and residual before
|
# We scale both hidden_states and residual before
|
||||||
@@ -778,9 +733,7 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
|||||||
hidden_states, residual)
|
hidden_states, residual)
|
||||||
|
|
||||||
if isinstance(self.mlp, CustomDeepseekV2MoE):
|
if isinstance(self.mlp, CustomDeepseekV2MoE):
|
||||||
hidden_states = self.mlp(hidden_states,
|
hidden_states = self.mlp(hidden_states, attn_metadata)
|
||||||
attn_metadata,
|
|
||||||
replace_allreduce=mla_moe_communication)
|
|
||||||
else:
|
else:
|
||||||
hidden_states = self.mlp(hidden_states)
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
|
||||||
@@ -793,10 +746,6 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
|||||||
# The scaling of DeepseekV2MOE output would be done in the forward
|
# The scaling of DeepseekV2MOE output would be done in the forward
|
||||||
# of DeepseekV2MOE
|
# of DeepseekV2MOE
|
||||||
hidden_states *= 1. / self.routed_scaling_factor
|
hidden_states *= 1. / self.routed_scaling_factor
|
||||||
if mla_moe_communication and self.layer_idx == self.layers - 1:
|
|
||||||
hidden_states = tensor_model_parallel_all_gather(hidden_states,
|
|
||||||
dim=0)
|
|
||||||
residual = tensor_model_parallel_all_gather(residual, dim=0)
|
|
||||||
|
|
||||||
# for last layer of main model and mtp layer.
|
# for last layer of main model and mtp layer.
|
||||||
if self.enable_shared_expert_dp and self.layer_idx >= (
|
if self.enable_shared_expert_dp and self.layer_idx >= (
|
||||||
|
|||||||
Reference in New Issue
Block a user