diff --git a/tests/ut/attention/test_attention_v1.py b/tests/ut/attention/test_attention_v1.py index dfb9a2a0..6806563e 100644 --- a/tests/ut/attention/test_attention_v1.py +++ b/tests/ut/attention/test_attention_v1.py @@ -63,10 +63,25 @@ class TestAscendAttentionBackend(TestBase): class TestAscendAttentionMetadataBuilder(TestBase): - def setUp(self): + @patch('vllm.distributed.parallel_state.get_dcp_group') + @patch('vllm.distributed.parallel_state._DCP', + new_callable=lambda: MagicMock(spec=GroupCoordinator)) + @patch("vllm.distributed.get_decode_context_model_parallel_world_size", + return_value=1) + def setUp(self, mock_get_dcp_size, mock_dcp, mock_get_dcp_group): + mock_dcp.world_size = 1 + dcp_group = MagicMock(spec=GroupCoordinator) + dcp_group.rank_in_group = 0 + dcp_group.world_size = 1 + dcp_group.device_group = MagicMock() + mock_get_dcp_group.return_value = dcp_group + self.mock_vllm_config = MagicMock() self.mock_vllm_config.model_config.max_model_len = 640 self.mock_vllm_config.cache_config.block_size = 64 + self.mock_vllm_config.compilation_config.cudagraph_mode = None + self.mock_vllm_config.scheduler_config.max_num_seqs = 10 + self.mock_vllm_config.scheduler_config.decode_max_num_seqs = 10 self.mock_device = 'cpu:0' self.builder = AscendAttentionMetadataBuilder(None, None, self.mock_vllm_config, diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 258d5e3a..6181631d 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -26,7 +26,7 @@ import torch.nn as nn import torch_npu from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionType) -from vllm.config import VllmConfig +from vllm.config import CUDAGraphMode, VllmConfig from vllm.distributed import (get_dcp_group, get_decode_context_model_parallel_rank, get_decode_context_model_parallel_world_size) @@ -163,8 +163,8 @@ class AscendMetadataForPrefill: @dataclass class AscendMetadataForDecode: """ Decode Specific Metadata for Ascend""" - num_computed_tokens_of_pcp_dcp: Optional[list[Optional[list[Optional[ - list[int]]]]]] = None + num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None + batch_seq_mask: torch.Tensor = None @dataclass @@ -232,10 +232,25 @@ class AscendAttentionMetadataBuilder: ): self.vllm_config = vllm_config self.model_config = vllm_config.model_config + self.compilation_config = vllm_config.compilation_config self.device = device self.max_num_blocks_per_req = cdiv( self.model_config.max_model_len, AscendAttentionBackend.get_supported_block_size()[0]) + decode_max_num_seqs = getattr(vllm_config.scheduler_config, + 'decode_max_num_seqs', 0) + max_num_seqs = max(vllm_config.scheduler_config.max_num_seqs, + decode_max_num_seqs) + self.batch_seq_mask_buf = torch.empty(max_num_seqs, + dtype=torch.uint8, + device=device) + self.pcp_size = get_prefill_context_model_parallel_world_size( + ) if prefill_context_parallel_enable() else 1 + self.pcp_rank = get_prefill_context_model_parallel_rank( + ) if self.pcp_size > 1 else 0 + self.dcp_size = get_decode_context_model_parallel_world_size() + self.dcp_rank = get_decode_context_model_parallel_rank( + ) if self.dcp_size > 1 else 0 def reorder_batch(self, input_batch, scheduler_output: "SchedulerOutput") -> bool: @@ -356,11 +371,32 @@ class AscendAttentionMetadataBuilder: common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata if common_long_seq_metadata is not None: num_computed_tokens_of_pcp_dcp = common_long_seq_metadata.num_computed_tokens_of_pcp_dcp - num_computed_tokens_of_pcp_dcp = np.array( + assert num_computed_tokens_of_pcp_dcp is not None + num_computed_tokens_array = np.array( num_computed_tokens_of_pcp_dcp) + num_computed_tokens_array = num_computed_tokens_array[: + num_decodes] + pad_length = common_attn_metadata.num_input_tokens - num_actual_tokens_pcp_padded // self.pcp_size + if self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY and pad_length > 0: + pad_tensor = np.zeros( + (pad_length, num_computed_tokens_array.shape[1], + num_computed_tokens_array.shape[2]), + dtype=num_computed_tokens_array.dtype) + + num_computed_tokens_array = np.concatenate( + [num_computed_tokens_array, pad_tensor], axis=0) + + batch_seq_mask = ( + num_computed_tokens_array[:, self.pcp_rank, + self.dcp_rank] == 0) + # TODO: numpy array mode of the shared memory is used to improve performance + self.batch_seq_mask_buf[:batch_seq_mask.shape[0]].copy_( + torch.from_numpy(batch_seq_mask), non_blocking=True) decode_metadata = AscendMetadataForDecode( - num_computed_tokens_of_pcp_dcp= - num_computed_tokens_of_pcp_dcp) + num_computed_tokens_of_pcp_dcp=num_computed_tokens_array, + batch_seq_mask=self.batch_seq_mask_buf[:batch_seq_mask. + shape[0]], + ) attn_metadata = AscendMetadata( num_actual_tokens=num_actual_tokens, @@ -869,7 +905,6 @@ class AscendAttentionBackendImpl(AttentionImpl): else: num_heads = self.num_heads - q_nope = query.view(query.shape[0], 1, query.shape[1], query.shape[2]) k_nope = self.key_cache.view(self.key_cache.shape[0], self.key_cache.shape[1], -1) value = self.value_cache.view(self.key_cache.shape[0], @@ -880,7 +915,7 @@ class AscendAttentionBackendImpl(AttentionImpl): 'num_key_value_heads': self.num_kv_heads, 'input_layout': - "BSND", + 'TND', 'atten_mask': None, 'scale': @@ -898,10 +933,12 @@ class AscendAttentionBackendImpl(AttentionImpl): 'actual_seq_lengths_kv': attn_metadata.decode_meta. num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank], + 'actual_seq_lengths': + attn_metadata.actual_seq_lengths_q[:attn_metadata.num_decodes], } graph_params = get_graph_params() forward_context: ForwardContext = get_forward_context() - num_tokens = q_nope.shape[0] + num_tokens = query.shape[0] if forward_context.capturing: stream = torch_npu.npu.current_stream() @@ -913,26 +950,27 @@ class AscendAttentionBackendImpl(AttentionImpl): workspace = graph_params.workspaces.get(num_tokens) if workspace is None: workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace( - q_nope, k_nope, value, **common_kwargs) + query, k_nope, value, **common_kwargs) update_graph_params_workspaces(num_tokens, weak_ref_tensors(workspace)) - attn_out = torch.empty_like(q_nope) - attn_lse = torch.empty((num_tokens, num_heads, 1, 1), + attn_out = torch.empty_like(query) + attn_lse = torch.empty((num_tokens, num_heads, 1), dtype=torch.float, - device=q_nope.device) + device=query.device) - graph_params.attn_params[num_tokens].append( - (weak_ref_tensors(q_nope), weak_ref_tensors(k_nope), - weak_ref_tensors(value), self.num_heads, self.num_kv_heads, - self.scale, attn_metadata.block_tables, - self.key_cache.shape[1], attn_metadata.decode_meta. - num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, - self.dcp_rank], - weak_ref_tensors(attn_out), weak_ref_tensors(attn_lse), - self.pcp_rank, self.dcp_rank, self.dcp_size)) + graph_params.attn_params[num_tokens].append(( + weak_ref_tensors(query), weak_ref_tensors(k_nope), + weak_ref_tensors(value), self.num_heads, self.num_kv_heads, + self.scale, attn_metadata.block_tables, + self.key_cache.shape[1], attn_metadata.decode_meta. + num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, + self.dcp_rank], + attn_metadata.actual_seq_lengths_q[:attn_metadata.num_decodes], + weak_ref_tensors(attn_out), weak_ref_tensors(attn_lse), + self.dcp_size, self.pcp_rank, self.dcp_rank)) torch.npu.graph_task_group_begin(stream) torch_npu.npu_fused_infer_attention_score.out( - q_nope, + query, k_nope, value, **common_kwargs, @@ -942,11 +980,17 @@ class AscendAttentionBackendImpl(AttentionImpl): graph_params.handles[num_tokens].append(handle) else: attn_out, attn_lse = torch_npu.npu_fused_infer_attention_score( - q_nope, k_nope, value, **common_kwargs) + query, k_nope, value, **common_kwargs) - attn_out = attn_out.view(attn_out.shape[0], attn_out.shape[2], - attn_out.shape[3]) - attn_lse = attn_lse.view(attn_lse.shape[0], attn_lse.shape[1], 1) + out_mask = attn_metadata.decode_meta.batch_seq_mask[:, None, + None].expand_as( + attn_out) + attn_out = torch.where(out_mask, 0, attn_out) + + lse_mask = attn_metadata.decode_meta.batch_seq_mask[:, None, + None].expand_as( + attn_lse) + attn_lse = torch.where(lse_mask, -torch.inf, attn_lse) attn_out_lse_list = [] # Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1] diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index faf03253..9e7e3c32 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -140,8 +140,7 @@ class AscendMLADecodeMetadata: attn_mask: Optional[torch.Tensor] = None sin: torch.Tensor = None cos: torch.Tensor = None - num_computed_tokens_of_pcp_dcp: Optional[list[Optional[list[Optional[ - list[int]]]]]] = None + num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None seq_mask_pcp: torch.Tensor = None seq_mask_dcp: torch.Tensor = None cp_seq_len: torch.Tensor = None diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index ede83f74..48118c05 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -16,8 +16,7 @@ class AscendPrefillContextParallelMetadata: num_actual_tokens_pcp_padded: Optional[int] = None - num_computed_tokens_of_pcp_dcp: Optional[list[Optional[list[Optional[ - list[int]]]]]] = None + num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None q_head_idx_tensor: torch.Tensor = None diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index b9293267..82410ec8 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -7,7 +7,6 @@ from dataclasses import dataclass from typing import Any, Callable, Optional from unittest.mock import patch -import numpy as np import torch import torch_npu import vllm.envs as envs @@ -302,16 +301,25 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape): graph_params.events[runtime_shape], ): (q_nope, k_nope, value, num_heads, num_kv_heads, scale, - block_table, block_size, actual_seq_lengths_kv, attn_output, - softmax_lse, pcp_rank, dcp_rank, dcp_size) = param - actual_seq_lengths_kv = forward_context.attn_metadata[ - key].decode_meta.num_computed_tokens_of_pcp_dcp[:, pcp_rank, - dcp_rank] - pad_length = runtime_shape - len(actual_seq_lengths_kv) - pad_tensor = np.zeros(pad_length, - dtype=actual_seq_lengths_kv.dtype) - actual_seq_lengths_kv = np.concatenate( - [actual_seq_lengths_kv, pad_tensor]) + block_table, block_size, actual_seq_lengths_kv, + actual_seq_lengths_q, attn_output, softmax_lse, dcp_size, + pcp_rank, dcp_rank) = param + attn_metadata = forward_context.attn_metadata[key] + actual_seq_lengths_kv = attn_metadata.decode_meta.num_computed_tokens_of_pcp_dcp[:, + pcp_rank, + dcp_rank] + if (runtime_shape - len(actual_seq_lengths_kv)) > 0: + actual_seq_lengths_kv = actual_seq_lengths_kv + [0] * ( + runtime_shape - len(actual_seq_lengths_kv)) + + actual_seq_lengths_q = attn_metadata.actual_seq_lengths_q[: + attn_metadata + . + num_decode_tokens] + if (runtime_shape - len(actual_seq_lengths_q)): + actual_seq_lengths_q = actual_seq_lengths_q + [ + actual_seq_lengths_q[-1] + ] * (runtime_shape - len(actual_seq_lengths_q)) if dcp_size > 1: num_heads = num_heads * dcp_size @@ -323,7 +331,7 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape): value, num_heads=num_heads, num_key_value_heads=num_kv_heads, - input_layout="BSND", + input_layout="TND", atten_mask=None, scale=scale, antiquant_mode=0, @@ -332,6 +340,7 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape): block_table=block_table, block_size=block_size, actual_seq_lengths_kv=actual_seq_lengths_kv, + actual_seq_lengths=actual_seq_lengths_q, workspace=graph_params.workspaces.get(runtime_shape), out=[attn_output, softmax_lse]) torch.npu.graph_task_update_end(update_stream) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 8e4acdd0..66868dd6 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1411,11 +1411,14 @@ class NPUModelRunner(LoRAModelRunnerMixin): req_indices, positions_np) self.input_batch.block_table.commit_slot_mapping( total_num_scheduled_tokens) - tokens, position_pcp, pcp_unpad_mask = self._update_tokens_for_pcp( - tokens) - num_scheduled_tokens = np.array(tokens, dtype=np.int32) - # update total_num_scheduled_tokens - total_num_scheduled_tokens = sum(num_scheduled_tokens[:num_reqs]) + if self.pcp_size > 1: + tokens, position_pcp, pcp_unpad_mask = self._update_tokens_for_pcp( + tokens) + num_scheduled_tokens = np.array(tokens, dtype=np.int32) + total_num_scheduled_tokens = sum(num_scheduled_tokens[:num_reqs]) + else: + position_pcp, pcp_unpad_mask = None, None + self.num_pcp_pads = self.num_pcp_pads[:num_reqs] total_num_pcp_pads = sum(self.num_pcp_pads) max_num_scheduled_tokens = max(tokens) @@ -4180,8 +4183,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): def _update_tokens_for_pcp(self, tokens): num_reqs = self.input_batch.num_reqs self.num_pcp_pads = self.num_pcp_pads[:num_reqs] - if not self.pcp_size > 1: - return tokens, None, None tokens = np.array(tokens, dtype=np.int32) num_decode_reqs = sum( self.input_batch.num_computed_tokens_cpu[:num_reqs] >=