From 4312a92a4f480f2a4ba79baf184be1b55e6da139 Mon Sep 17 00:00:00 2001 From: weiguihua2 Date: Mon, 27 Oct 2025 09:58:23 +0800 Subject: [PATCH] [feat]dcp pcp support aclgraph (#3731) ### What this PR does / why we need it? dcp pcp support full aclgraph, including mla attention_v1 - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/c9461e05a4ed3557cfbf4b15ded1e26761cc39ca Signed-off-by: weiguihua2 --- tests/ut/attention/test_mla_v1.py | 49 +++++++- vllm_ascend/attention/attention_v1.py | 93 ++++++++++++--- vllm_ascend/attention/mla_v1.py | 161 +++++++++++++++++++++----- vllm_ascend/compilation/acl_graph.py | 100 ++++++++++++++++ vllm_ascend/worker/model_runner_v1.py | 79 ++++++++++--- 5 files changed, 414 insertions(+), 68 deletions(-) diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index e7b17a36..d8ddc6a6 100644 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -176,16 +176,30 @@ class TestAscendMLAMetadata(TestBase): class TestAscendMLAMetadataBuilder(TestBase): - def test_ascend_mla_metadata_builder_default(self): + @patch('vllm.distributed.parallel_state.get_dcp_group') + @patch('vllm.distributed.parallel_state._DCP', + new_callable=lambda: MagicMock(spec=GroupCoordinator)) + @patch("vllm.distributed.get_decode_context_model_parallel_world_size", + return_value=1) + def test_ascend_mla_metadata_builder_default(self, mock_get_dcp_size, + mock_dcp, mock_get_dcp_group): mock_vllm_config = MagicMock() mock_vllm_config.model_config.max_model_len = 1024 mock_vllm_config.model_config.get_head_size.return_value = 64 mock_vllm_config.model_config.dtype = torch.float16 mock_vllm_config.cache_config.block_size = 16 mock_vllm_config.scheduler_config.max_num_seqs = 4 + mock_vllm_config.scheduler_config.decode_max_num_seqs = 4 mock_vllm_config.scheduler_config.chunked_prefill_enabled = False mock_device = 'cpu' + mock_dcp.world_size = 1 + dcp_group = MagicMock(spec=GroupCoordinator) + dcp_group.rank_in_group = 0 + dcp_group.world_size = 1 + dcp_group.device_group = MagicMock() + mock_get_dcp_group.return_value = dcp_group + mock_vllm_config.speculative_config = None ascend_config = MagicMock() @@ -200,16 +214,31 @@ class TestAscendMLAMetadataBuilder(TestBase): builder.chunked_prefill_enabled, mock_vllm_config.scheduler_config.chunked_prefill_enabled) - def test_ascend_mla_metadata_builder_spec_decode(self): + @patch('vllm.distributed.parallel_state.get_dcp_group') + @patch('vllm.distributed.parallel_state._DCP', + new_callable=lambda: MagicMock(spec=GroupCoordinator)) + @patch("vllm.distributed.get_decode_context_model_parallel_world_size", + return_value=1) + def test_ascend_mla_metadata_builder_spec_decode(self, mock_get_dcp_size, + mock_dcp, + mock_get_dcp_group): mock_vllm_config = MagicMock() mock_vllm_config.model_config.max_model_len = 1024 mock_vllm_config.model_config.get_head_size.return_value = 64 mock_vllm_config.model_config.dtype = torch.float16 mock_vllm_config.cache_config.block_size = 16 mock_vllm_config.scheduler_config.max_num_seqs = 4 + mock_vllm_config.scheduler_config.decode_max_num_seqs = 4 mock_vllm_config.scheduler_config.chunked_prefill_enabled = False mock_device = 'cpu' + mock_dcp.world_size = 1 + dcp_group = MagicMock(spec=GroupCoordinator) + dcp_group.rank_in_group = 0 + dcp_group.world_size = 1 + dcp_group.device_group = MagicMock() + mock_get_dcp_group.return_value = dcp_group + mock_spec_config = MagicMock() mock_spec_config.num_speculative_tokens = 3 mock_vllm_config.speculative_config = mock_spec_config @@ -226,16 +255,30 @@ class TestAscendMLAMetadataBuilder(TestBase): builder.chunked_prefill_enabled, mock_vllm_config.scheduler_config.chunked_prefill_enabled) - def test_reorder_batch(self): + @patch('vllm.distributed.parallel_state.get_dcp_group') + @patch('vllm.distributed.parallel_state._DCP', + new_callable=lambda: MagicMock(spec=GroupCoordinator)) + @patch("vllm.distributed.get_decode_context_model_parallel_world_size", + return_value=1) + def test_reorder_batch(self, mock_get_dcp_size, mock_dcp, + mock_get_dcp_group): ascend_config = MagicMock() mock_vllm_config = MagicMock() mock_vllm_config.model_config.max_model_len = 1024 mock_vllm_config.cache_config.block_size = 16 mock_vllm_config.scheduler_config.max_num_seqs = 4 + mock_vllm_config.scheduler_config.decode_max_num_seqs = 4 mock_vllm_config.scheduler_config.chunked_prefill_enabled = False mock_device = 'cpu' + mock_dcp.world_size = 1 + dcp_group = MagicMock(spec=GroupCoordinator) + dcp_group.rank_in_group = 0 + dcp_group.world_size = 1 + dcp_group.device_group = MagicMock() + mock_get_dcp_group.return_value = dcp_group + mock_vllm_config.speculative_config = None with patch("vllm_ascend.attention.mla_v1.get_ascend_config", diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index ef36c04c..62bca309 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -865,26 +865,81 @@ class AscendAttentionBackendImpl(AttentionImpl): num_heads = self.num_heads # 1. Compute out&lse by "npu_fused_infer_attention_score" - attn_out, attn_lse = torch.ops.npu.npu_fused_infer_attention_score( - query.view(query.shape[0], 1, query.shape[1], query.shape[2]), - # [b,num_heads,head_size] -> [b,1,num_heads,head_size] - self.key_cache.view(self.key_cache.shape[0], - self.key_cache.shape[1], -1), - self.value_cache.view(self.key_cache.shape[0], - self.key_cache.shape[1], -1), - num_heads=num_heads, - num_key_value_heads=self.num_kv_heads, - input_layout="BSND", - atten_mask=None, - scale=self.scale, - antiquant_mode=0, - antiquant_scale=None, - softmax_lse_flag=True, - block_table=attn_metadata.block_tables, - block_size=self.key_cache.shape[1], - actual_seq_lengths_kv=attn_metadata.decode_meta. + q_nope = query.view(query.shape[0], 1, query.shape[1], query.shape[2]) + # [b,num_heads,head_size] -> [b,1,num_heads,head_size] + k_nope = self.key_cache.view(self.key_cache.shape[0], + self.key_cache.shape[1], -1) + value = self.value_cache.view(self.key_cache.shape[0], + self.key_cache.shape[1], -1) + common_kwargs = { + 'num_heads': + num_heads, + 'num_key_value_heads': + self.num_kv_heads, + 'input_layout': + "BSND", + 'atten_mask': + None, + 'scale': + self.scale, + 'antiquant_mode': + 0, + 'antiquant_scale': + None, + 'softmax_lse_flag': + True, + 'block_table': + attn_metadata.block_tables, + 'block_size': + self.key_cache.shape[1], + "actual_seq_lengths_kv": + attn_metadata.decode_meta. num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank], - ) + } + graph_params = get_graph_params() + forward_context: ForwardContext = get_forward_context() + num_tokens = query.shape[0] + if forward_context.capturing: + stream = torch_npu.npu.current_stream() + + event = torch.npu.ExternalEvent() + event.wait(stream) + event.reset(stream) + graph_params.events[num_tokens].append(event) + + workspace = graph_params.workspaces.get(num_tokens) + if workspace is None: + workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace( + q_nope, k_nope, value, **common_kwargs) + update_graph_params_workspaces(num_tokens, + weak_ref_tensors(workspace)) + attn_out = torch.empty_like(q_nope) + attn_lse = torch.empty((num_tokens, num_heads, 1, 1), + dtype=torch.float, + device=q_nope.device) + + graph_params.attn_params[num_tokens].append( + (weak_ref_tensors(q_nope), weak_ref_tensors(k_nope), + weak_ref_tensors(value), self.num_heads, self.num_kv_heads, + self.scale, attn_metadata.block_tables, + self.key_cache.shape[1], attn_metadata.decode_meta. + num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, + self.dcp_rank], + weak_ref_tensors(attn_out), weak_ref_tensors(attn_lse), + self.pcp_rank, self.dcp_rank, self.dcp_size)) + torch.npu.graph_task_group_begin(stream) + torch_npu.npu_fused_infer_attention_score.out( + q_nope, + k_nope, + value, + **common_kwargs, + workspace=workspace, + out=[attn_out, attn_lse]) + handle = torch.npu.graph_task_group_end(stream) + graph_params.handles[num_tokens].append(handle) + else: + attn_out, attn_lse = torch_npu.npu_fused_infer_attention_score( + q_nope, k_nope, value, **common_kwargs) attn_out = attn_out.view(attn_out.shape[0], attn_out.shape[2], attn_out.shape[3]) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index ebba38c1..d57e7316 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -140,6 +140,9 @@ class AscendMLADecodeMetadata: cos: torch.Tensor = None num_computed_tokens_of_pcp_dcp: Optional[list[Optional[list[Optional[ list[int]]]]]] = None + seq_mask_pcp: torch.Tensor = None + seq_mask_dcp: torch.Tensor = None + cp_seq_len: torch.Tensor = None @dataclass @@ -259,6 +262,24 @@ class AscendMLAMetadataBuilder: self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim self.cos_cache = None self.sin_cache = None + self.pcp_size = get_prefill_context_model_parallel_world_size( + ) if prefill_context_parallel_enable() else 1 + self.cp_rank = get_prefill_context_model_parallel_rank( + ) if self.pcp_size > 1 else 0 + self.dcp_size = get_decode_context_model_parallel_world_size() + self.dcp_rank = get_decode_context_model_parallel_rank( + ) if self.dcp_size > 1 else 0 + decode_max_num_seqs = getattr(scheduler_config, 'decode_max_num_seqs', + 0) + max_num_seqs = max(scheduler_config.max_num_seqs, decode_max_num_seqs) + self.seq_mask_pcp_buf = torch.empty(max_num_seqs, + self.pcp_size, + dtype=torch.uint8, + device=device) + self.seq_mask_dcp_buf = torch.empty(max_num_seqs, + self.dcp_size, + dtype=torch.uint8, + device=device) def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: @@ -463,6 +484,41 @@ class AscendMLAMetadataBuilder: block_table = block_table[:num_decodes, ...] seq_lens_list = seq_lens.tolist() + if num_computed_tokens_of_pcp_dcp is not None: + num_computed_tokens_of_cp_dcp_array = np.array( + num_computed_tokens_of_pcp_dcp + )[:num_decodes] # [bs, pcp_size, dcp_size] + seq_mask_pcp = torch.where( + torch.tensor( + num_computed_tokens_of_cp_dcp_array.sum(2)) == 0, 0, + 1).to(torch.uint8) + self.seq_mask_pcp_buf[:seq_mask_pcp.shape[0], :seq_mask_pcp. + shape[1]].copy_(seq_mask_pcp, + non_blocking=True) + seq_mask_pcp_shape = (seq_mask_pcp.shape[0], + seq_mask_pcp.shape[1]) + + seq_mask_dcp = torch.where( + torch.tensor( + num_computed_tokens_of_cp_dcp_array[:, + self.cp_rank, :]) + == 0, 0, 1).to(torch.uint8) + self.seq_mask_dcp_buf[:seq_mask_dcp.shape[0], :seq_mask_dcp. + shape[1]].copy_(seq_mask_dcp, + non_blocking=True) + seq_mask_dcp_shape = (seq_mask_dcp.shape[0], + seq_mask_dcp.shape[1]) + + cp_seq_len = num_computed_tokens_of_cp_dcp_array[:, + self.cp_rank, + self.dcp_rank] + cp_seq_len = torch.tensor(cp_seq_len, dtype=torch.int32) + cp_seq_len = torch.where(cp_seq_len == 0, 1, cp_seq_len) + else: + seq_mask_pcp_shape = (0, 0) + seq_mask_dcp_shape = (0, 0) + cp_seq_len = None + # TODO: After the fullgraph supports MTP, the if branch needs to deleted assert self.cos_cache is not None assert self.sin_cache is not None @@ -485,7 +541,14 @@ class AscendMLAMetadataBuilder: sin=sin, cos=cos, num_computed_tokens_of_pcp_dcp= - num_computed_tokens_of_pcp_dcp) + num_computed_tokens_of_pcp_dcp, + seq_mask_pcp=self. + seq_mask_pcp_buf[:seq_mask_pcp_shape[0], : + seq_mask_pcp_shape[1]], + seq_mask_dcp=self. + seq_mask_dcp_buf[:seq_mask_dcp_shape[0], : + seq_mask_dcp_shape[1]], + cp_seq_len=cp_seq_len) else: cos[:num_decode_tokens, ...] = self.cos_cache[input_positions].unsqueeze( @@ -505,7 +568,14 @@ class AscendMLAMetadataBuilder: sin=sin[:num_decode_tokens, ...], cos=cos[:num_decode_tokens, ...], num_computed_tokens_of_pcp_dcp= - num_computed_tokens_of_pcp_dcp) + num_computed_tokens_of_pcp_dcp, + seq_mask_pcp=self. + seq_mask_pcp_buf[:seq_mask_pcp_shape[0], : + seq_mask_pcp_shape[1]], + seq_mask_dcp=self. + seq_mask_dcp_buf[:seq_mask_dcp_shape[0], : + seq_mask_dcp_shape[1]], + cp_seq_len=cp_seq_len) return self.metadata_cls( # type: ignore num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded, @@ -1590,36 +1660,63 @@ class AscendMLAImpl(MLAAttentionImpl): q_nope = q_nope.view(num_tokens, num_heads, -1) q_pe = q_pe.view(num_tokens, num_heads, -1) # use pcp & dcp split computed token nums from scheduler to compute actual seq_len and seq_mask - num_computed_tokens_of_pcp_dcp = np.array( - decode_meta.num_computed_tokens_of_pcp_dcp - )[:attn_metadata.num_decodes] # [bs, pcp_size, dcp_size] - seq_mask_pcp = torch.where( - torch.tensor(num_computed_tokens_of_pcp_dcp.sum(2)) == 0, 0, - 1).to(torch.uint8).to(q_pe.device) - seq_mask_dcp = torch.where( - torch.tensor( - num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, :]) == 0, 0, - 1).to(torch.uint8).to(q_pe.device) - seq_len = num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, - self.dcp_rank] - seq_len = torch.tensor(seq_len, dtype=torch.int32) - # npu_multi_head_latent_attention does not support seq_len = 0, - # update where seq_len == 0 to 1. - # This will not influence result, since we will use seq_mask to update lse. - seq_len = torch.where(seq_len == 0, 1, seq_len) + seq_mask_pcp = decode_meta.seq_mask_pcp + seq_mask_dcp = decode_meta.seq_mask_dcp + seq_len = decode_meta.cp_seq_len - if torch.sum(seq_len).item() == 0: - # Case that no kv_cache has been stored on this rank, no need to do following computation. - attn_output = torch.zeros( - [num_tokens, num_heads, self.kv_lora_rank], - dtype=q_nope.dtype, - device=q_nope.device) - softmax_lse = torch.full((num_tokens, num_heads, 1), - float('-inf'), - dtype=q_nope.dtype, - device=q_nope.device) + common_kwargs = { + "return_lse": True, + "calc_type": "calc_type_ring", + } + graph_params = get_graph_params() + forward_context: ForwardContext = get_forward_context() + if forward_context.capturing: + stream = torch_npu.npu.current_stream() + event = torch.npu.ExternalEvent() + event.wait(stream) + event.reset(stream) + graph_params.events[num_tokens].append(event) + workspace = graph_params.workspaces.get(num_tokens) + if workspace is None: + workspace = torch_npu.atb._npu_multi_head_latent_attention_get_workspace( + q_nope, q_pe, k_nope, k_pe, decode_meta.block_table, + seq_len, num_heads, self.scale, self.num_kv_heads, + **common_kwargs) + update_graph_params_workspaces(num_tokens, + weak_ref_tensors(workspace)) + attn_output = torch.empty_like(q_nope) + softmax_lse = torch.empty((num_tokens, num_heads, 1), + dtype=q_nope.dtype, + device=q_nope.device) + graph_params.attn_params[num_tokens].append( + (weak_ref_tensors(q_nope), weak_ref_tensors(q_pe), + weak_ref_tensors(k_nope), weak_ref_tensors(k_pe), + decode_meta.block_table, seq_len, num_heads, self.scale, + self.num_kv_heads, weak_ref_tensors(attn_output), + weak_ref_tensors(softmax_lse))) + torch.npu.graph_task_group_begin(stream) + torch_npu.atb.npu_multi_head_latent_attention( + q_nope, + q_pe, + k_nope, + k_pe, + decode_meta.block_table, + seq_len, + num_heads, + self.scale, + self.num_kv_heads, + **common_kwargs, + workspace=workspace, + output=attn_output, + lse=softmax_lse) + handle = torch.npu.graph_task_group_end(stream) + graph_params.handles[num_tokens].append(handle) else: - attn_output, softmax_lse = torch_npu.atb.npu_multi_head_latent_attention( + attn_output = torch.empty_like(q_nope) + softmax_lse = torch.empty((num_tokens, num_heads, 1), + dtype=q_nope.dtype, + device=q_nope.device) + torch_npu.atb.npu_multi_head_latent_attention( q_nope, q_pe, k_nope, @@ -1630,7 +1727,9 @@ class AscendMLAImpl(MLAAttentionImpl): self.scale, self.num_kv_heads, return_lse=True, - calc_type="calc_type_ring") + calc_type="calc_type_ring", + output=attn_output, + lse=softmax_lse) if self.dcp_size > 1: # Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1] diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index 91d75b52..5548787f 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -7,6 +7,7 @@ from dataclasses import dataclass from typing import Any, Callable, Optional from unittest.mock import patch +import numpy as np import torch import torch_npu import vllm.envs as envs @@ -300,6 +301,105 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape, event.record(update_stream) +def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape): + graph_params = get_graph_params() + # FIXME: Behold! We are using a temporary hack here to update the args + # for each layer's attention op in the graph. + for key, param, handle, event in zip( + forward_context.attn_metadata, + graph_params.attn_params[runtime_shape], + graph_params.handles[runtime_shape], + graph_params.events[runtime_shape], + ): + (q_nope, k_nope, value, num_heads, num_kv_heads, scale, block_table, + block_size, actual_seq_lengths_kv, attn_output, softmax_lse, cp_rank, + dcp_rank, dcp_size) = param + actual_seq_lengths_kv = forward_context.attn_metadata[ + key].decode_meta.num_computed_tokens_of_pcp_dcp[:, cp_rank, + dcp_rank] + pad_length = runtime_shape - len(actual_seq_lengths_kv) + pad_tensor = np.zeros(pad_length, dtype=actual_seq_lengths_kv.dtype) + actual_seq_lengths_kv = np.concatenate( + [actual_seq_lengths_kv, pad_tensor]) + if dcp_size > 1: + num_heads = num_heads * dcp_size + + with torch.npu.stream(update_stream): + torch.npu.graph_task_update_begin(update_stream, handle) + + torch_npu.npu_fused_infer_attention_score.out( + q_nope, + k_nope, + value, + num_heads=num_heads, + num_key_value_heads=num_kv_heads, + input_layout="BSND", + atten_mask=None, + scale=scale, + antiquant_mode=0, + antiquant_scale=None, + softmax_lse_flag=True, + block_table=block_table, + block_size=block_size, + actual_seq_lengths_kv=actual_seq_lengths_kv, + workspace=graph_params.workspaces.get(runtime_shape), + out=[attn_output, softmax_lse]) + torch.npu.graph_task_update_end(update_stream) + + event.record(update_stream) + + +def update_mla_attn_dcp_pcp_params(update_stream, forward_context, + runtime_shape, speculative_config): + graph_params = get_graph_params() + # FIXME: Behold! We are using a temporary hack here to update the args + # for each layer's attention op in the graph. + for key, param, handle, event in zip( + forward_context.attn_metadata, + graph_params.attn_params[runtime_shape], + graph_params.handles[runtime_shape], + graph_params.events[runtime_shape], + ): + (q_nope, q_pe, k_nope, k_pe, block_table, seq_len, num_heads, scale, + num_kv_heads, attn_output, softmax_lse) = param + + decode_meta = forward_context.attn_metadata[key].decode + seq_len = decode_meta.cp_seq_len + + if speculative_config and speculative_config.method == "deepseek_mtp": + spec_multiple = speculative_config.num_speculative_tokens + 1 + seq_len = seq_len + [0] * (runtime_shape // spec_multiple - + len(seq_len)) + else: + pad_length = runtime_shape - len(seq_len) + pad_tensor = torch.zeros(pad_length, + dtype=seq_len.dtype, + device=seq_len.device) + seq_len = torch.cat([seq_len, pad_tensor], dim=0) + + with torch.npu.stream(update_stream): + torch.npu.graph_task_update_begin(update_stream, handle) + + torch_npu.atb.npu_multi_head_latent_attention( + q_nope, + q_pe, + k_nope, + k_pe, + block_table, + seq_len, + num_heads, + scale, + num_kv_heads, + return_lse=True, + calc_type="calc_type_ring", + workspace=graph_params.workspaces.get(runtime_shape), + output=attn_output, + lse=softmax_lse) + torch.npu.graph_task_update_end(update_stream) + + event.record(update_stream) + + @dataclass class GraphParams: events: dict[int, list[torch.npu.ExternalEvent]] diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 2f006dbc..f30a9a39 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -110,10 +110,14 @@ from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, AscendPrefillContextParallelMetadata) +# yapf: disable from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper, set_graph_params, + update_attn_dcp_pcp_params, update_attn_params, + update_mla_attn_dcp_pcp_params, update_mla_attn_params) +# yapf: enable from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor from vllm_ascend.eplb.core.eplb_device_transfer_loader import \ D2DExpertWeightLoader @@ -1649,6 +1653,11 @@ class NPUModelRunner(LoRAModelRunnerMixin): slot_mapping = blk_table.slot_mapping[:slot_mapping_size] blk_table.slot_mapping[slot_mapping_size:].fill_(0) if self.pcp_size > 1: + slot_mapping_for_pcp = blk_table.slot_mapping[: + long_seq_metadata + . + num_actual_tokens_pcp_padded] + slot_mapping_for_pcp[slot_mapping_size:].fill_(-1) assert pcp_unpad_mask is not None pcp_padded_slot_mapping = self.pcp_padded_slot_mapping[: pcp_unpad_mask @@ -1657,10 +1666,11 @@ class NPUModelRunner(LoRAModelRunnerMixin): 0]] pcp_padded_slot_mapping.fill_(-1) pcp_padded_slot_mapping[ - pcp_unpad_mask] = blk_table.slot_mapping[: - slot_mapping_size] - blk_table.slot_mapping[:long_seq_metadata. - num_actual_tokens_pcp_padded] = pcp_padded_slot_mapping + pcp_unpad_mask] = slot_mapping_for_pcp[: + slot_mapping_size] + slot_mapping_for_pcp[:long_seq_metadata. + num_actual_tokens_pcp_padded] = pcp_padded_slot_mapping + slot_mapping = slot_mapping_for_pcp # Make AscendCommonAttentionMetadata common_attn_metadata = AscendCommonAttentionMetadata( @@ -1749,13 +1759,25 @@ class NPUModelRunner(LoRAModelRunnerMixin): if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL: # TODO: maybe_padded_num_tokens will be removed, use num_input_tokens instead if self.vllm_config.model_config.use_mla: - # FIXME: Try using `auto_dispatch_capture=True` - update_mla_attn_params(self.update_stream, forward_context, - maybe_padded_num_tokens, - self.speculative_config) + if self.pcp_size * self.dcp_size > 1: + # FIXME: Try using `auto_dispatch_capture=True` + update_mla_attn_dcp_pcp_params(self.update_stream, + forward_context, + maybe_padded_num_tokens, + self.speculative_config) + else: + # FIXME: Try using `auto_dispatch_capture=True` + update_mla_attn_params(self.update_stream, forward_context, + maybe_padded_num_tokens, + self.speculative_config) else: - update_attn_params(self.update_stream, forward_context, - maybe_padded_num_tokens) + if self.pcp_size * self.dcp_size > 1: + update_attn_dcp_pcp_params(self.update_stream, + forward_context, + maybe_padded_num_tokens) + else: + update_attn_params(self.update_stream, forward_context, + maybe_padded_num_tokens) if get_forward_context().sp_enabled: hidden_states = tensor_model_parallel_all_gather(hidden_states, 0) @@ -2488,6 +2510,19 @@ class NPUModelRunner(LoRAModelRunnerMixin): kv_cache_group_id].get_device_tensor() slot_mapping = self.input_batch.block_table[ kv_cache_group_id].slot_mapping + self.cp_kv_recover_idx = torch.zeros(self.max_num_tokens, + dtype=torch.int32, + device=self.device) + long_seq_metadata = self._generate_pcp_metadata( + num_tokens, self.seq_lens_cpu) + if long_seq_metadata is not None: + pcp_world_size = get_pcp_group( + ).world_size if prefill_context_parallel_enable() else 1 + dcp_world_size = get_dcp_group().world_size + num_computed_tokens_of_pcp_dcp = [[ + [0] * dcp_world_size for _ in range(pcp_world_size) + ] for _ in range(num_tokens)] + long_seq_metadata.num_computed_tokens_of_pcp_dcp = num_computed_tokens_of_pcp_dcp common_attn_metadata = AscendCommonAttentionMetadata( query_start_loc=torch.tensor( [0] + self.actual_seq_lengths_q[:num_reqs], @@ -2511,6 +2546,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): decode_token_per_req=self.decode_token_per_req, cos=self.cos, sin=self.sin, + prefill_context_parallel_metadata=long_seq_metadata, ) attn_state = AscendAttentionState.DecodeOnly if self.speculative_config and \ @@ -2540,12 +2576,25 @@ class NPUModelRunner(LoRAModelRunnerMixin): not forward_context.capturing: if self.vllm_config.model_config.use_mla: # FIXME: Try using `auto_dispatch_capture=True` - update_mla_attn_params(self.update_stream, forward_context, - positions.shape[0], - self.speculative_config) + if self.pcp_size * self.dcp_size > 1: + # FIXME: Try using `auto_dispatch_capture=True` + update_mla_attn_dcp_pcp_params(self.update_stream, + forward_context, + positions.shape[0], + self.speculative_config) + else: + # FIXME: Try using `auto_dispatch_capture=True` + update_mla_attn_params(self.update_stream, forward_context, + positions.shape[0], + self.speculative_config) else: - update_attn_params(self.update_stream, forward_context, - positions.shape[0]) + if self.pcp_size * self.dcp_size > 1: + update_attn_dcp_pcp_params(self.update_stream, + forward_context, + positions.shape[0]) + else: + update_attn_params(self.update_stream, forward_context, + positions.shape[0]) if self.drafter and self.drafter.name == SpecDcodeType.EAGLE3: hidden_states, _ = hidden_states