[Refactor]6/N Extract common code of class AscendMLAImpl (#5314)
RFC: https://github.com/vllm-project/vllm-ascend/issues/4629 Reason: Eliminate duplicate code for two file(mla_v1.py mla_cp.py) of IMPL classes. vLLM version: 0.13.0rc3 vLLM main:ad32e3e19c- vLLM version: release/v0.13.0 - vLLM main:5fbfa8d9ef--------- Signed-off-by: wujinyuan1 <wjy9595@qq.com> Co-authored-by: wujinyuan1 <wjy9595@qq.com> Co-authored-by: weijinqian0 <1184188277@qq.com>
This commit is contained in:
@@ -254,7 +254,7 @@ class TestAscendMLAImpl(TestBase):
|
||||
|
||||
@patch('vllm_ascend.attention.mla_cp.get_dcp_group')
|
||||
@patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad")
|
||||
@patch("vllm_ascend.attention.mla_cp.maybe_npu_prefetch")
|
||||
@patch("vllm_ascend.attention.mla_v1.maybe_npu_prefetch")
|
||||
def test_mla_preprocess_dcp(self, magic_npu_fetch,
|
||||
mock_maybe_all_gather_and_maybe_unpad,
|
||||
mock_get_dcp_group):
|
||||
@@ -339,7 +339,7 @@ class TestAscendMLAImpl(TestBase):
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('vllm_ascend.attention.mla_cp.get_pcp_group')
|
||||
@patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad")
|
||||
@patch("vllm_ascend.attention.mla_cp.maybe_npu_prefetch")
|
||||
@patch("vllm_ascend.attention.mla_v1.maybe_npu_prefetch")
|
||||
def test_mla_preprocess_pcp(self, magic_npu_fetch,
|
||||
mock_maybe_all_gather_and_maybe_unpad,
|
||||
mock_get_pcp_group,
|
||||
@@ -543,8 +543,8 @@ class TestAscendMLAImpl(TestBase):
|
||||
self.impl._v_up_proj.return_value = torch.randn(
|
||||
B, self.impl.v_head_dim)
|
||||
|
||||
result = self.impl._forward_decode_pcp_dcp(q_nope, q_pe, k_nope, k_pe,
|
||||
BS, attn_metadata)
|
||||
result = self.impl._forward_decode(q_nope, q_pe, k_nope, k_pe, BS,
|
||||
attn_metadata)
|
||||
|
||||
self.assertEqual(result.shape[0], B)
|
||||
self.assertEqual(result.shape[1], self.impl.v_head_dim)
|
||||
@@ -578,14 +578,14 @@ class TestAscendMLAImpl(TestBase):
|
||||
|
||||
def mock_reorg_kvcache(allgatered_kv_c_normed: torch.Tensor,
|
||||
allgatered_k_pe: torch.Tensor,
|
||||
padded_local_chunk_seq_lens_lst: list[int],
|
||||
local_context_lens_allranks: list[list[int]],
|
||||
sum_seq_len: int, max_seq_len: int,
|
||||
chunk_size: int, chunk_idx: int, toks: int):
|
||||
return torch.randn(sum_seq_len, allgatered_kv_c_normed.shape[1],
|
||||
chunked_context: CPChunkedContextMetadata,
|
||||
chunk_idx: int, toks: int):
|
||||
return torch.randn(
|
||||
chunked_context.cu_seq_lens_lst[chunk_idx][-1],
|
||||
allgatered_kv_c_normed.shape[1],
|
||||
allgatered_kv_c_normed.shape[2]), torch.randn(
|
||||
sum_seq_len, allgatered_k_pe.shape[1],
|
||||
allgatered_k_pe.shape[2])
|
||||
chunked_context.cu_seq_lens_lst[chunk_idx][-1],
|
||||
allgatered_k_pe.shape[1], allgatered_k_pe.shape[2])
|
||||
|
||||
# mock proj
|
||||
self.impl.kv_b_proj.side_effect = mock_kv_b_proj
|
||||
@@ -679,10 +679,6 @@ class TestAscendMLAImpl(TestBase):
|
||||
iters * (1 if dcp_size * pcp_size > 1 else 0))
|
||||
self.assertEqual(mock_load.call_count, iters)
|
||||
self.assertEqual(mock_ring.call_count, iters)
|
||||
self.assertEqual(mock_dcp.all_gather.call_count,
|
||||
(1 if dcp_size > 1 else 0))
|
||||
self.assertEqual(mock_pcp.all_gather.call_count,
|
||||
iters * (1 if pcp_size > 1 else 0))
|
||||
mock_reorg.reset_mock()
|
||||
mock_load.reset_mock()
|
||||
mock_ring.reset_mock()
|
||||
@@ -691,7 +687,18 @@ class TestAscendMLAImpl(TestBase):
|
||||
self.assertEqual(out.shape, prefix_out.shape)
|
||||
self.assertEqual(lse.shape, prefix_lse.shape)
|
||||
|
||||
def test_reorg_kvcache_with_dcp_pcp(self):
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch('vllm.distributed.parallel_state._PCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch('vllm.distributed.parallel_state.get_dcp_group')
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
def test_reorg_kvcache_with_dcp_pcp(self, mock_dcp, mock_get_dcp_group,
|
||||
mock_pcp, mock_get_pcp_group):
|
||||
|
||||
def mock_all_gather(ws):
|
||||
return lambda tensor, dim: torch.cat([tensor] * ws, dim=dim)
|
||||
|
||||
BLOCK_SIZE = 128 # fixed
|
||||
max_model_len = 4096
|
||||
max_num_seqs = 25
|
||||
@@ -706,6 +713,12 @@ class TestAscendMLAImpl(TestBase):
|
||||
pcp_size, dcp_size, nums_tokens_per_rank, nums_all_rank_context, num_prefills, num_decodes, num_seqs, cp_local_block_size, num_computed_tokens_of_pcp_dcp = test_case
|
||||
if pcp_size * dcp_size == 1:
|
||||
continue
|
||||
self.impl.dcp_size = dcp_size
|
||||
self.impl.pcp_size = pcp_size
|
||||
mock_dcp.all_gather = MagicMock(
|
||||
side_effect=mock_all_gather(dcp_size))
|
||||
mock_pcp.all_gather = MagicMock(
|
||||
side_effect=mock_all_gather(pcp_size))
|
||||
chunked_prefill_workspace_size = min(
|
||||
max(8 * max_model_len, 4 * max_num_seqs * BLOCK_SIZE),
|
||||
128 * 1024)
|
||||
@@ -723,27 +736,21 @@ class TestAscendMLAImpl(TestBase):
|
||||
|
||||
for i in range(len(chunked_context.seq_tot)):
|
||||
allgatered_kv_c_normed = torch.randn(
|
||||
chunked_context.seq_tot[i] * pcp_size * dcp_size,
|
||||
self.impl.num_heads, self.impl.v_head_dim)
|
||||
allgatered_k_pe = torch.randn(
|
||||
chunked_context.seq_tot[i] * pcp_size * dcp_size,
|
||||
self.impl.num_heads, self.impl.qk_rope_head_dim)
|
||||
chunked_context.seq_tot[i], self.impl.num_heads,
|
||||
self.impl.kv_lora_rank)
|
||||
allgatered_k_pe = torch.randn(chunked_context.seq_tot[i],
|
||||
self.impl.num_heads,
|
||||
self.impl.qk_rope_head_dim)
|
||||
result_kv, result_k_pe = self.impl._reorg_kvcache(
|
||||
allgatered_kv_c_normed,
|
||||
allgatered_k_pe,
|
||||
padded_local_chunk_seq_lens_lst=chunked_context.
|
||||
padded_local_chunk_seq_lens[i],
|
||||
local_context_lens_allranks=chunked_context.
|
||||
local_context_lens_allranks,
|
||||
sum_seq_len=chunked_context.cu_seq_lens_lst[i][-1],
|
||||
max_seq_len=chunked_context.max_seq_lens[i],
|
||||
chunk_size=chunked_context.chunk_size,
|
||||
chunked_context,
|
||||
chunk_idx=i,
|
||||
toks=chunked_context.seq_tot[i],
|
||||
)
|
||||
self.assertEqual(result_kv.shape,
|
||||
(chunked_context.cu_seq_lens_lst[i][-1],
|
||||
self.impl.num_heads, self.impl.v_head_dim))
|
||||
self.impl.num_heads, self.impl.kv_lora_rank))
|
||||
self.assertEqual(
|
||||
result_k_pe.shape,
|
||||
(chunked_context.cu_seq_lens_lst[i][-1],
|
||||
@@ -754,6 +761,11 @@ class TestAscendMLAImpl(TestBase):
|
||||
self.assertEqual(result_k_pe.shape[0],
|
||||
chunked_context.cu_seq_lens_lst[i][-1])
|
||||
|
||||
self.assertEqual(mock_dcp.all_gather.call_count,
|
||||
(1 if dcp_size > 1 else 0))
|
||||
self.assertEqual(mock_pcp.all_gather.call_count,
|
||||
(1 if pcp_size > 1 else 0))
|
||||
|
||||
def test_out_lse_reshape(self):
|
||||
test_cases = [10, 1, 128, 512]
|
||||
for test_case in test_cases:
|
||||
@@ -1052,9 +1064,8 @@ class TestAscendMLAImpl(TestBase):
|
||||
attn_metadata.prefill.pcp_metadata.pcp_prefill_mask = torch.triu(
|
||||
torch.ones(10, 10, dtype=torch.float16), 1)
|
||||
|
||||
output = self.impl._forward_prefill_cp(q_nope, q_pe, k_nope,
|
||||
k_pe, value,
|
||||
kv_c_and_k_pe_cache,
|
||||
output = self.impl._forward_prefill(q_nope, q_pe, k_nope, k_pe,
|
||||
value, kv_c_and_k_pe_cache,
|
||||
attn_metadata)
|
||||
self.assertEqual(
|
||||
output.shape,
|
||||
|
||||
@@ -23,16 +23,12 @@ from vllm_ascend.attention.mla_v1 import (AscendMLADecodeMetadata,
|
||||
PrefillMLAPreprocessResult)
|
||||
#isort: on
|
||||
|
||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
maybe_save_kv_layer_to_connector,
|
||||
wait_for_kv_layer_from_connector)
|
||||
from vllm_ascend.attention.common_cp import AscendPCPMetadata, CPChunkedContextMetadata
|
||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata)
|
||||
from vllm_ascend.attention.common_cp import (AscendPCPMetadata,
|
||||
CPChunkedContextMetadata)
|
||||
from vllm_ascend.compilation.acl_graph import (get_graph_params,
|
||||
get_mtp_graph_params,
|
||||
update_graph_params_workspaces)
|
||||
from vllm_ascend.ops.shared_weight_layer import (
|
||||
is_hidden_layer, reach_layer_for_shared_weight_series)
|
||||
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||
from vllm_ascend.utils import weak_ref_tensors
|
||||
|
||||
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024
|
||||
@@ -197,8 +193,7 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
|
||||
+ self.
|
||||
num_prefills]
|
||||
|
||||
def set_decode_block_table(
|
||||
self, common_attn_metadata: AscendCommonAttentionMetadata):
|
||||
def set_decode_block_table(self):
|
||||
self.block_table = self.block_table[:self.num_decodes_flatten, ...]
|
||||
|
||||
def build_prefill_metadata(
|
||||
@@ -280,6 +275,12 @@ class AscendMlaCPImpl(AscendMLAImpl):
|
||||
self.dcp_group = get_dcp_group(
|
||||
).device_group if self.dcp_size > 1 else None
|
||||
|
||||
def get_num_actual_tokens(self, attn_metadata: M):
|
||||
if self.pcp_size > 1:
|
||||
return attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size
|
||||
else:
|
||||
return attn_metadata.num_actual_tokens
|
||||
|
||||
def _v_up_proj(self, x):
|
||||
# Convert from (B, N, L) to (N, B, L)
|
||||
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
||||
@@ -289,306 +290,23 @@ class AscendMlaCPImpl(AscendMLAImpl):
|
||||
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
|
||||
return x
|
||||
|
||||
def _compute_prefill_context(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: Tuple[torch.Tensor],
|
||||
rope_dim: int,
|
||||
attn_metadata: AscendMLAMetadata,
|
||||
prefix_output: torch.Tensor,
|
||||
prefix_lse: torch.Tensor,
|
||||
):
|
||||
assert len(kv_c_and_k_pe_cache) > 1
|
||||
prefill_metadata = attn_metadata.prefill
|
||||
if prefill_metadata is None or prefill_metadata.chunked_context is None:
|
||||
return prefix_output, prefix_lse
|
||||
|
||||
iters = len(prefill_metadata.chunked_context.seq_tot)
|
||||
|
||||
current_seq_len = torch.tensor(prefill_metadata.query_lens,
|
||||
dtype=torch.int32)
|
||||
cache_kv_c = kv_c_and_k_pe_cache[0]
|
||||
cache_k_pe = kv_c_and_k_pe_cache[1]
|
||||
num_heads = cache_k_pe.size(2)
|
||||
latent_kv_dim = kv_c_and_k_pe_cache[0].size(-1)
|
||||
for i in range(iters):
|
||||
toks = prefill_metadata.chunked_context.seq_tot[i]
|
||||
# chunk_seq_lens will be padded when pcp&dcp
|
||||
context_seq_len = prefill_metadata.chunked_context.chunk_seq_lens[
|
||||
i]
|
||||
context_seq_len_npu = prefill_metadata.chunked_context.padded_chunk_seq_lens_npu[
|
||||
i]
|
||||
seq_len = torch.stack([current_seq_len, context_seq_len])
|
||||
kv_c_normed = torch.empty(toks,
|
||||
num_heads,
|
||||
latent_kv_dim,
|
||||
dtype=q_nope.dtype,
|
||||
device=q_nope.device)
|
||||
k_pe = torch.empty(toks,
|
||||
num_heads,
|
||||
rope_dim,
|
||||
dtype=q_nope.dtype,
|
||||
device=q_nope.device)
|
||||
|
||||
torch_npu.atb.npu_paged_cache_load(
|
||||
cache_kv_c,
|
||||
cache_k_pe,
|
||||
prefill_metadata.block_table,
|
||||
context_seq_len_npu,
|
||||
seq_starts=prefill_metadata.chunked_context.starts[i],
|
||||
key=kv_c_normed,
|
||||
value=k_pe,
|
||||
)
|
||||
|
||||
cache_kv_c_k_pe = torch.cat([kv_c_normed, k_pe], dim=-1)
|
||||
if self.dcp_size > 1:
|
||||
cache_kv_c_k_pe = get_dcp_group().all_gather(
|
||||
cache_kv_c_k_pe, 0)
|
||||
|
||||
if self.pcp_size > 1:
|
||||
cache_kv_c_k_pe = get_pcp_group().all_gather(
|
||||
cache_kv_c_k_pe, 0)
|
||||
|
||||
allgatered_kv_c_normed, allgatered_k_pe = cache_kv_c_k_pe.split(
|
||||
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
kv_c_normed, k_pe = self._reorg_kvcache(
|
||||
allgatered_kv_c_normed,
|
||||
allgatered_k_pe,
|
||||
padded_local_chunk_seq_lens_lst=prefill_metadata.
|
||||
chunked_context.padded_local_chunk_seq_lens[i],
|
||||
local_context_lens_allranks=prefill_metadata.chunked_context.
|
||||
local_context_lens_allranks,
|
||||
sum_seq_len=prefill_metadata.chunked_context.cu_seq_lens_lst[i]
|
||||
[-1],
|
||||
max_seq_len=prefill_metadata.chunked_context.max_seq_lens[i],
|
||||
chunk_size=prefill_metadata.chunked_context.chunk_size,
|
||||
chunk_idx=i,
|
||||
toks=toks,
|
||||
)
|
||||
|
||||
kv_c_normed = kv_c_normed.squeeze()
|
||||
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(
|
||||
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||
k_nope, v = kv_nope \
|
||||
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
|
||||
|
||||
mask = attn_metadata.attn_mask
|
||||
torch_npu.atb.npu_ring_mla(
|
||||
q_nope=q_nope,
|
||||
q_rope=q_pe,
|
||||
k_nope=k_nope,
|
||||
k_rope=k_pe,
|
||||
value=v,
|
||||
mask=mask,
|
||||
seqlen=seq_len,
|
||||
head_num=self.num_heads,
|
||||
kv_head_num=self.num_heads,
|
||||
pre_out=prefix_output,
|
||||
prev_lse=prefix_lse,
|
||||
qk_scale=self.scale,
|
||||
kernel_type="kernel_type_high_precision",
|
||||
mask_type="no_mask",
|
||||
input_layout="type_bsnd",
|
||||
calc_type="calc_type_default",
|
||||
output=prefix_output,
|
||||
softmax_lse=prefix_lse)
|
||||
return prefix_output, prefix_lse
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer_name,
|
||||
hidden_states: torch.Tensor, # query in unified attn
|
||||
kv_cache: Tuple[torch.Tensor],
|
||||
attn_metadata: M,
|
||||
need_gather_q_kv: bool = False,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
if self.fc2_o_shared_enable and is_hidden_layer(
|
||||
self.vllm_config, self.o_proj):
|
||||
reach_layer_for_shared_weight_series(self.o_proj)
|
||||
return output.fill_(0)
|
||||
|
||||
forward_context = get_forward_context()
|
||||
|
||||
if self.pcp_size > 1:
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size
|
||||
else:
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
assert attn_metadata.num_decodes is not None and \
|
||||
attn_metadata.num_prefills is not None and \
|
||||
attn_metadata.num_decode_tokens is not None
|
||||
|
||||
has_prefill = attn_metadata.num_prefills > 0
|
||||
def mla_preprocess_prefill(self, q_c, kv_no_split, kv_cache,
|
||||
attn_metadata):
|
||||
if not self.pcp_size > 1:
|
||||
return super().mla_preprocess_prefill(q_c, kv_no_split, kv_cache,
|
||||
attn_metadata)
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
# Inputs and outputs may be padded for CUDA graphs
|
||||
output_padded = output
|
||||
o_proj_input_shape = (forward_context.num_tokens,
|
||||
self.num_heads * self.v_head_dim)
|
||||
o_proj_input = torch.empty(o_proj_input_shape,
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
|
||||
# MLA Preprocess
|
||||
if self.enable_mlapo and not has_prefill:
|
||||
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
hidden_states.contiguous(), need_gather_q_kv)
|
||||
decode_preprocess_res, prefill_preprocess_res = self._mla_decode_preprocess(
|
||||
hidden_states, kv_cache, attn_metadata)
|
||||
else:
|
||||
decode_preprocess_res, prefill_preprocess_res = self._mla_preprocess(
|
||||
layer_name, hidden_states, kv_cache, attn_metadata,
|
||||
need_gather_q_kv)
|
||||
|
||||
if decode_preprocess_res is not None:
|
||||
# MLA Preprocess for decoding
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
output_decode = self._forward_decode_pcp_dcp(
|
||||
decode_preprocess_res.ql_nope,
|
||||
decode_preprocess_res.q_pe,
|
||||
decode_preprocess_res.k_nope,
|
||||
decode_preprocess_res.k_pe,
|
||||
kv_cache[0].shape[1],
|
||||
attn_metadata,
|
||||
)
|
||||
else:
|
||||
output_decode = self._forward_decode(
|
||||
decode_preprocess_res.ql_nope, decode_preprocess_res.q_pe,
|
||||
decode_preprocess_res.k_nope, decode_preprocess_res.k_pe,
|
||||
kv_cache[0].shape[1], attn_metadata)
|
||||
|
||||
o_proj_input[:num_decode_tokens] = output_decode
|
||||
|
||||
if prefill_preprocess_res is not None:
|
||||
# FIX: aicore move should be also placed on the comm stream in dbo,
|
||||
# otherwise it may affect the accuracy
|
||||
# TODO: use an elegant way to overlap
|
||||
if self.pcp_size > 1:
|
||||
output_prefill = self._forward_prefill_cp(
|
||||
prefill_preprocess_res.q_nope, prefill_preprocess_res.q_pe,
|
||||
prefill_preprocess_res.k_nope, prefill_preprocess_res.k_pe,
|
||||
prefill_preprocess_res.value, kv_cache, attn_metadata)
|
||||
else:
|
||||
output_prefill = self._forward_prefill(
|
||||
prefill_preprocess_res.q_nope, prefill_preprocess_res.q_pe,
|
||||
prefill_preprocess_res.k_nope, prefill_preprocess_res.k_pe,
|
||||
prefill_preprocess_res.value, kv_cache, attn_metadata)
|
||||
|
||||
o_proj_input[num_decode_tokens:num_actual_tokens] = output_prefill
|
||||
# O proj
|
||||
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024
|
||||
maybe_npu_prefetch(inputs=self.o_proj.weight,
|
||||
dependency=o_proj_input,
|
||||
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
||||
enabled=self.enable_prefetch)
|
||||
|
||||
output[...] = self.o_proj(o_proj_input,
|
||||
is_prefill=(prefill_preprocess_res
|
||||
is not None))[0]
|
||||
|
||||
del o_proj_input
|
||||
|
||||
if has_prefill:
|
||||
maybe_save_kv_layer_to_connector(layer_name, list(kv_cache))
|
||||
return output_padded
|
||||
|
||||
def _mla_preprocess(self, layer_name, hidden_states, kv_cache,
|
||||
attn_metadata, need_gather_q_kv):
|
||||
# MLA Preprocess:
|
||||
# 1. Perform fused_qkv_a_proj and q_a_layernorm to obtain q_c and kv_no_split
|
||||
# or
|
||||
# Perform kv_a_proj_with_mqa to obtain kv_no_split
|
||||
# 2. If need_gather_q_kv, perform all_gather.
|
||||
# 3. Preprocess decode tokens, write kv cache and get:
|
||||
# decode_ql_nope, decode_q_pe, decode_k_pe, decode_k_nope
|
||||
# 4. Preprocess prefill tokens, write kv cache and get:
|
||||
# prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe, prefill_value
|
||||
has_decode = attn_metadata.num_decodes > 0
|
||||
has_prefill = attn_metadata.num_prefills > 0
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
if self.fused_qkv_a_proj is not None:
|
||||
maybe_npu_prefetch(inputs=self.fused_qkv_a_proj.weight,
|
||||
dependency=hidden_states,
|
||||
enabled=self.enable_prefetch)
|
||||
qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
|
||||
q_c, kv_no_split = qkv_lora.split(
|
||||
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
|
||||
dim=-1,
|
||||
)
|
||||
q_c = self.q_a_layernorm(q_c)
|
||||
# allgather need contiguous data
|
||||
kv_no_split = kv_no_split.contiguous()
|
||||
else:
|
||||
q_c = hidden_states
|
||||
kv_no_split = self.kv_a_proj_with_mqa(hidden_states)[0]
|
||||
|
||||
# Process for Flash Comm V1
|
||||
q_c = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
q_c.contiguous(), need_gather_q_kv)
|
||||
kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
kv_no_split.contiguous(), need_gather_q_kv)
|
||||
|
||||
if self.fc2_o_shared_enable and is_hidden_layer(
|
||||
self.vllm_config, self.o_proj):
|
||||
reach_layer_for_shared_weight_series(self.o_proj)
|
||||
|
||||
decode_preprocess_res = None
|
||||
prefill_preprocess_res = None
|
||||
if has_prefill:
|
||||
wait_for_kv_layer_from_connector(layer_name)
|
||||
# Preprocess for decode tokens
|
||||
if has_decode:
|
||||
decode_q_c = q_c[:num_decode_tokens]
|
||||
cos = attn_metadata.decode.cos
|
||||
sin = attn_metadata.decode.sin
|
||||
decode_ql_nope, decode_q_pe = \
|
||||
self._q_proj_and_k_up_proj(decode_q_c)
|
||||
if self.dcp_size > 1:
|
||||
decode_q_no_split = torch.cat([decode_ql_nope, decode_q_pe],
|
||||
dim=-1)
|
||||
decode_q_no_split = get_dcp_group().all_gather(
|
||||
decode_q_no_split, 1)
|
||||
decode_ql_nope, decode_q_pe = decode_q_no_split.split(
|
||||
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
|
||||
decode_slots = attn_metadata.slot_mapping[:num_decode_tokens *
|
||||
self.pcp_size:self.
|
||||
pcp_size]
|
||||
decode_kv_no_split = kv_no_split[:num_decode_tokens]
|
||||
decode_k_pe, decode_k_nope = self.exec_kv_decode(
|
||||
decode_kv_no_split, cos, sin, kv_cache, decode_slots)
|
||||
decode_preprocess_res = DecodeMLAPreprocessResult(
|
||||
decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe)
|
||||
# Preprocess for prefill tokens
|
||||
if has_prefill:
|
||||
if self.pcp_size > 1:
|
||||
num_actual_tokens = (attn_metadata.num_actual_tokens_pcp_padded
|
||||
- self.pcp_size * num_decode_tokens
|
||||
num_actual_tokens = (attn_metadata.num_actual_tokens_pcp_padded -
|
||||
self.pcp_size * num_decode_tokens
|
||||
) // self.pcp_size + num_decode_tokens
|
||||
prefill_kv_no_split = kv_no_split[
|
||||
num_decode_tokens:num_actual_tokens]
|
||||
prefill_q_c = q_c[num_decode_tokens:num_actual_tokens]
|
||||
prefill_q = self.q_proj(prefill_q_c)[0] \
|
||||
.view(-1, self.num_heads, self.qk_head_dim)
|
||||
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
|
||||
prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim]
|
||||
if self.pcp_size > 1:
|
||||
cos = attn_metadata.prefill.cos[:num_actual_tokens -
|
||||
num_decode_tokens]
|
||||
sin = attn_metadata.prefill.sin[:num_actual_tokens -
|
||||
num_decode_tokens]
|
||||
else:
|
||||
cos = attn_metadata.prefill.cos
|
||||
sin = attn_metadata.prefill.sin
|
||||
prefill_slots = attn_metadata.slot_mapping[
|
||||
num_decode_tokens:num_actual_tokens]
|
||||
cos = attn_metadata.prefill.cos[:num_actual_tokens - num_decode_tokens]
|
||||
sin = attn_metadata.prefill.sin[:num_actual_tokens - num_decode_tokens]
|
||||
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
|
||||
if self.pcp_size > 1:
|
||||
prefill_kv_no_split = kv_no_split[:num_actual_tokens]
|
||||
kv_c, k_pe = prefill_kv_no_split.split(
|
||||
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
@@ -600,18 +318,15 @@ class AscendMlaCPImpl(AscendMLAImpl):
|
||||
[num_actual_tokens, self.num_kv_heads, -1])
|
||||
k_pe = k_pe.unsqueeze(1)
|
||||
prefill_k_pe = k_pe
|
||||
prefill_k_pe[
|
||||
num_decode_tokens:num_actual_tokens] = self.rope_single(
|
||||
prefill_k_pe[num_decode_tokens:num_actual_tokens], cos,
|
||||
sin)
|
||||
prefill_k_pe[num_decode_tokens:num_actual_tokens] = self.rope_single(
|
||||
prefill_k_pe[num_decode_tokens:num_actual_tokens], cos, sin)
|
||||
prefill_k_c_normed = kv_c_normed[:num_actual_tokens]
|
||||
prefill_kv_c_k_pe = torch.cat(
|
||||
[prefill_k_c_normed, prefill_k_pe], dim=-1)
|
||||
prefill_kv_c_k_pe = get_pcp_group().all_gather(
|
||||
prefill_kv_c_k_pe, 0)
|
||||
prefill_kv_c_k_pe = torch.cat([prefill_k_c_normed, prefill_k_pe],
|
||||
dim=-1)
|
||||
prefill_kv_c_k_pe = get_pcp_group().all_gather(prefill_kv_c_k_pe, 0)
|
||||
prefill_kv_c_k_pe = torch.index_select(
|
||||
prefill_kv_c_k_pe, 0, attn_metadata.prefill.pcp_metadata.
|
||||
pcp_allgather_restore_idx)
|
||||
prefill_kv_c_k_pe, 0,
|
||||
attn_metadata.prefill.pcp_metadata.pcp_allgather_restore_idx)
|
||||
prefill_kv_c_k_pe = prefill_kv_c_k_pe[num_decode_tokens *
|
||||
self.pcp_size:]
|
||||
prefill_k_c_normed, prefill_k_pe = prefill_kv_c_k_pe.split(
|
||||
@@ -625,93 +340,57 @@ class AscendMlaCPImpl(AscendMLAImpl):
|
||||
key_cache=kv_cache[0],
|
||||
value_cache=kv_cache[1],
|
||||
slot_indices=slot_mapping)
|
||||
else:
|
||||
prefill_k_pe, prefill_k_c_normed = self.exec_kv_prefill(
|
||||
prefill_kv_no_split, cos, sin, kv_cache, prefill_slots)
|
||||
prefill_k_nope, prefill_value = self.kv_b_proj(
|
||||
prefill_k_c_normed)[0].view(
|
||||
-1, self.num_heads,
|
||||
self.qk_nope_head_dim + self.v_head_dim).split(
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
if not self.pcp_size > 1:
|
||||
prefill_k_pe = prefill_k_pe.view(prefill_q_c.shape[0],
|
||||
self.num_kv_heads, -1)
|
||||
prefill_k_pe = prefill_k_pe.expand(
|
||||
(*prefill_k_nope.shape[:-1], -1))
|
||||
prefill_preprocess_res = PrefillMLAPreprocessResult(
|
||||
prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe,
|
||||
prefill_k_pe = prefill_k_pe.expand((*prefill_k_nope.shape[:-1], -1))
|
||||
return PrefillMLAPreprocessResult(prefill_q_nope, prefill_q_pe,
|
||||
prefill_k_nope, prefill_k_pe,
|
||||
prefill_value)
|
||||
return decode_preprocess_res, prefill_preprocess_res
|
||||
|
||||
def _mla_decode_preprocess(self, hidden_states, kv_cache, attn_metadata):
|
||||
bsz = attn_metadata.num_decode_tokens
|
||||
hidden_states = hidden_states[:bsz]
|
||||
def mla_preprocess_decode(self, q_c, kv_no_split, kv_cache, attn_metadata):
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
decode_q_c = q_c[:num_decode_tokens]
|
||||
cos = attn_metadata.decode.cos
|
||||
sin = attn_metadata.decode.sin
|
||||
decode_ql_nope, decode_q_pe = \
|
||||
self._q_proj_and_k_up_proj(decode_q_c)
|
||||
decode_ql_nope, decode_q_pe = self.reorg_decode_q(
|
||||
decode_ql_nope, decode_q_pe)
|
||||
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
|
||||
decode_slots = attn_metadata.slot_mapping[:num_decode_tokens *
|
||||
self.pcp_size:self.pcp_size]
|
||||
decode_kv_no_split = kv_no_split[:num_decode_tokens]
|
||||
decode_k_pe, decode_k_nope = self.exec_kv_decode(
|
||||
decode_kv_no_split, cos, sin, kv_cache, decode_slots)
|
||||
return DecodeMLAPreprocessResult(decode_ql_nope, decode_q_pe,
|
||||
decode_k_nope, decode_k_pe)
|
||||
|
||||
cos_shape = attn_metadata.decode.cos.shape
|
||||
cos = attn_metadata.decode.cos.view(cos_shape[0], cos_shape[-1])
|
||||
sin = attn_metadata.decode.sin.view(cos_shape[0], cos_shape[-1])
|
||||
|
||||
decode_k_nope, decode_k_pe = kv_cache[0], kv_cache[1]
|
||||
decode_q_nope = torch.empty(
|
||||
(hidden_states.shape[0], self.W_UK_T.shape[0],
|
||||
decode_k_nope.shape[-1]),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
decode_q_pe = torch.empty(
|
||||
(hidden_states.shape[0], self.W_UK_T.shape[0],
|
||||
decode_k_pe.shape[-1]),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
|
||||
torch.ops._C_ascend.mla_preprocess(
|
||||
hidden_states,
|
||||
self.wd_qkv,
|
||||
self.deq_scale_qkv,
|
||||
self.gamma1,
|
||||
self.beta1,
|
||||
self.wu_q,
|
||||
self.qb_deq_scl,
|
||||
self.gamma2,
|
||||
cos,
|
||||
sin,
|
||||
self.W_UK_T,
|
||||
decode_k_nope,
|
||||
decode_k_pe,
|
||||
attn_metadata.slot_mapping[:bsz].flatten(),
|
||||
quant_scale0=self.quant_scale0,
|
||||
quant_offset0=self.quant_offset0,
|
||||
bias0=self.quant_bias_qkv,
|
||||
quant_scale1=self.quant_scale1,
|
||||
quant_offset1=self.quant_offset1,
|
||||
bias1=self.qb_qt_bias,
|
||||
ctkv_scale=self.ctkv_scale,
|
||||
q_nope_scale=self.q_nope_scale,
|
||||
cache_mode="krope_ctkv",
|
||||
quant_mode="per_tensor_quant_asymm",
|
||||
q_out0=decode_q_nope,
|
||||
kv_cache_out0=decode_k_nope,
|
||||
q_out1=decode_q_pe,
|
||||
kv_cache_out1=decode_k_pe,
|
||||
enable_inner_out=False,
|
||||
inner_out=torch.tensor([], device=hidden_states.device))
|
||||
decode_q_nope = decode_q_nope.view(bsz, self.num_heads,
|
||||
self.kv_lora_rank)
|
||||
decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1)
|
||||
def get_context_seq_len_npu(self, index: int,
|
||||
attn_metadata: AscendMLAMetadata):
|
||||
prefill_metadata = attn_metadata.prefill
|
||||
assert prefill_metadata is not None
|
||||
assert prefill_metadata.chunked_context is not None
|
||||
assert isinstance(prefill_metadata.chunked_context,
|
||||
CPChunkedContextMetadata)
|
||||
assert prefill_metadata.chunked_context.padded_chunk_seq_lens_npu is not None
|
||||
iters = len(prefill_metadata.chunked_context.seq_tot)
|
||||
assert 0 <= index < iters
|
||||
return prefill_metadata.chunked_context.padded_chunk_seq_lens_npu[
|
||||
index]
|
||||
|
||||
def reorg_decode_q(self, decode_q_nope, decode_q_pe):
|
||||
if self.dcp_size > 1:
|
||||
decode_q_no_split = torch.cat([decode_q_nope, decode_q_pe], dim=-1)
|
||||
decode_q_no_split = get_dcp_group().all_gather(
|
||||
decode_q_no_split, 1)
|
||||
decode_q_nope, decode_q_pe = decode_q_no_split.split(
|
||||
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
return decode_q_nope, decode_q_pe
|
||||
|
||||
decode_preprocess_res = DecodeMLAPreprocessResult(
|
||||
decode_q_nope, decode_q_pe, decode_k_nope, decode_k_pe)
|
||||
return decode_preprocess_res, None
|
||||
|
||||
def _forward_prefill_cp(
|
||||
def _forward_prefill(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
@@ -721,6 +400,9 @@ class AscendMlaCPImpl(AscendMLAImpl):
|
||||
kv_c_and_k_pe_cache: Tuple[torch.Tensor],
|
||||
attn_metadata: AscendMLAMetadata,
|
||||
) -> torch.Tensor:
|
||||
if not self.pcp_size > 1:
|
||||
return super()._forward_prefill(q_nope, q_pe, k_nope, k_pe, value,
|
||||
kv_c_and_k_pe_cache, attn_metadata)
|
||||
assert attn_metadata.prefill is not None
|
||||
assert attn_metadata.prefill.pcp_metadata is not None
|
||||
num_tokens = q_nope.size(0)
|
||||
@@ -840,7 +522,7 @@ class AscendMlaCPImpl(AscendMLAImpl):
|
||||
softmax_lse=attn_lse)
|
||||
return attn_output, attn_lse
|
||||
|
||||
def _forward_decode_pcp_dcp(
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
@@ -1014,13 +696,9 @@ class AscendMlaCPImpl(AscendMLAImpl):
|
||||
|
||||
def _reorg_kvcache(
|
||||
self,
|
||||
allgatered_kv_c_normed: torch.Tensor,
|
||||
allgatered_k_pe: torch.Tensor,
|
||||
padded_local_chunk_seq_lens_lst: list[int],
|
||||
local_context_lens_allranks: list[list[int]],
|
||||
sum_seq_len: int,
|
||||
max_seq_len: int,
|
||||
chunk_size: int,
|
||||
kv_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
chunked_context: CPChunkedContextMetadata,
|
||||
chunk_idx: int,
|
||||
toks: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
@@ -1044,6 +722,29 @@ class AscendMlaCPImpl(AscendMLAImpl):
|
||||
chunk_idx: chunk idx of chunked_prefill.
|
||||
toks: the number of tokens for local gather cache.
|
||||
"""
|
||||
assert chunked_context is not None
|
||||
assert chunked_context.padded_local_chunk_seq_lens is not None
|
||||
assert chunked_context.local_context_lens_allranks is not None
|
||||
assert chunked_context.cu_seq_lens_lst is not None
|
||||
assert chunked_context.max_seq_lens is not None
|
||||
assert chunked_context.chunk_size is not None
|
||||
|
||||
padded_local_chunk_seq_lens_lst = chunked_context.padded_local_chunk_seq_lens[
|
||||
chunk_idx]
|
||||
local_context_lens_allranks = chunked_context.local_context_lens_allranks
|
||||
sum_seq_len = chunked_context.cu_seq_lens_lst[chunk_idx][-1]
|
||||
max_seq_len = chunked_context.max_seq_lens[chunk_idx]
|
||||
chunk_size: int = chunked_context.chunk_size
|
||||
cache_kv_c_k_pe = torch.cat([kv_c_normed, k_pe], dim=-1)
|
||||
if self.dcp_size > 1:
|
||||
cache_kv_c_k_pe = get_dcp_group().all_gather(cache_kv_c_k_pe, 0)
|
||||
|
||||
if self.pcp_size > 1:
|
||||
cache_kv_c_k_pe = get_pcp_group().all_gather(cache_kv_c_k_pe, 0)
|
||||
|
||||
allgatered_kv_c_normed, allgatered_k_pe = cache_kv_c_k_pe.split(
|
||||
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
|
||||
kv_c_segments = []
|
||||
k_pe_segments = []
|
||||
src_token_idx = 0
|
||||
|
||||
@@ -503,8 +503,7 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
|
||||
common_attn_metadata.block_table_tensor[:common_attn_metadata.
|
||||
num_reqs])
|
||||
|
||||
def set_decode_block_table(
|
||||
self, common_attn_metadata: AscendCommonAttentionMetadata):
|
||||
def set_decode_block_table(self):
|
||||
self.block_table = self.block_table[:self.num_decodes, ...]
|
||||
|
||||
def build_prefill_metadata(
|
||||
@@ -564,7 +563,7 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
|
||||
self.seq_lens = self.seq_lens[:self.num_decodes]
|
||||
input_positions = input_positions[:self.num_decode_tokens]
|
||||
|
||||
self.set_decode_block_table(common_attn_metadata)
|
||||
self.set_decode_block_table()
|
||||
|
||||
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
|
||||
# NOTE: Maybe this block_table change can be removed when graph_pad_size > 1.
|
||||
@@ -895,6 +894,26 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
self.ctkv_scale = torch.tensor([1], dtype=act_dtype, device=device)
|
||||
self.q_nope_scale = torch.tensor([1], dtype=act_dtype, device=device)
|
||||
|
||||
def get_context_seq_len_npu(self, index: int,
|
||||
attn_metadata: AscendMLAMetadata):
|
||||
prefill_metadata = attn_metadata.prefill
|
||||
assert prefill_metadata is not None
|
||||
assert prefill_metadata.chunked_context is not None
|
||||
assert prefill_metadata.chunked_context.chunk_seq_lens_npu is not None
|
||||
iters = len(prefill_metadata.chunked_context.seq_tot)
|
||||
assert 0 <= index < iters
|
||||
return prefill_metadata.chunked_context.chunk_seq_lens_npu[index]
|
||||
|
||||
def _reorg_kvcache(
|
||||
self,
|
||||
kv_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
chunked_context: CPChunkedContextMetadata,
|
||||
chunk_idx: int,
|
||||
toks: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return kv_c_normed, k_pe
|
||||
|
||||
def _compute_prefill_context(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
@@ -923,9 +942,9 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
# chunk_seq_lens will be padded when pcp&dcp
|
||||
context_seq_len = prefill_metadata.chunked_context.chunk_seq_lens[
|
||||
i]
|
||||
context_seq_len_npu = prefill_metadata.chunked_context.chunk_seq_lens_npu[
|
||||
i]
|
||||
seq_len = torch.stack([current_seq_len, context_seq_len])
|
||||
context_seq_len_npu = self.get_context_seq_len_npu(
|
||||
i, attn_metadata)
|
||||
kv_c_normed = torch.empty(toks,
|
||||
num_heads,
|
||||
latent_kv_dim,
|
||||
@@ -946,7 +965,13 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
key=kv_c_normed,
|
||||
value=k_pe,
|
||||
)
|
||||
|
||||
kv_c_normed, k_pe = self._reorg_kvcache(
|
||||
kv_c_normed,
|
||||
k_pe,
|
||||
chunked_context=prefill_metadata.chunked_context,
|
||||
chunk_idx=i,
|
||||
toks=toks,
|
||||
)
|
||||
kv_c_normed = kv_c_normed.squeeze()
|
||||
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(
|
||||
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||
@@ -1210,7 +1235,11 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
|
||||
return self._v_up_proj(attn_output)
|
||||
|
||||
def _mla_decode_preprocess(self, hidden_states, kv_cache, attn_metadata):
|
||||
def reorg_decode_q(self, decode_q_nope, decode_q_pe):
|
||||
return decode_q_nope, decode_q_pe
|
||||
|
||||
def _mla_preprocess_only_decode(self, hidden_states, kv_cache,
|
||||
attn_metadata):
|
||||
bsz = attn_metadata.num_decode_tokens
|
||||
hidden_states = hidden_states[:bsz]
|
||||
|
||||
@@ -1267,10 +1296,57 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
self.kv_lora_rank)
|
||||
decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1)
|
||||
|
||||
decode_q_nope, decode_q_pe = self.reorg_decode_q(
|
||||
decode_q_nope, decode_q_pe)
|
||||
|
||||
decode_preprocess_res = DecodeMLAPreprocessResult(
|
||||
decode_q_nope, decode_q_pe, decode_k_nope, decode_k_pe)
|
||||
return decode_preprocess_res, None
|
||||
|
||||
def mla_preprocess_prefill(self, q_c, kv_no_split, kv_cache,
|
||||
attn_metadata):
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
prefill_kv_no_split = kv_no_split[num_decode_tokens:num_actual_tokens]
|
||||
prefill_q_c = q_c[num_decode_tokens:num_actual_tokens]
|
||||
prefill_q = self.q_proj(prefill_q_c)[0] \
|
||||
.view(-1, self.num_heads, self.qk_head_dim)
|
||||
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
|
||||
prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim]
|
||||
cos = attn_metadata.prefill.cos
|
||||
sin = attn_metadata.prefill.sin
|
||||
prefill_slots = attn_metadata.slot_mapping[
|
||||
num_decode_tokens:num_actual_tokens]
|
||||
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
|
||||
prefill_k_pe, prefill_k_c_normed = self.exec_kv_prefill(
|
||||
prefill_kv_no_split, cos, sin, kv_cache, prefill_slots)
|
||||
prefill_k_nope, prefill_value = self.kv_b_proj(
|
||||
prefill_k_c_normed)[0].view(
|
||||
-1, self.num_heads,
|
||||
self.qk_nope_head_dim + self.v_head_dim).split(
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
prefill_k_pe = prefill_k_pe.view(prefill_q_c.shape[0],
|
||||
self.num_kv_heads, -1)
|
||||
prefill_k_pe = prefill_k_pe.expand((*prefill_k_nope.shape[:-1], -1))
|
||||
return PrefillMLAPreprocessResult(prefill_q_nope, prefill_q_pe,
|
||||
prefill_k_nope, prefill_k_pe,
|
||||
prefill_value)
|
||||
|
||||
def mla_preprocess_decode(self, q_c, kv_no_split, kv_cache, attn_metadata):
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
decode_q_c = q_c[:num_decode_tokens]
|
||||
cos = attn_metadata.decode.cos
|
||||
sin = attn_metadata.decode.sin
|
||||
decode_ql_nope, decode_q_pe = \
|
||||
self._q_proj_and_k_up_proj(decode_q_c)
|
||||
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
|
||||
decode_slots = attn_metadata.slot_mapping[:num_decode_tokens:1]
|
||||
decode_kv_no_split = kv_no_split[:num_decode_tokens]
|
||||
decode_k_pe, decode_k_nope = self.exec_kv_decode(
|
||||
decode_kv_no_split, cos, sin, kv_cache, decode_slots)
|
||||
return DecodeMLAPreprocessResult(decode_ql_nope, decode_q_pe,
|
||||
decode_k_nope, decode_k_pe)
|
||||
|
||||
def _mla_preprocess(self, layer_name, hidden_states, kv_cache,
|
||||
attn_metadata, need_gather_q_kv):
|
||||
# MLA Preprocess:
|
||||
@@ -1284,8 +1360,6 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
# prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe, prefill_value
|
||||
has_decode = attn_metadata.num_decodes > 0
|
||||
has_prefill = attn_metadata.num_prefills > 0
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
if self.fused_qkv_a_proj is not None:
|
||||
maybe_npu_prefetch(inputs=self.fused_qkv_a_proj.weight,
|
||||
dependency=hidden_states,
|
||||
@@ -1318,48 +1392,17 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
wait_for_kv_layer_from_connector(layer_name)
|
||||
# Preprocess for decode tokens
|
||||
if has_decode:
|
||||
decode_q_c = q_c[:num_decode_tokens]
|
||||
cos = attn_metadata.decode.cos
|
||||
sin = attn_metadata.decode.sin
|
||||
decode_ql_nope, decode_q_pe = \
|
||||
self._q_proj_and_k_up_proj(decode_q_c)
|
||||
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
|
||||
decode_slots = attn_metadata.slot_mapping[:num_decode_tokens:1]
|
||||
decode_kv_no_split = kv_no_split[:num_decode_tokens]
|
||||
decode_k_pe, decode_k_nope = self.exec_kv_decode(
|
||||
decode_kv_no_split, cos, sin, kv_cache, decode_slots)
|
||||
decode_preprocess_res = DecodeMLAPreprocessResult(
|
||||
decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe)
|
||||
decode_preprocess_res = self.mla_preprocess_decode(
|
||||
q_c, kv_no_split, kv_cache, attn_metadata)
|
||||
# Preprocess for prefill tokens
|
||||
if has_prefill:
|
||||
prefill_kv_no_split = kv_no_split[
|
||||
num_decode_tokens:num_actual_tokens]
|
||||
prefill_q_c = q_c[num_decode_tokens:num_actual_tokens]
|
||||
prefill_q = self.q_proj(prefill_q_c)[0] \
|
||||
.view(-1, self.num_heads, self.qk_head_dim)
|
||||
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
|
||||
prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim]
|
||||
cos = attn_metadata.prefill.cos
|
||||
sin = attn_metadata.prefill.sin
|
||||
prefill_slots = attn_metadata.slot_mapping[
|
||||
num_decode_tokens:num_actual_tokens]
|
||||
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
|
||||
prefill_k_pe, prefill_k_c_normed = self.exec_kv_prefill(
|
||||
prefill_kv_no_split, cos, sin, kv_cache, prefill_slots)
|
||||
prefill_k_nope, prefill_value = self.kv_b_proj(
|
||||
prefill_k_c_normed)[0].view(
|
||||
-1, self.num_heads,
|
||||
self.qk_nope_head_dim + self.v_head_dim).split(
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
prefill_k_pe = prefill_k_pe.view(prefill_q_c.shape[0],
|
||||
self.num_kv_heads, -1)
|
||||
prefill_k_pe = prefill_k_pe.expand(
|
||||
(*prefill_k_nope.shape[:-1], -1))
|
||||
prefill_preprocess_res = PrefillMLAPreprocessResult(
|
||||
prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe,
|
||||
prefill_value)
|
||||
prefill_preprocess_res = self.mla_preprocess_prefill(
|
||||
q_c, kv_no_split, kv_cache, attn_metadata)
|
||||
return decode_preprocess_res, prefill_preprocess_res
|
||||
|
||||
def get_num_actual_tokens(self, attn_metadata: M):
|
||||
return attn_metadata.num_actual_tokens
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer_name,
|
||||
@@ -1378,7 +1421,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
return output.fill_(0)
|
||||
|
||||
forward_context = get_forward_context()
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
num_actual_tokens = self.get_num_actual_tokens(attn_metadata)
|
||||
assert attn_metadata.num_decodes is not None and \
|
||||
attn_metadata.num_prefills is not None and \
|
||||
attn_metadata.num_decode_tokens is not None
|
||||
@@ -1397,13 +1440,12 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
if self.enable_mlapo and not has_prefill:
|
||||
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
hidden_states.contiguous(), need_gather_q_kv)
|
||||
decode_preprocess_res, prefill_preprocess_res = self._mla_decode_preprocess(
|
||||
decode_preprocess_res, prefill_preprocess_res = self._mla_preprocess_only_decode(
|
||||
hidden_states, kv_cache, attn_metadata)
|
||||
else:
|
||||
decode_preprocess_res, prefill_preprocess_res = self._mla_preprocess(
|
||||
layer_name, hidden_states, kv_cache, attn_metadata,
|
||||
need_gather_q_kv)
|
||||
|
||||
if decode_preprocess_res is not None:
|
||||
# MLA Preprocess for decoding
|
||||
output_decode = self._forward_decode(decode_preprocess_res.ql_nope,
|
||||
|
||||
Reference in New Issue
Block a user