[feat]decode convert bsnd to tnd and fix bug when pcp and dcp (#3980)

### What this PR does / why we need it?
1、in attention_v1 module, convert bsnd t0 tnd when pcp and dcp
2、fix tochair bug: service startup problem

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.11.0
- vLLM main:
83f478bb19

Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
weiguihua2
2025-11-06 14:58:24 +08:00
committed by GitHub
parent 25b24c02ea
commit 2eebe1dc0a
6 changed files with 118 additions and 51 deletions

View File

@@ -63,10 +63,25 @@ class TestAscendAttentionBackend(TestBase):
class TestAscendAttentionMetadataBuilder(TestBase): class TestAscendAttentionMetadataBuilder(TestBase):
def setUp(self): @patch('vllm.distributed.parallel_state.get_dcp_group')
@patch('vllm.distributed.parallel_state._DCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
return_value=1)
def setUp(self, mock_get_dcp_size, mock_dcp, mock_get_dcp_group):
mock_dcp.world_size = 1
dcp_group = MagicMock(spec=GroupCoordinator)
dcp_group.rank_in_group = 0
dcp_group.world_size = 1
dcp_group.device_group = MagicMock()
mock_get_dcp_group.return_value = dcp_group
self.mock_vllm_config = MagicMock() self.mock_vllm_config = MagicMock()
self.mock_vllm_config.model_config.max_model_len = 640 self.mock_vllm_config.model_config.max_model_len = 640
self.mock_vllm_config.cache_config.block_size = 64 self.mock_vllm_config.cache_config.block_size = 64
self.mock_vllm_config.compilation_config.cudagraph_mode = None
self.mock_vllm_config.scheduler_config.max_num_seqs = 10
self.mock_vllm_config.scheduler_config.decode_max_num_seqs = 10
self.mock_device = 'cpu:0' self.mock_device = 'cpu:0'
self.builder = AscendAttentionMetadataBuilder(None, None, self.builder = AscendAttentionMetadataBuilder(None, None,
self.mock_vllm_config, self.mock_vllm_config,

View File

@@ -26,7 +26,7 @@ import torch.nn as nn
import torch_npu import torch_npu
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionType) AttentionLayer, AttentionType)
from vllm.config import VllmConfig from vllm.config import CUDAGraphMode, VllmConfig
from vllm.distributed import (get_dcp_group, from vllm.distributed import (get_dcp_group,
get_decode_context_model_parallel_rank, get_decode_context_model_parallel_rank,
get_decode_context_model_parallel_world_size) get_decode_context_model_parallel_world_size)
@@ -163,8 +163,8 @@ class AscendMetadataForPrefill:
@dataclass @dataclass
class AscendMetadataForDecode: class AscendMetadataForDecode:
""" Decode Specific Metadata for Ascend""" """ Decode Specific Metadata for Ascend"""
num_computed_tokens_of_pcp_dcp: Optional[list[Optional[list[Optional[ num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None
list[int]]]]]] = None batch_seq_mask: torch.Tensor = None
@dataclass @dataclass
@@ -232,10 +232,25 @@ class AscendAttentionMetadataBuilder:
): ):
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
self.compilation_config = vllm_config.compilation_config
self.device = device self.device = device
self.max_num_blocks_per_req = cdiv( self.max_num_blocks_per_req = cdiv(
self.model_config.max_model_len, self.model_config.max_model_len,
AscendAttentionBackend.get_supported_block_size()[0]) AscendAttentionBackend.get_supported_block_size()[0])
decode_max_num_seqs = getattr(vllm_config.scheduler_config,
'decode_max_num_seqs', 0)
max_num_seqs = max(vllm_config.scheduler_config.max_num_seqs,
decode_max_num_seqs)
self.batch_seq_mask_buf = torch.empty(max_num_seqs,
dtype=torch.uint8,
device=device)
self.pcp_size = get_prefill_context_model_parallel_world_size(
) if prefill_context_parallel_enable() else 1
self.pcp_rank = get_prefill_context_model_parallel_rank(
) if self.pcp_size > 1 else 0
self.dcp_size = get_decode_context_model_parallel_world_size()
self.dcp_rank = get_decode_context_model_parallel_rank(
) if self.dcp_size > 1 else 0
def reorder_batch(self, input_batch, def reorder_batch(self, input_batch,
scheduler_output: "SchedulerOutput") -> bool: scheduler_output: "SchedulerOutput") -> bool:
@@ -356,11 +371,32 @@ class AscendAttentionMetadataBuilder:
common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
if common_long_seq_metadata is not None: if common_long_seq_metadata is not None:
num_computed_tokens_of_pcp_dcp = common_long_seq_metadata.num_computed_tokens_of_pcp_dcp num_computed_tokens_of_pcp_dcp = common_long_seq_metadata.num_computed_tokens_of_pcp_dcp
num_computed_tokens_of_pcp_dcp = np.array( assert num_computed_tokens_of_pcp_dcp is not None
num_computed_tokens_array = np.array(
num_computed_tokens_of_pcp_dcp) num_computed_tokens_of_pcp_dcp)
num_computed_tokens_array = num_computed_tokens_array[:
num_decodes]
pad_length = common_attn_metadata.num_input_tokens - num_actual_tokens_pcp_padded // self.pcp_size
if self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY and pad_length > 0:
pad_tensor = np.zeros(
(pad_length, num_computed_tokens_array.shape[1],
num_computed_tokens_array.shape[2]),
dtype=num_computed_tokens_array.dtype)
num_computed_tokens_array = np.concatenate(
[num_computed_tokens_array, pad_tensor], axis=0)
batch_seq_mask = (
num_computed_tokens_array[:, self.pcp_rank,
self.dcp_rank] == 0)
# TODO: numpy array mode of the shared memory is used to improve performance
self.batch_seq_mask_buf[:batch_seq_mask.shape[0]].copy_(
torch.from_numpy(batch_seq_mask), non_blocking=True)
decode_metadata = AscendMetadataForDecode( decode_metadata = AscendMetadataForDecode(
num_computed_tokens_of_pcp_dcp= num_computed_tokens_of_pcp_dcp=num_computed_tokens_array,
num_computed_tokens_of_pcp_dcp) batch_seq_mask=self.batch_seq_mask_buf[:batch_seq_mask.
shape[0]],
)
attn_metadata = AscendMetadata( attn_metadata = AscendMetadata(
num_actual_tokens=num_actual_tokens, num_actual_tokens=num_actual_tokens,
@@ -869,7 +905,6 @@ class AscendAttentionBackendImpl(AttentionImpl):
else: else:
num_heads = self.num_heads num_heads = self.num_heads
q_nope = query.view(query.shape[0], 1, query.shape[1], query.shape[2])
k_nope = self.key_cache.view(self.key_cache.shape[0], k_nope = self.key_cache.view(self.key_cache.shape[0],
self.key_cache.shape[1], -1) self.key_cache.shape[1], -1)
value = self.value_cache.view(self.key_cache.shape[0], value = self.value_cache.view(self.key_cache.shape[0],
@@ -880,7 +915,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
'num_key_value_heads': 'num_key_value_heads':
self.num_kv_heads, self.num_kv_heads,
'input_layout': 'input_layout':
"BSND", 'TND',
'atten_mask': 'atten_mask':
None, None,
'scale': 'scale':
@@ -898,10 +933,12 @@ class AscendAttentionBackendImpl(AttentionImpl):
'actual_seq_lengths_kv': 'actual_seq_lengths_kv':
attn_metadata.decode_meta. attn_metadata.decode_meta.
num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank], num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank],
'actual_seq_lengths':
attn_metadata.actual_seq_lengths_q[:attn_metadata.num_decodes],
} }
graph_params = get_graph_params() graph_params = get_graph_params()
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
num_tokens = q_nope.shape[0] num_tokens = query.shape[0]
if forward_context.capturing: if forward_context.capturing:
stream = torch_npu.npu.current_stream() stream = torch_npu.npu.current_stream()
@@ -913,26 +950,27 @@ class AscendAttentionBackendImpl(AttentionImpl):
workspace = graph_params.workspaces.get(num_tokens) workspace = graph_params.workspaces.get(num_tokens)
if workspace is None: if workspace is None:
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace( workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
q_nope, k_nope, value, **common_kwargs) query, k_nope, value, **common_kwargs)
update_graph_params_workspaces(num_tokens, update_graph_params_workspaces(num_tokens,
weak_ref_tensors(workspace)) weak_ref_tensors(workspace))
attn_out = torch.empty_like(q_nope) attn_out = torch.empty_like(query)
attn_lse = torch.empty((num_tokens, num_heads, 1, 1), attn_lse = torch.empty((num_tokens, num_heads, 1),
dtype=torch.float, dtype=torch.float,
device=q_nope.device) device=query.device)
graph_params.attn_params[num_tokens].append( graph_params.attn_params[num_tokens].append((
(weak_ref_tensors(q_nope), weak_ref_tensors(k_nope), weak_ref_tensors(query), weak_ref_tensors(k_nope),
weak_ref_tensors(value), self.num_heads, self.num_kv_heads, weak_ref_tensors(value), self.num_heads, self.num_kv_heads,
self.scale, attn_metadata.block_tables, self.scale, attn_metadata.block_tables,
self.key_cache.shape[1], attn_metadata.decode_meta. self.key_cache.shape[1], attn_metadata.decode_meta.
num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, num_computed_tokens_of_pcp_dcp[:, self.pcp_rank,
self.dcp_rank], self.dcp_rank],
attn_metadata.actual_seq_lengths_q[:attn_metadata.num_decodes],
weak_ref_tensors(attn_out), weak_ref_tensors(attn_lse), weak_ref_tensors(attn_out), weak_ref_tensors(attn_lse),
self.pcp_rank, self.dcp_rank, self.dcp_size)) self.dcp_size, self.pcp_rank, self.dcp_rank))
torch.npu.graph_task_group_begin(stream) torch.npu.graph_task_group_begin(stream)
torch_npu.npu_fused_infer_attention_score.out( torch_npu.npu_fused_infer_attention_score.out(
q_nope, query,
k_nope, k_nope,
value, value,
**common_kwargs, **common_kwargs,
@@ -942,11 +980,17 @@ class AscendAttentionBackendImpl(AttentionImpl):
graph_params.handles[num_tokens].append(handle) graph_params.handles[num_tokens].append(handle)
else: else:
attn_out, attn_lse = torch_npu.npu_fused_infer_attention_score( attn_out, attn_lse = torch_npu.npu_fused_infer_attention_score(
q_nope, k_nope, value, **common_kwargs) query, k_nope, value, **common_kwargs)
attn_out = attn_out.view(attn_out.shape[0], attn_out.shape[2], out_mask = attn_metadata.decode_meta.batch_seq_mask[:, None,
attn_out.shape[3]) None].expand_as(
attn_lse = attn_lse.view(attn_lse.shape[0], attn_lse.shape[1], 1) attn_out)
attn_out = torch.where(out_mask, 0, attn_out)
lse_mask = attn_metadata.decode_meta.batch_seq_mask[:, None,
None].expand_as(
attn_lse)
attn_lse = torch.where(lse_mask, -torch.inf, attn_lse)
attn_out_lse_list = [] attn_out_lse_list = []
# Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1] # Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]

View File

@@ -140,8 +140,7 @@ class AscendMLADecodeMetadata:
attn_mask: Optional[torch.Tensor] = None attn_mask: Optional[torch.Tensor] = None
sin: torch.Tensor = None sin: torch.Tensor = None
cos: torch.Tensor = None cos: torch.Tensor = None
num_computed_tokens_of_pcp_dcp: Optional[list[Optional[list[Optional[ num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None
list[int]]]]]] = None
seq_mask_pcp: torch.Tensor = None seq_mask_pcp: torch.Tensor = None
seq_mask_dcp: torch.Tensor = None seq_mask_dcp: torch.Tensor = None
cp_seq_len: torch.Tensor = None cp_seq_len: torch.Tensor = None

View File

@@ -16,8 +16,7 @@ class AscendPrefillContextParallelMetadata:
num_actual_tokens_pcp_padded: Optional[int] = None num_actual_tokens_pcp_padded: Optional[int] = None
num_computed_tokens_of_pcp_dcp: Optional[list[Optional[list[Optional[ num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None
list[int]]]]]] = None
q_head_idx_tensor: torch.Tensor = None q_head_idx_tensor: torch.Tensor = None

View File

@@ -7,7 +7,6 @@ from dataclasses import dataclass
from typing import Any, Callable, Optional from typing import Any, Callable, Optional
from unittest.mock import patch from unittest.mock import patch
import numpy as np
import torch import torch
import torch_npu import torch_npu
import vllm.envs as envs import vllm.envs as envs
@@ -302,16 +301,25 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
graph_params.events[runtime_shape], graph_params.events[runtime_shape],
): ):
(q_nope, k_nope, value, num_heads, num_kv_heads, scale, (q_nope, k_nope, value, num_heads, num_kv_heads, scale,
block_table, block_size, actual_seq_lengths_kv, attn_output, block_table, block_size, actual_seq_lengths_kv,
softmax_lse, pcp_rank, dcp_rank, dcp_size) = param actual_seq_lengths_q, attn_output, softmax_lse, dcp_size,
actual_seq_lengths_kv = forward_context.attn_metadata[ pcp_rank, dcp_rank) = param
key].decode_meta.num_computed_tokens_of_pcp_dcp[:, pcp_rank, attn_metadata = forward_context.attn_metadata[key]
actual_seq_lengths_kv = attn_metadata.decode_meta.num_computed_tokens_of_pcp_dcp[:,
pcp_rank,
dcp_rank] dcp_rank]
pad_length = runtime_shape - len(actual_seq_lengths_kv) if (runtime_shape - len(actual_seq_lengths_kv)) > 0:
pad_tensor = np.zeros(pad_length, actual_seq_lengths_kv = actual_seq_lengths_kv + [0] * (
dtype=actual_seq_lengths_kv.dtype) runtime_shape - len(actual_seq_lengths_kv))
actual_seq_lengths_kv = np.concatenate(
[actual_seq_lengths_kv, pad_tensor]) actual_seq_lengths_q = attn_metadata.actual_seq_lengths_q[:
attn_metadata
.
num_decode_tokens]
if (runtime_shape - len(actual_seq_lengths_q)):
actual_seq_lengths_q = actual_seq_lengths_q + [
actual_seq_lengths_q[-1]
] * (runtime_shape - len(actual_seq_lengths_q))
if dcp_size > 1: if dcp_size > 1:
num_heads = num_heads * dcp_size num_heads = num_heads * dcp_size
@@ -323,7 +331,7 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
value, value,
num_heads=num_heads, num_heads=num_heads,
num_key_value_heads=num_kv_heads, num_key_value_heads=num_kv_heads,
input_layout="BSND", input_layout="TND",
atten_mask=None, atten_mask=None,
scale=scale, scale=scale,
antiquant_mode=0, antiquant_mode=0,
@@ -332,6 +340,7 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
block_table=block_table, block_table=block_table,
block_size=block_size, block_size=block_size,
actual_seq_lengths_kv=actual_seq_lengths_kv, actual_seq_lengths_kv=actual_seq_lengths_kv,
actual_seq_lengths=actual_seq_lengths_q,
workspace=graph_params.workspaces.get(runtime_shape), workspace=graph_params.workspaces.get(runtime_shape),
out=[attn_output, softmax_lse]) out=[attn_output, softmax_lse])
torch.npu.graph_task_update_end(update_stream) torch.npu.graph_task_update_end(update_stream)

View File

@@ -1411,11 +1411,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
req_indices, positions_np) req_indices, positions_np)
self.input_batch.block_table.commit_slot_mapping( self.input_batch.block_table.commit_slot_mapping(
total_num_scheduled_tokens) total_num_scheduled_tokens)
if self.pcp_size > 1:
tokens, position_pcp, pcp_unpad_mask = self._update_tokens_for_pcp( tokens, position_pcp, pcp_unpad_mask = self._update_tokens_for_pcp(
tokens) tokens)
num_scheduled_tokens = np.array(tokens, dtype=np.int32) num_scheduled_tokens = np.array(tokens, dtype=np.int32)
# update total_num_scheduled_tokens
total_num_scheduled_tokens = sum(num_scheduled_tokens[:num_reqs]) total_num_scheduled_tokens = sum(num_scheduled_tokens[:num_reqs])
else:
position_pcp, pcp_unpad_mask = None, None
self.num_pcp_pads = self.num_pcp_pads[:num_reqs]
total_num_pcp_pads = sum(self.num_pcp_pads) total_num_pcp_pads = sum(self.num_pcp_pads)
max_num_scheduled_tokens = max(tokens) max_num_scheduled_tokens = max(tokens)
@@ -4180,8 +4183,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
def _update_tokens_for_pcp(self, tokens): def _update_tokens_for_pcp(self, tokens):
num_reqs = self.input_batch.num_reqs num_reqs = self.input_batch.num_reqs
self.num_pcp_pads = self.num_pcp_pads[:num_reqs] self.num_pcp_pads = self.num_pcp_pads[:num_reqs]
if not self.pcp_size > 1:
return tokens, None, None
tokens = np.array(tokens, dtype=np.int32) tokens = np.array(tokens, dtype=np.int32)
num_decode_reqs = sum( num_decode_reqs = sum(
self.input_batch.num_computed_tokens_cpu[:num_reqs] >= self.input_batch.num_computed_tokens_cpu[:num_reqs] >=