From 23169021d9f5d20649061ab8f9ab734d7d00dd1b Mon Sep 17 00:00:00 2001 From: wujinyuan1 Date: Sun, 28 Dec 2025 10:40:45 +0800 Subject: [PATCH] [Refactor]6/N Extract common code of class AscendMLAImpl (#5314) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit RFC: https://github.com/vllm-project/vllm-ascend/issues/4629 Reason: Eliminate duplicate code for two file(mla_v1.py mla_cp.py) of IMPL classes. vLLM version: 0.13.0rc3 vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 - vLLM version: release/v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/5fbfa8d9ef15948599631baeb91e8220b2ee9bcc --------- Signed-off-by: wujinyuan1 Co-authored-by: wujinyuan1 Co-authored-by: weijinqian0 <1184188277@qq.com> --- tests/ut/attention/test_mla_cp.py | 79 +++-- vllm_ascend/attention/mla_cp.py | 561 +++++++----------------------- vllm_ascend/attention/mla_v1.py | 142 +++++--- 3 files changed, 268 insertions(+), 514 deletions(-) diff --git a/tests/ut/attention/test_mla_cp.py b/tests/ut/attention/test_mla_cp.py index a7597af8..a40662d6 100755 --- a/tests/ut/attention/test_mla_cp.py +++ b/tests/ut/attention/test_mla_cp.py @@ -254,7 +254,7 @@ class TestAscendMLAImpl(TestBase): @patch('vllm_ascend.attention.mla_cp.get_dcp_group') @patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad") - @patch("vllm_ascend.attention.mla_cp.maybe_npu_prefetch") + @patch("vllm_ascend.attention.mla_v1.maybe_npu_prefetch") def test_mla_preprocess_dcp(self, magic_npu_fetch, mock_maybe_all_gather_and_maybe_unpad, mock_get_dcp_group): @@ -339,7 +339,7 @@ class TestAscendMLAImpl(TestBase): @patch('torch_npu._npu_reshape_and_cache') @patch('vllm_ascend.attention.mla_cp.get_pcp_group') @patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad") - @patch("vllm_ascend.attention.mla_cp.maybe_npu_prefetch") + @patch("vllm_ascend.attention.mla_v1.maybe_npu_prefetch") def test_mla_preprocess_pcp(self, magic_npu_fetch, mock_maybe_all_gather_and_maybe_unpad, mock_get_pcp_group, @@ -543,8 +543,8 @@ class TestAscendMLAImpl(TestBase): self.impl._v_up_proj.return_value = torch.randn( B, self.impl.v_head_dim) - result = self.impl._forward_decode_pcp_dcp(q_nope, q_pe, k_nope, k_pe, - BS, attn_metadata) + result = self.impl._forward_decode(q_nope, q_pe, k_nope, k_pe, BS, + attn_metadata) self.assertEqual(result.shape[0], B) self.assertEqual(result.shape[1], self.impl.v_head_dim) @@ -578,14 +578,14 @@ class TestAscendMLAImpl(TestBase): def mock_reorg_kvcache(allgatered_kv_c_normed: torch.Tensor, allgatered_k_pe: torch.Tensor, - padded_local_chunk_seq_lens_lst: list[int], - local_context_lens_allranks: list[list[int]], - sum_seq_len: int, max_seq_len: int, - chunk_size: int, chunk_idx: int, toks: int): - return torch.randn(sum_seq_len, allgatered_kv_c_normed.shape[1], - allgatered_kv_c_normed.shape[2]), torch.randn( - sum_seq_len, allgatered_k_pe.shape[1], - allgatered_k_pe.shape[2]) + chunked_context: CPChunkedContextMetadata, + chunk_idx: int, toks: int): + return torch.randn( + chunked_context.cu_seq_lens_lst[chunk_idx][-1], + allgatered_kv_c_normed.shape[1], + allgatered_kv_c_normed.shape[2]), torch.randn( + chunked_context.cu_seq_lens_lst[chunk_idx][-1], + allgatered_k_pe.shape[1], allgatered_k_pe.shape[2]) # mock proj self.impl.kv_b_proj.side_effect = mock_kv_b_proj @@ -679,10 +679,6 @@ class TestAscendMLAImpl(TestBase): iters * (1 if dcp_size * pcp_size > 1 else 0)) self.assertEqual(mock_load.call_count, iters) self.assertEqual(mock_ring.call_count, iters) - self.assertEqual(mock_dcp.all_gather.call_count, - (1 if dcp_size > 1 else 0)) - self.assertEqual(mock_pcp.all_gather.call_count, - iters * (1 if pcp_size > 1 else 0)) mock_reorg.reset_mock() mock_load.reset_mock() mock_ring.reset_mock() @@ -691,7 +687,18 @@ class TestAscendMLAImpl(TestBase): self.assertEqual(out.shape, prefix_out.shape) self.assertEqual(lse.shape, prefix_lse.shape) - def test_reorg_kvcache_with_dcp_pcp(self): + @patch('vllm.distributed.parallel_state.get_pcp_group') + @patch('vllm.distributed.parallel_state._PCP', + new_callable=lambda: MagicMock(spec=GroupCoordinator)) + @patch('vllm.distributed.parallel_state.get_dcp_group') + @patch('vllm.distributed.parallel_state._DCP', + new_callable=lambda: MagicMock(spec=GroupCoordinator)) + def test_reorg_kvcache_with_dcp_pcp(self, mock_dcp, mock_get_dcp_group, + mock_pcp, mock_get_pcp_group): + + def mock_all_gather(ws): + return lambda tensor, dim: torch.cat([tensor] * ws, dim=dim) + BLOCK_SIZE = 128 # fixed max_model_len = 4096 max_num_seqs = 25 @@ -706,6 +713,12 @@ class TestAscendMLAImpl(TestBase): pcp_size, dcp_size, nums_tokens_per_rank, nums_all_rank_context, num_prefills, num_decodes, num_seqs, cp_local_block_size, num_computed_tokens_of_pcp_dcp = test_case if pcp_size * dcp_size == 1: continue + self.impl.dcp_size = dcp_size + self.impl.pcp_size = pcp_size + mock_dcp.all_gather = MagicMock( + side_effect=mock_all_gather(dcp_size)) + mock_pcp.all_gather = MagicMock( + side_effect=mock_all_gather(pcp_size)) chunked_prefill_workspace_size = min( max(8 * max_model_len, 4 * max_num_seqs * BLOCK_SIZE), 128 * 1024) @@ -723,27 +736,21 @@ class TestAscendMLAImpl(TestBase): for i in range(len(chunked_context.seq_tot)): allgatered_kv_c_normed = torch.randn( - chunked_context.seq_tot[i] * pcp_size * dcp_size, - self.impl.num_heads, self.impl.v_head_dim) - allgatered_k_pe = torch.randn( - chunked_context.seq_tot[i] * pcp_size * dcp_size, - self.impl.num_heads, self.impl.qk_rope_head_dim) + chunked_context.seq_tot[i], self.impl.num_heads, + self.impl.kv_lora_rank) + allgatered_k_pe = torch.randn(chunked_context.seq_tot[i], + self.impl.num_heads, + self.impl.qk_rope_head_dim) result_kv, result_k_pe = self.impl._reorg_kvcache( allgatered_kv_c_normed, allgatered_k_pe, - padded_local_chunk_seq_lens_lst=chunked_context. - padded_local_chunk_seq_lens[i], - local_context_lens_allranks=chunked_context. - local_context_lens_allranks, - sum_seq_len=chunked_context.cu_seq_lens_lst[i][-1], - max_seq_len=chunked_context.max_seq_lens[i], - chunk_size=chunked_context.chunk_size, + chunked_context, chunk_idx=i, toks=chunked_context.seq_tot[i], ) self.assertEqual(result_kv.shape, (chunked_context.cu_seq_lens_lst[i][-1], - self.impl.num_heads, self.impl.v_head_dim)) + self.impl.num_heads, self.impl.kv_lora_rank)) self.assertEqual( result_k_pe.shape, (chunked_context.cu_seq_lens_lst[i][-1], @@ -754,6 +761,11 @@ class TestAscendMLAImpl(TestBase): self.assertEqual(result_k_pe.shape[0], chunked_context.cu_seq_lens_lst[i][-1]) + self.assertEqual(mock_dcp.all_gather.call_count, + (1 if dcp_size > 1 else 0)) + self.assertEqual(mock_pcp.all_gather.call_count, + (1 if pcp_size > 1 else 0)) + def test_out_lse_reshape(self): test_cases = [10, 1, 128, 512] for test_case in test_cases: @@ -1052,10 +1064,9 @@ class TestAscendMLAImpl(TestBase): attn_metadata.prefill.pcp_metadata.pcp_prefill_mask = torch.triu( torch.ones(10, 10, dtype=torch.float16), 1) - output = self.impl._forward_prefill_cp(q_nope, q_pe, k_nope, - k_pe, value, - kv_c_and_k_pe_cache, - attn_metadata) + output = self.impl._forward_prefill(q_nope, q_pe, k_nope, k_pe, + value, kv_c_and_k_pe_cache, + attn_metadata) self.assertEqual( output.shape, (seq_len_q, self.impl.num_heads * self.impl.v_head_dim)) diff --git a/vllm_ascend/attention/mla_cp.py b/vllm_ascend/attention/mla_cp.py index 4ce90cb1..30a6c594 100644 --- a/vllm_ascend/attention/mla_cp.py +++ b/vllm_ascend/attention/mla_cp.py @@ -23,16 +23,12 @@ from vllm_ascend.attention.mla_v1 import (AscendMLADecodeMetadata, PrefillMLAPreprocessResult) #isort: on -from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, - maybe_save_kv_layer_to_connector, - wait_for_kv_layer_from_connector) -from vllm_ascend.attention.common_cp import AscendPCPMetadata, CPChunkedContextMetadata +from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata) +from vllm_ascend.attention.common_cp import (AscendPCPMetadata, + CPChunkedContextMetadata) from vllm_ascend.compilation.acl_graph import (get_graph_params, get_mtp_graph_params, update_graph_params_workspaces) -from vllm_ascend.ops.shared_weight_layer import ( - is_hidden_layer, reach_layer_for_shared_weight_series) -from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch from vllm_ascend.utils import weak_ref_tensors MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 @@ -197,8 +193,7 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder): + self. num_prefills] - def set_decode_block_table( - self, common_attn_metadata: AscendCommonAttentionMetadata): + def set_decode_block_table(self): self.block_table = self.block_table[:self.num_decodes_flatten, ...] def build_prefill_metadata( @@ -280,6 +275,12 @@ class AscendMlaCPImpl(AscendMLAImpl): self.dcp_group = get_dcp_group( ).device_group if self.dcp_size > 1 else None + def get_num_actual_tokens(self, attn_metadata: M): + if self.pcp_size > 1: + return attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size + else: + return attn_metadata.num_actual_tokens + def _v_up_proj(self, x): # Convert from (B, N, L) to (N, B, L) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) @@ -289,429 +290,107 @@ class AscendMlaCPImpl(AscendMLAImpl): x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) return x - def _compute_prefill_context( - self, - q_nope: torch.Tensor, - q_pe: torch.Tensor, - kv_c_and_k_pe_cache: Tuple[torch.Tensor], - rope_dim: int, - attn_metadata: AscendMLAMetadata, - prefix_output: torch.Tensor, - prefix_lse: torch.Tensor, - ): - assert len(kv_c_and_k_pe_cache) > 1 + def mla_preprocess_prefill(self, q_c, kv_no_split, kv_cache, + attn_metadata): + if not self.pcp_size > 1: + return super().mla_preprocess_prefill(q_c, kv_no_split, kv_cache, + attn_metadata) + num_decode_tokens = attn_metadata.num_decode_tokens + num_actual_tokens = (attn_metadata.num_actual_tokens_pcp_padded - + self.pcp_size * num_decode_tokens + ) // self.pcp_size + num_decode_tokens + prefill_q_c = q_c[num_decode_tokens:num_actual_tokens] + prefill_q = self.q_proj(prefill_q_c)[0] \ + .view(-1, self.num_heads, self.qk_head_dim) + prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:] + prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim] + cos = attn_metadata.prefill.cos[:num_actual_tokens - num_decode_tokens] + sin = attn_metadata.prefill.sin[:num_actual_tokens - num_decode_tokens] + prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin) + prefill_kv_no_split = kv_no_split[:num_actual_tokens] + kv_c, k_pe = prefill_kv_no_split.split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) + assert len( + kv_cache + ) > 1, "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)" + kv_c_normed = kv_c_normed.view( + [num_actual_tokens, self.num_kv_heads, -1]) + k_pe = k_pe.unsqueeze(1) + prefill_k_pe = k_pe + prefill_k_pe[num_decode_tokens:num_actual_tokens] = self.rope_single( + prefill_k_pe[num_decode_tokens:num_actual_tokens], cos, sin) + prefill_k_c_normed = kv_c_normed[:num_actual_tokens] + prefill_kv_c_k_pe = torch.cat([prefill_k_c_normed, prefill_k_pe], + dim=-1) + prefill_kv_c_k_pe = get_pcp_group().all_gather(prefill_kv_c_k_pe, 0) + prefill_kv_c_k_pe = torch.index_select( + prefill_kv_c_k_pe, 0, + attn_metadata.prefill.pcp_metadata.pcp_allgather_restore_idx) + prefill_kv_c_k_pe = prefill_kv_c_k_pe[num_decode_tokens * + self.pcp_size:] + prefill_k_c_normed, prefill_k_pe = prefill_kv_c_k_pe.split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + kv_c_normed, k_pe = prefill_k_c_normed, prefill_k_pe + prefill_k_c_normed = prefill_k_c_normed.squeeze() + slot_mapping = attn_metadata.slot_mapping[self.pcp_size * + num_decode_tokens:] + torch_npu._npu_reshape_and_cache(key=kv_c_normed, + value=k_pe, + key_cache=kv_cache[0], + value_cache=kv_cache[1], + slot_indices=slot_mapping) + prefill_k_nope, prefill_value = self.kv_b_proj( + prefill_k_c_normed)[0].view( + -1, self.num_heads, + self.qk_nope_head_dim + self.v_head_dim).split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + prefill_k_pe = prefill_k_pe.expand((*prefill_k_nope.shape[:-1], -1)) + return PrefillMLAPreprocessResult(prefill_q_nope, prefill_q_pe, + prefill_k_nope, prefill_k_pe, + prefill_value) + + def mla_preprocess_decode(self, q_c, kv_no_split, kv_cache, attn_metadata): + num_decode_tokens = attn_metadata.num_decode_tokens + decode_q_c = q_c[:num_decode_tokens] + cos = attn_metadata.decode.cos + sin = attn_metadata.decode.sin + decode_ql_nope, decode_q_pe = \ + self._q_proj_and_k_up_proj(decode_q_c) + decode_ql_nope, decode_q_pe = self.reorg_decode_q( + decode_ql_nope, decode_q_pe) + decode_q_pe = self.rope_single(decode_q_pe, cos, sin) + decode_slots = attn_metadata.slot_mapping[:num_decode_tokens * + self.pcp_size:self.pcp_size] + decode_kv_no_split = kv_no_split[:num_decode_tokens] + decode_k_pe, decode_k_nope = self.exec_kv_decode( + decode_kv_no_split, cos, sin, kv_cache, decode_slots) + return DecodeMLAPreprocessResult(decode_ql_nope, decode_q_pe, + decode_k_nope, decode_k_pe) + + def get_context_seq_len_npu(self, index: int, + attn_metadata: AscendMLAMetadata): prefill_metadata = attn_metadata.prefill - if prefill_metadata is None or prefill_metadata.chunked_context is None: - return prefix_output, prefix_lse - + assert prefill_metadata is not None + assert prefill_metadata.chunked_context is not None + assert isinstance(prefill_metadata.chunked_context, + CPChunkedContextMetadata) + assert prefill_metadata.chunked_context.padded_chunk_seq_lens_npu is not None iters = len(prefill_metadata.chunked_context.seq_tot) + assert 0 <= index < iters + return prefill_metadata.chunked_context.padded_chunk_seq_lens_npu[ + index] - current_seq_len = torch.tensor(prefill_metadata.query_lens, - dtype=torch.int32) - cache_kv_c = kv_c_and_k_pe_cache[0] - cache_k_pe = kv_c_and_k_pe_cache[1] - num_heads = cache_k_pe.size(2) - latent_kv_dim = kv_c_and_k_pe_cache[0].size(-1) - for i in range(iters): - toks = prefill_metadata.chunked_context.seq_tot[i] - # chunk_seq_lens will be padded when pcp&dcp - context_seq_len = prefill_metadata.chunked_context.chunk_seq_lens[ - i] - context_seq_len_npu = prefill_metadata.chunked_context.padded_chunk_seq_lens_npu[ - i] - seq_len = torch.stack([current_seq_len, context_seq_len]) - kv_c_normed = torch.empty(toks, - num_heads, - latent_kv_dim, - dtype=q_nope.dtype, - device=q_nope.device) - k_pe = torch.empty(toks, - num_heads, - rope_dim, - dtype=q_nope.dtype, - device=q_nope.device) - - torch_npu.atb.npu_paged_cache_load( - cache_kv_c, - cache_k_pe, - prefill_metadata.block_table, - context_seq_len_npu, - seq_starts=prefill_metadata.chunked_context.starts[i], - key=kv_c_normed, - value=k_pe, - ) - - cache_kv_c_k_pe = torch.cat([kv_c_normed, k_pe], dim=-1) - if self.dcp_size > 1: - cache_kv_c_k_pe = get_dcp_group().all_gather( - cache_kv_c_k_pe, 0) - - if self.pcp_size > 1: - cache_kv_c_k_pe = get_pcp_group().all_gather( - cache_kv_c_k_pe, 0) - - allgatered_kv_c_normed, allgatered_k_pe = cache_kv_c_k_pe.split( - [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - kv_c_normed, k_pe = self._reorg_kvcache( - allgatered_kv_c_normed, - allgatered_k_pe, - padded_local_chunk_seq_lens_lst=prefill_metadata. - chunked_context.padded_local_chunk_seq_lens[i], - local_context_lens_allranks=prefill_metadata.chunked_context. - local_context_lens_allranks, - sum_seq_len=prefill_metadata.chunked_context.cu_seq_lens_lst[i] - [-1], - max_seq_len=prefill_metadata.chunked_context.max_seq_lens[i], - chunk_size=prefill_metadata.chunked_context.chunk_size, - chunk_idx=i, - toks=toks, - ) - - kv_c_normed = kv_c_normed.squeeze() - kv_nope = self.kv_b_proj(kv_c_normed)[0].view( - -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - k_nope, v = kv_nope \ - .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k_pe = k_pe.expand((*k_nope.shape[:-1], -1)) - - mask = attn_metadata.attn_mask - torch_npu.atb.npu_ring_mla( - q_nope=q_nope, - q_rope=q_pe, - k_nope=k_nope, - k_rope=k_pe, - value=v, - mask=mask, - seqlen=seq_len, - head_num=self.num_heads, - kv_head_num=self.num_heads, - pre_out=prefix_output, - prev_lse=prefix_lse, - qk_scale=self.scale, - kernel_type="kernel_type_high_precision", - mask_type="no_mask", - input_layout="type_bsnd", - calc_type="calc_type_default", - output=prefix_output, - softmax_lse=prefix_lse) - return prefix_output, prefix_lse - - def forward( - self, - layer_name, - hidden_states: torch.Tensor, # query in unified attn - kv_cache: Tuple[torch.Tensor], - attn_metadata: M, - need_gather_q_kv: bool = False, - output: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - assert output is not None, "Output tensor must be provided." - if attn_metadata is None: - # Profiling run. - if self.fc2_o_shared_enable and is_hidden_layer( - self.vllm_config, self.o_proj): - reach_layer_for_shared_weight_series(self.o_proj) - return output.fill_(0) - - forward_context = get_forward_context() - - if self.pcp_size > 1: - num_actual_tokens = attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size - else: - num_actual_tokens = attn_metadata.num_actual_tokens - assert attn_metadata.num_decodes is not None and \ - attn_metadata.num_prefills is not None and \ - attn_metadata.num_decode_tokens is not None - - has_prefill = attn_metadata.num_prefills > 0 - num_decode_tokens = attn_metadata.num_decode_tokens - # Inputs and outputs may be padded for CUDA graphs - output_padded = output - o_proj_input_shape = (forward_context.num_tokens, - self.num_heads * self.v_head_dim) - o_proj_input = torch.empty(o_proj_input_shape, - dtype=hidden_states.dtype, - device=hidden_states.device) - - # MLA Preprocess - if self.enable_mlapo and not has_prefill: - hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( - hidden_states.contiguous(), need_gather_q_kv) - decode_preprocess_res, prefill_preprocess_res = self._mla_decode_preprocess( - hidden_states, kv_cache, attn_metadata) - else: - decode_preprocess_res, prefill_preprocess_res = self._mla_preprocess( - layer_name, hidden_states, kv_cache, attn_metadata, - need_gather_q_kv) - - if decode_preprocess_res is not None: - # MLA Preprocess for decoding - if self.pcp_size * self.dcp_size > 1: - output_decode = self._forward_decode_pcp_dcp( - decode_preprocess_res.ql_nope, - decode_preprocess_res.q_pe, - decode_preprocess_res.k_nope, - decode_preprocess_res.k_pe, - kv_cache[0].shape[1], - attn_metadata, - ) - else: - output_decode = self._forward_decode( - decode_preprocess_res.ql_nope, decode_preprocess_res.q_pe, - decode_preprocess_res.k_nope, decode_preprocess_res.k_pe, - kv_cache[0].shape[1], attn_metadata) - - o_proj_input[:num_decode_tokens] = output_decode - - if prefill_preprocess_res is not None: - # FIX: aicore move should be also placed on the comm stream in dbo, - # otherwise it may affect the accuracy - # TODO: use an elegant way to overlap - if self.pcp_size > 1: - output_prefill = self._forward_prefill_cp( - prefill_preprocess_res.q_nope, prefill_preprocess_res.q_pe, - prefill_preprocess_res.k_nope, prefill_preprocess_res.k_pe, - prefill_preprocess_res.value, kv_cache, attn_metadata) - else: - output_prefill = self._forward_prefill( - prefill_preprocess_res.q_nope, prefill_preprocess_res.q_pe, - prefill_preprocess_res.k_nope, prefill_preprocess_res.k_pe, - prefill_preprocess_res.value, kv_cache, attn_metadata) - - o_proj_input[num_decode_tokens:num_actual_tokens] = output_prefill - # O proj - MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 - maybe_npu_prefetch(inputs=self.o_proj.weight, - dependency=o_proj_input, - max_size=MAX_O_PROJ_PREFETCH_SIZE, - enabled=self.enable_prefetch) - - output[...] = self.o_proj(o_proj_input, - is_prefill=(prefill_preprocess_res - is not None))[0] - - del o_proj_input - - if has_prefill: - maybe_save_kv_layer_to_connector(layer_name, list(kv_cache)) - return output_padded - - def _mla_preprocess(self, layer_name, hidden_states, kv_cache, - attn_metadata, need_gather_q_kv): - # MLA Preprocess: - # 1. Perform fused_qkv_a_proj and q_a_layernorm to obtain q_c and kv_no_split - # or - # Perform kv_a_proj_with_mqa to obtain kv_no_split - # 2. If need_gather_q_kv, perform all_gather. - # 3. Preprocess decode tokens, write kv cache and get: - # decode_ql_nope, decode_q_pe, decode_k_pe, decode_k_nope - # 4. Preprocess prefill tokens, write kv cache and get: - # prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe, prefill_value - has_decode = attn_metadata.num_decodes > 0 - has_prefill = attn_metadata.num_prefills > 0 - num_decode_tokens = attn_metadata.num_decode_tokens - num_actual_tokens = attn_metadata.num_actual_tokens - if self.fused_qkv_a_proj is not None: - maybe_npu_prefetch(inputs=self.fused_qkv_a_proj.weight, - dependency=hidden_states, - enabled=self.enable_prefetch) - qkv_lora = self.fused_qkv_a_proj(hidden_states)[0] - q_c, kv_no_split = qkv_lora.split( - [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], - dim=-1, - ) - q_c = self.q_a_layernorm(q_c) - # allgather need contiguous data - kv_no_split = kv_no_split.contiguous() - else: - q_c = hidden_states - kv_no_split = self.kv_a_proj_with_mqa(hidden_states)[0] - - # Process for Flash Comm V1 - q_c = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( - q_c.contiguous(), need_gather_q_kv) - kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( - kv_no_split.contiguous(), need_gather_q_kv) - - if self.fc2_o_shared_enable and is_hidden_layer( - self.vllm_config, self.o_proj): - reach_layer_for_shared_weight_series(self.o_proj) - - decode_preprocess_res = None - prefill_preprocess_res = None - if has_prefill: - wait_for_kv_layer_from_connector(layer_name) - # Preprocess for decode tokens - if has_decode: - decode_q_c = q_c[:num_decode_tokens] - cos = attn_metadata.decode.cos - sin = attn_metadata.decode.sin - decode_ql_nope, decode_q_pe = \ - self._q_proj_and_k_up_proj(decode_q_c) - if self.dcp_size > 1: - decode_q_no_split = torch.cat([decode_ql_nope, decode_q_pe], - dim=-1) - decode_q_no_split = get_dcp_group().all_gather( - decode_q_no_split, 1) - decode_ql_nope, decode_q_pe = decode_q_no_split.split( - [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - decode_q_pe = self.rope_single(decode_q_pe, cos, sin) - decode_slots = attn_metadata.slot_mapping[:num_decode_tokens * - self.pcp_size:self. - pcp_size] - decode_kv_no_split = kv_no_split[:num_decode_tokens] - decode_k_pe, decode_k_nope = self.exec_kv_decode( - decode_kv_no_split, cos, sin, kv_cache, decode_slots) - decode_preprocess_res = DecodeMLAPreprocessResult( - decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe) - # Preprocess for prefill tokens - if has_prefill: - if self.pcp_size > 1: - num_actual_tokens = (attn_metadata.num_actual_tokens_pcp_padded - - self.pcp_size * num_decode_tokens - ) // self.pcp_size + num_decode_tokens - prefill_kv_no_split = kv_no_split[ - num_decode_tokens:num_actual_tokens] - prefill_q_c = q_c[num_decode_tokens:num_actual_tokens] - prefill_q = self.q_proj(prefill_q_c)[0] \ - .view(-1, self.num_heads, self.qk_head_dim) - prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:] - prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim] - if self.pcp_size > 1: - cos = attn_metadata.prefill.cos[:num_actual_tokens - - num_decode_tokens] - sin = attn_metadata.prefill.sin[:num_actual_tokens - - num_decode_tokens] - else: - cos = attn_metadata.prefill.cos - sin = attn_metadata.prefill.sin - prefill_slots = attn_metadata.slot_mapping[ - num_decode_tokens:num_actual_tokens] - prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin) - if self.pcp_size > 1: - prefill_kv_no_split = kv_no_split[:num_actual_tokens] - kv_c, k_pe = prefill_kv_no_split.split( - [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) - assert len( - kv_cache - ) > 1, "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)" - kv_c_normed = kv_c_normed.view( - [num_actual_tokens, self.num_kv_heads, -1]) - k_pe = k_pe.unsqueeze(1) - prefill_k_pe = k_pe - prefill_k_pe[ - num_decode_tokens:num_actual_tokens] = self.rope_single( - prefill_k_pe[num_decode_tokens:num_actual_tokens], cos, - sin) - prefill_k_c_normed = kv_c_normed[:num_actual_tokens] - prefill_kv_c_k_pe = torch.cat( - [prefill_k_c_normed, prefill_k_pe], dim=-1) - prefill_kv_c_k_pe = get_pcp_group().all_gather( - prefill_kv_c_k_pe, 0) - prefill_kv_c_k_pe = torch.index_select( - prefill_kv_c_k_pe, 0, attn_metadata.prefill.pcp_metadata. - pcp_allgather_restore_idx) - prefill_kv_c_k_pe = prefill_kv_c_k_pe[num_decode_tokens * - self.pcp_size:] - prefill_k_c_normed, prefill_k_pe = prefill_kv_c_k_pe.split( - [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - kv_c_normed, k_pe = prefill_k_c_normed, prefill_k_pe - prefill_k_c_normed = prefill_k_c_normed.squeeze() - slot_mapping = attn_metadata.slot_mapping[self.pcp_size * - num_decode_tokens:] - torch_npu._npu_reshape_and_cache(key=kv_c_normed, - value=k_pe, - key_cache=kv_cache[0], - value_cache=kv_cache[1], - slot_indices=slot_mapping) - else: - prefill_k_pe, prefill_k_c_normed = self.exec_kv_prefill( - prefill_kv_no_split, cos, sin, kv_cache, prefill_slots) - prefill_k_nope, prefill_value = self.kv_b_proj( - prefill_k_c_normed)[0].view( - -1, self.num_heads, - self.qk_nope_head_dim + self.v_head_dim).split( - [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - if not self.pcp_size > 1: - prefill_k_pe = prefill_k_pe.view(prefill_q_c.shape[0], - self.num_kv_heads, -1) - prefill_k_pe = prefill_k_pe.expand( - (*prefill_k_nope.shape[:-1], -1)) - prefill_preprocess_res = PrefillMLAPreprocessResult( - prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe, - prefill_value) - return decode_preprocess_res, prefill_preprocess_res - - def _mla_decode_preprocess(self, hidden_states, kv_cache, attn_metadata): - bsz = attn_metadata.num_decode_tokens - hidden_states = hidden_states[:bsz] - - cos_shape = attn_metadata.decode.cos.shape - cos = attn_metadata.decode.cos.view(cos_shape[0], cos_shape[-1]) - sin = attn_metadata.decode.sin.view(cos_shape[0], cos_shape[-1]) - - decode_k_nope, decode_k_pe = kv_cache[0], kv_cache[1] - decode_q_nope = torch.empty( - (hidden_states.shape[0], self.W_UK_T.shape[0], - decode_k_nope.shape[-1]), - dtype=hidden_states.dtype, - device=hidden_states.device, - ) - decode_q_pe = torch.empty( - (hidden_states.shape[0], self.W_UK_T.shape[0], - decode_k_pe.shape[-1]), - dtype=hidden_states.dtype, - device=hidden_states.device, - ) - - torch.ops._C_ascend.mla_preprocess( - hidden_states, - self.wd_qkv, - self.deq_scale_qkv, - self.gamma1, - self.beta1, - self.wu_q, - self.qb_deq_scl, - self.gamma2, - cos, - sin, - self.W_UK_T, - decode_k_nope, - decode_k_pe, - attn_metadata.slot_mapping[:bsz].flatten(), - quant_scale0=self.quant_scale0, - quant_offset0=self.quant_offset0, - bias0=self.quant_bias_qkv, - quant_scale1=self.quant_scale1, - quant_offset1=self.quant_offset1, - bias1=self.qb_qt_bias, - ctkv_scale=self.ctkv_scale, - q_nope_scale=self.q_nope_scale, - cache_mode="krope_ctkv", - quant_mode="per_tensor_quant_asymm", - q_out0=decode_q_nope, - kv_cache_out0=decode_k_nope, - q_out1=decode_q_pe, - kv_cache_out1=decode_k_pe, - enable_inner_out=False, - inner_out=torch.tensor([], device=hidden_states.device)) - decode_q_nope = decode_q_nope.view(bsz, self.num_heads, - self.kv_lora_rank) - decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1) - + def reorg_decode_q(self, decode_q_nope, decode_q_pe): if self.dcp_size > 1: decode_q_no_split = torch.cat([decode_q_nope, decode_q_pe], dim=-1) decode_q_no_split = get_dcp_group().all_gather( decode_q_no_split, 1) decode_q_nope, decode_q_pe = decode_q_no_split.split( [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + return decode_q_nope, decode_q_pe - decode_preprocess_res = DecodeMLAPreprocessResult( - decode_q_nope, decode_q_pe, decode_k_nope, decode_k_pe) - return decode_preprocess_res, None - - def _forward_prefill_cp( + def _forward_prefill( self, q_nope: torch.Tensor, q_pe: torch.Tensor, @@ -721,6 +400,9 @@ class AscendMlaCPImpl(AscendMLAImpl): kv_c_and_k_pe_cache: Tuple[torch.Tensor], attn_metadata: AscendMLAMetadata, ) -> torch.Tensor: + if not self.pcp_size > 1: + return super()._forward_prefill(q_nope, q_pe, k_nope, k_pe, value, + kv_c_and_k_pe_cache, attn_metadata) assert attn_metadata.prefill is not None assert attn_metadata.prefill.pcp_metadata is not None num_tokens = q_nope.size(0) @@ -840,7 +522,7 @@ class AscendMlaCPImpl(AscendMLAImpl): softmax_lse=attn_lse) return attn_output, attn_lse - def _forward_decode_pcp_dcp( + def _forward_decode( self, q_nope: torch.Tensor, q_pe: torch.Tensor, @@ -1014,13 +696,9 @@ class AscendMlaCPImpl(AscendMLAImpl): def _reorg_kvcache( self, - allgatered_kv_c_normed: torch.Tensor, - allgatered_k_pe: torch.Tensor, - padded_local_chunk_seq_lens_lst: list[int], - local_context_lens_allranks: list[list[int]], - sum_seq_len: int, - max_seq_len: int, - chunk_size: int, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + chunked_context: CPChunkedContextMetadata, chunk_idx: int, toks: int, ) -> tuple[torch.Tensor, torch.Tensor]: @@ -1044,6 +722,29 @@ class AscendMlaCPImpl(AscendMLAImpl): chunk_idx: chunk idx of chunked_prefill. toks: the number of tokens for local gather cache. """ + assert chunked_context is not None + assert chunked_context.padded_local_chunk_seq_lens is not None + assert chunked_context.local_context_lens_allranks is not None + assert chunked_context.cu_seq_lens_lst is not None + assert chunked_context.max_seq_lens is not None + assert chunked_context.chunk_size is not None + + padded_local_chunk_seq_lens_lst = chunked_context.padded_local_chunk_seq_lens[ + chunk_idx] + local_context_lens_allranks = chunked_context.local_context_lens_allranks + sum_seq_len = chunked_context.cu_seq_lens_lst[chunk_idx][-1] + max_seq_len = chunked_context.max_seq_lens[chunk_idx] + chunk_size: int = chunked_context.chunk_size + cache_kv_c_k_pe = torch.cat([kv_c_normed, k_pe], dim=-1) + if self.dcp_size > 1: + cache_kv_c_k_pe = get_dcp_group().all_gather(cache_kv_c_k_pe, 0) + + if self.pcp_size > 1: + cache_kv_c_k_pe = get_pcp_group().all_gather(cache_kv_c_k_pe, 0) + + allgatered_kv_c_normed, allgatered_k_pe = cache_kv_c_k_pe.split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + kv_c_segments = [] k_pe_segments = [] src_token_idx = 0 diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 6454a294..096deb6a 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -503,8 +503,7 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]): common_attn_metadata.block_table_tensor[:common_attn_metadata. num_reqs]) - def set_decode_block_table( - self, common_attn_metadata: AscendCommonAttentionMetadata): + def set_decode_block_table(self): self.block_table = self.block_table[:self.num_decodes, ...] def build_prefill_metadata( @@ -564,7 +563,7 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]): self.seq_lens = self.seq_lens[:self.num_decodes] input_positions = input_positions[:self.num_decode_tokens] - self.set_decode_block_table(common_attn_metadata) + self.set_decode_block_table() # NOTE: Currently, MTP-fullgraph is incompatibility pcp # NOTE: Maybe this block_table change can be removed when graph_pad_size > 1. @@ -895,6 +894,26 @@ class AscendMLAImpl(MLAAttentionImpl): self.ctkv_scale = torch.tensor([1], dtype=act_dtype, device=device) self.q_nope_scale = torch.tensor([1], dtype=act_dtype, device=device) + def get_context_seq_len_npu(self, index: int, + attn_metadata: AscendMLAMetadata): + prefill_metadata = attn_metadata.prefill + assert prefill_metadata is not None + assert prefill_metadata.chunked_context is not None + assert prefill_metadata.chunked_context.chunk_seq_lens_npu is not None + iters = len(prefill_metadata.chunked_context.seq_tot) + assert 0 <= index < iters + return prefill_metadata.chunked_context.chunk_seq_lens_npu[index] + + def _reorg_kvcache( + self, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + chunked_context: CPChunkedContextMetadata, + chunk_idx: int, + toks: int, + ) -> tuple[torch.Tensor, torch.Tensor]: + return kv_c_normed, k_pe + def _compute_prefill_context( self, q_nope: torch.Tensor, @@ -923,9 +942,9 @@ class AscendMLAImpl(MLAAttentionImpl): # chunk_seq_lens will be padded when pcp&dcp context_seq_len = prefill_metadata.chunked_context.chunk_seq_lens[ i] - context_seq_len_npu = prefill_metadata.chunked_context.chunk_seq_lens_npu[ - i] seq_len = torch.stack([current_seq_len, context_seq_len]) + context_seq_len_npu = self.get_context_seq_len_npu( + i, attn_metadata) kv_c_normed = torch.empty(toks, num_heads, latent_kv_dim, @@ -946,7 +965,13 @@ class AscendMLAImpl(MLAAttentionImpl): key=kv_c_normed, value=k_pe, ) - + kv_c_normed, k_pe = self._reorg_kvcache( + kv_c_normed, + k_pe, + chunked_context=prefill_metadata.chunked_context, + chunk_idx=i, + toks=toks, + ) kv_c_normed = kv_c_normed.squeeze() kv_nope = self.kv_b_proj(kv_c_normed)[0].view( -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) @@ -1210,7 +1235,11 @@ class AscendMLAImpl(MLAAttentionImpl): return self._v_up_proj(attn_output) - def _mla_decode_preprocess(self, hidden_states, kv_cache, attn_metadata): + def reorg_decode_q(self, decode_q_nope, decode_q_pe): + return decode_q_nope, decode_q_pe + + def _mla_preprocess_only_decode(self, hidden_states, kv_cache, + attn_metadata): bsz = attn_metadata.num_decode_tokens hidden_states = hidden_states[:bsz] @@ -1267,10 +1296,57 @@ class AscendMLAImpl(MLAAttentionImpl): self.kv_lora_rank) decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1) + decode_q_nope, decode_q_pe = self.reorg_decode_q( + decode_q_nope, decode_q_pe) + decode_preprocess_res = DecodeMLAPreprocessResult( decode_q_nope, decode_q_pe, decode_k_nope, decode_k_pe) return decode_preprocess_res, None + def mla_preprocess_prefill(self, q_c, kv_no_split, kv_cache, + attn_metadata): + num_decode_tokens = attn_metadata.num_decode_tokens + num_actual_tokens = attn_metadata.num_actual_tokens + prefill_kv_no_split = kv_no_split[num_decode_tokens:num_actual_tokens] + prefill_q_c = q_c[num_decode_tokens:num_actual_tokens] + prefill_q = self.q_proj(prefill_q_c)[0] \ + .view(-1, self.num_heads, self.qk_head_dim) + prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:] + prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim] + cos = attn_metadata.prefill.cos + sin = attn_metadata.prefill.sin + prefill_slots = attn_metadata.slot_mapping[ + num_decode_tokens:num_actual_tokens] + prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin) + prefill_k_pe, prefill_k_c_normed = self.exec_kv_prefill( + prefill_kv_no_split, cos, sin, kv_cache, prefill_slots) + prefill_k_nope, prefill_value = self.kv_b_proj( + prefill_k_c_normed)[0].view( + -1, self.num_heads, + self.qk_nope_head_dim + self.v_head_dim).split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + prefill_k_pe = prefill_k_pe.view(prefill_q_c.shape[0], + self.num_kv_heads, -1) + prefill_k_pe = prefill_k_pe.expand((*prefill_k_nope.shape[:-1], -1)) + return PrefillMLAPreprocessResult(prefill_q_nope, prefill_q_pe, + prefill_k_nope, prefill_k_pe, + prefill_value) + + def mla_preprocess_decode(self, q_c, kv_no_split, kv_cache, attn_metadata): + num_decode_tokens = attn_metadata.num_decode_tokens + decode_q_c = q_c[:num_decode_tokens] + cos = attn_metadata.decode.cos + sin = attn_metadata.decode.sin + decode_ql_nope, decode_q_pe = \ + self._q_proj_and_k_up_proj(decode_q_c) + decode_q_pe = self.rope_single(decode_q_pe, cos, sin) + decode_slots = attn_metadata.slot_mapping[:num_decode_tokens:1] + decode_kv_no_split = kv_no_split[:num_decode_tokens] + decode_k_pe, decode_k_nope = self.exec_kv_decode( + decode_kv_no_split, cos, sin, kv_cache, decode_slots) + return DecodeMLAPreprocessResult(decode_ql_nope, decode_q_pe, + decode_k_nope, decode_k_pe) + def _mla_preprocess(self, layer_name, hidden_states, kv_cache, attn_metadata, need_gather_q_kv): # MLA Preprocess: @@ -1284,8 +1360,6 @@ class AscendMLAImpl(MLAAttentionImpl): # prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe, prefill_value has_decode = attn_metadata.num_decodes > 0 has_prefill = attn_metadata.num_prefills > 0 - num_decode_tokens = attn_metadata.num_decode_tokens - num_actual_tokens = attn_metadata.num_actual_tokens if self.fused_qkv_a_proj is not None: maybe_npu_prefetch(inputs=self.fused_qkv_a_proj.weight, dependency=hidden_states, @@ -1318,48 +1392,17 @@ class AscendMLAImpl(MLAAttentionImpl): wait_for_kv_layer_from_connector(layer_name) # Preprocess for decode tokens if has_decode: - decode_q_c = q_c[:num_decode_tokens] - cos = attn_metadata.decode.cos - sin = attn_metadata.decode.sin - decode_ql_nope, decode_q_pe = \ - self._q_proj_and_k_up_proj(decode_q_c) - decode_q_pe = self.rope_single(decode_q_pe, cos, sin) - decode_slots = attn_metadata.slot_mapping[:num_decode_tokens:1] - decode_kv_no_split = kv_no_split[:num_decode_tokens] - decode_k_pe, decode_k_nope = self.exec_kv_decode( - decode_kv_no_split, cos, sin, kv_cache, decode_slots) - decode_preprocess_res = DecodeMLAPreprocessResult( - decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe) + decode_preprocess_res = self.mla_preprocess_decode( + q_c, kv_no_split, kv_cache, attn_metadata) # Preprocess for prefill tokens if has_prefill: - prefill_kv_no_split = kv_no_split[ - num_decode_tokens:num_actual_tokens] - prefill_q_c = q_c[num_decode_tokens:num_actual_tokens] - prefill_q = self.q_proj(prefill_q_c)[0] \ - .view(-1, self.num_heads, self.qk_head_dim) - prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:] - prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim] - cos = attn_metadata.prefill.cos - sin = attn_metadata.prefill.sin - prefill_slots = attn_metadata.slot_mapping[ - num_decode_tokens:num_actual_tokens] - prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin) - prefill_k_pe, prefill_k_c_normed = self.exec_kv_prefill( - prefill_kv_no_split, cos, sin, kv_cache, prefill_slots) - prefill_k_nope, prefill_value = self.kv_b_proj( - prefill_k_c_normed)[0].view( - -1, self.num_heads, - self.qk_nope_head_dim + self.v_head_dim).split( - [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - prefill_k_pe = prefill_k_pe.view(prefill_q_c.shape[0], - self.num_kv_heads, -1) - prefill_k_pe = prefill_k_pe.expand( - (*prefill_k_nope.shape[:-1], -1)) - prefill_preprocess_res = PrefillMLAPreprocessResult( - prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe, - prefill_value) + prefill_preprocess_res = self.mla_preprocess_prefill( + q_c, kv_no_split, kv_cache, attn_metadata) return decode_preprocess_res, prefill_preprocess_res + def get_num_actual_tokens(self, attn_metadata: M): + return attn_metadata.num_actual_tokens + def forward( self, layer_name, @@ -1378,7 +1421,7 @@ class AscendMLAImpl(MLAAttentionImpl): return output.fill_(0) forward_context = get_forward_context() - num_actual_tokens = attn_metadata.num_actual_tokens + num_actual_tokens = self.get_num_actual_tokens(attn_metadata) assert attn_metadata.num_decodes is not None and \ attn_metadata.num_prefills is not None and \ attn_metadata.num_decode_tokens is not None @@ -1397,13 +1440,12 @@ class AscendMLAImpl(MLAAttentionImpl): if self.enable_mlapo and not has_prefill: hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( hidden_states.contiguous(), need_gather_q_kv) - decode_preprocess_res, prefill_preprocess_res = self._mla_decode_preprocess( + decode_preprocess_res, prefill_preprocess_res = self._mla_preprocess_only_decode( hidden_states, kv_cache, attn_metadata) else: decode_preprocess_res, prefill_preprocess_res = self._mla_preprocess( layer_name, hidden_states, kv_cache, attn_metadata, need_gather_q_kv) - if decode_preprocess_res is not None: # MLA Preprocess for decoding output_decode = self._forward_decode(decode_preprocess_res.ql_nope,