[feat]decode convert bsnd to tnd and fix bug when pcp and dcp (#3980)
### What this PR does / why we need it?
1、in attention_v1 module, convert bsnd t0 tnd when pcp and dcp
2、fix tochair bug: service startup problem
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
83f478bb19
Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
@@ -63,10 +63,25 @@ class TestAscendAttentionBackend(TestBase):
|
||||
|
||||
class TestAscendAttentionMetadataBuilder(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
@patch('vllm.distributed.parallel_state.get_dcp_group')
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||
return_value=1)
|
||||
def setUp(self, mock_get_dcp_size, mock_dcp, mock_get_dcp_group):
|
||||
mock_dcp.world_size = 1
|
||||
dcp_group = MagicMock(spec=GroupCoordinator)
|
||||
dcp_group.rank_in_group = 0
|
||||
dcp_group.world_size = 1
|
||||
dcp_group.device_group = MagicMock()
|
||||
mock_get_dcp_group.return_value = dcp_group
|
||||
|
||||
self.mock_vllm_config = MagicMock()
|
||||
self.mock_vllm_config.model_config.max_model_len = 640
|
||||
self.mock_vllm_config.cache_config.block_size = 64
|
||||
self.mock_vllm_config.compilation_config.cudagraph_mode = None
|
||||
self.mock_vllm_config.scheduler_config.max_num_seqs = 10
|
||||
self.mock_vllm_config.scheduler_config.decode_max_num_seqs = 10
|
||||
self.mock_device = 'cpu:0'
|
||||
self.builder = AscendAttentionMetadataBuilder(None, None,
|
||||
self.mock_vllm_config,
|
||||
|
||||
@@ -26,7 +26,7 @@ import torch.nn as nn
|
||||
import torch_npu
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer, AttentionType)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.distributed import (get_dcp_group,
|
||||
get_decode_context_model_parallel_rank,
|
||||
get_decode_context_model_parallel_world_size)
|
||||
@@ -163,8 +163,8 @@ class AscendMetadataForPrefill:
|
||||
@dataclass
|
||||
class AscendMetadataForDecode:
|
||||
""" Decode Specific Metadata for Ascend"""
|
||||
num_computed_tokens_of_pcp_dcp: Optional[list[Optional[list[Optional[
|
||||
list[int]]]]]] = None
|
||||
num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None
|
||||
batch_seq_mask: torch.Tensor = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -232,10 +232,25 @@ class AscendAttentionMetadataBuilder:
|
||||
):
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.device = device
|
||||
self.max_num_blocks_per_req = cdiv(
|
||||
self.model_config.max_model_len,
|
||||
AscendAttentionBackend.get_supported_block_size()[0])
|
||||
decode_max_num_seqs = getattr(vllm_config.scheduler_config,
|
||||
'decode_max_num_seqs', 0)
|
||||
max_num_seqs = max(vllm_config.scheduler_config.max_num_seqs,
|
||||
decode_max_num_seqs)
|
||||
self.batch_seq_mask_buf = torch.empty(max_num_seqs,
|
||||
dtype=torch.uint8,
|
||||
device=device)
|
||||
self.pcp_size = get_prefill_context_model_parallel_world_size(
|
||||
) if prefill_context_parallel_enable() else 1
|
||||
self.pcp_rank = get_prefill_context_model_parallel_rank(
|
||||
) if self.pcp_size > 1 else 0
|
||||
self.dcp_size = get_decode_context_model_parallel_world_size()
|
||||
self.dcp_rank = get_decode_context_model_parallel_rank(
|
||||
) if self.dcp_size > 1 else 0
|
||||
|
||||
def reorder_batch(self, input_batch,
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
@@ -356,11 +371,32 @@ class AscendAttentionMetadataBuilder:
|
||||
common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
|
||||
if common_long_seq_metadata is not None:
|
||||
num_computed_tokens_of_pcp_dcp = common_long_seq_metadata.num_computed_tokens_of_pcp_dcp
|
||||
num_computed_tokens_of_pcp_dcp = np.array(
|
||||
assert num_computed_tokens_of_pcp_dcp is not None
|
||||
num_computed_tokens_array = np.array(
|
||||
num_computed_tokens_of_pcp_dcp)
|
||||
num_computed_tokens_array = num_computed_tokens_array[:
|
||||
num_decodes]
|
||||
pad_length = common_attn_metadata.num_input_tokens - num_actual_tokens_pcp_padded // self.pcp_size
|
||||
if self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY and pad_length > 0:
|
||||
pad_tensor = np.zeros(
|
||||
(pad_length, num_computed_tokens_array.shape[1],
|
||||
num_computed_tokens_array.shape[2]),
|
||||
dtype=num_computed_tokens_array.dtype)
|
||||
|
||||
num_computed_tokens_array = np.concatenate(
|
||||
[num_computed_tokens_array, pad_tensor], axis=0)
|
||||
|
||||
batch_seq_mask = (
|
||||
num_computed_tokens_array[:, self.pcp_rank,
|
||||
self.dcp_rank] == 0)
|
||||
# TODO: numpy array mode of the shared memory is used to improve performance
|
||||
self.batch_seq_mask_buf[:batch_seq_mask.shape[0]].copy_(
|
||||
torch.from_numpy(batch_seq_mask), non_blocking=True)
|
||||
decode_metadata = AscendMetadataForDecode(
|
||||
num_computed_tokens_of_pcp_dcp=
|
||||
num_computed_tokens_of_pcp_dcp)
|
||||
num_computed_tokens_of_pcp_dcp=num_computed_tokens_array,
|
||||
batch_seq_mask=self.batch_seq_mask_buf[:batch_seq_mask.
|
||||
shape[0]],
|
||||
)
|
||||
|
||||
attn_metadata = AscendMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
@@ -869,7 +905,6 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
else:
|
||||
num_heads = self.num_heads
|
||||
|
||||
q_nope = query.view(query.shape[0], 1, query.shape[1], query.shape[2])
|
||||
k_nope = self.key_cache.view(self.key_cache.shape[0],
|
||||
self.key_cache.shape[1], -1)
|
||||
value = self.value_cache.view(self.key_cache.shape[0],
|
||||
@@ -880,7 +915,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
'num_key_value_heads':
|
||||
self.num_kv_heads,
|
||||
'input_layout':
|
||||
"BSND",
|
||||
'TND',
|
||||
'atten_mask':
|
||||
None,
|
||||
'scale':
|
||||
@@ -898,10 +933,12 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
'actual_seq_lengths_kv':
|
||||
attn_metadata.decode_meta.
|
||||
num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank],
|
||||
'actual_seq_lengths':
|
||||
attn_metadata.actual_seq_lengths_q[:attn_metadata.num_decodes],
|
||||
}
|
||||
graph_params = get_graph_params()
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
num_tokens = q_nope.shape[0]
|
||||
num_tokens = query.shape[0]
|
||||
if forward_context.capturing:
|
||||
stream = torch_npu.npu.current_stream()
|
||||
|
||||
@@ -913,26 +950,27 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
workspace = graph_params.workspaces.get(num_tokens)
|
||||
if workspace is None:
|
||||
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
|
||||
q_nope, k_nope, value, **common_kwargs)
|
||||
query, k_nope, value, **common_kwargs)
|
||||
update_graph_params_workspaces(num_tokens,
|
||||
weak_ref_tensors(workspace))
|
||||
attn_out = torch.empty_like(q_nope)
|
||||
attn_lse = torch.empty((num_tokens, num_heads, 1, 1),
|
||||
attn_out = torch.empty_like(query)
|
||||
attn_lse = torch.empty((num_tokens, num_heads, 1),
|
||||
dtype=torch.float,
|
||||
device=q_nope.device)
|
||||
device=query.device)
|
||||
|
||||
graph_params.attn_params[num_tokens].append(
|
||||
(weak_ref_tensors(q_nope), weak_ref_tensors(k_nope),
|
||||
weak_ref_tensors(value), self.num_heads, self.num_kv_heads,
|
||||
self.scale, attn_metadata.block_tables,
|
||||
self.key_cache.shape[1], attn_metadata.decode_meta.
|
||||
num_computed_tokens_of_pcp_dcp[:, self.pcp_rank,
|
||||
self.dcp_rank],
|
||||
weak_ref_tensors(attn_out), weak_ref_tensors(attn_lse),
|
||||
self.pcp_rank, self.dcp_rank, self.dcp_size))
|
||||
graph_params.attn_params[num_tokens].append((
|
||||
weak_ref_tensors(query), weak_ref_tensors(k_nope),
|
||||
weak_ref_tensors(value), self.num_heads, self.num_kv_heads,
|
||||
self.scale, attn_metadata.block_tables,
|
||||
self.key_cache.shape[1], attn_metadata.decode_meta.
|
||||
num_computed_tokens_of_pcp_dcp[:, self.pcp_rank,
|
||||
self.dcp_rank],
|
||||
attn_metadata.actual_seq_lengths_q[:attn_metadata.num_decodes],
|
||||
weak_ref_tensors(attn_out), weak_ref_tensors(attn_lse),
|
||||
self.dcp_size, self.pcp_rank, self.dcp_rank))
|
||||
torch.npu.graph_task_group_begin(stream)
|
||||
torch_npu.npu_fused_infer_attention_score.out(
|
||||
q_nope,
|
||||
query,
|
||||
k_nope,
|
||||
value,
|
||||
**common_kwargs,
|
||||
@@ -942,11 +980,17 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
graph_params.handles[num_tokens].append(handle)
|
||||
else:
|
||||
attn_out, attn_lse = torch_npu.npu_fused_infer_attention_score(
|
||||
q_nope, k_nope, value, **common_kwargs)
|
||||
query, k_nope, value, **common_kwargs)
|
||||
|
||||
attn_out = attn_out.view(attn_out.shape[0], attn_out.shape[2],
|
||||
attn_out.shape[3])
|
||||
attn_lse = attn_lse.view(attn_lse.shape[0], attn_lse.shape[1], 1)
|
||||
out_mask = attn_metadata.decode_meta.batch_seq_mask[:, None,
|
||||
None].expand_as(
|
||||
attn_out)
|
||||
attn_out = torch.where(out_mask, 0, attn_out)
|
||||
|
||||
lse_mask = attn_metadata.decode_meta.batch_seq_mask[:, None,
|
||||
None].expand_as(
|
||||
attn_lse)
|
||||
attn_lse = torch.where(lse_mask, -torch.inf, attn_lse)
|
||||
|
||||
attn_out_lse_list = []
|
||||
# Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
|
||||
|
||||
@@ -140,8 +140,7 @@ class AscendMLADecodeMetadata:
|
||||
attn_mask: Optional[torch.Tensor] = None
|
||||
sin: torch.Tensor = None
|
||||
cos: torch.Tensor = None
|
||||
num_computed_tokens_of_pcp_dcp: Optional[list[Optional[list[Optional[
|
||||
list[int]]]]]] = None
|
||||
num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None
|
||||
seq_mask_pcp: torch.Tensor = None
|
||||
seq_mask_dcp: torch.Tensor = None
|
||||
cp_seq_len: torch.Tensor = None
|
||||
|
||||
@@ -16,8 +16,7 @@ class AscendPrefillContextParallelMetadata:
|
||||
|
||||
num_actual_tokens_pcp_padded: Optional[int] = None
|
||||
|
||||
num_computed_tokens_of_pcp_dcp: Optional[list[Optional[list[Optional[
|
||||
list[int]]]]]] = None
|
||||
num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None
|
||||
|
||||
q_head_idx_tensor: torch.Tensor = None
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch_npu
|
||||
import vllm.envs as envs
|
||||
@@ -302,16 +301,25 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
|
||||
graph_params.events[runtime_shape],
|
||||
):
|
||||
(q_nope, k_nope, value, num_heads, num_kv_heads, scale,
|
||||
block_table, block_size, actual_seq_lengths_kv, attn_output,
|
||||
softmax_lse, pcp_rank, dcp_rank, dcp_size) = param
|
||||
actual_seq_lengths_kv = forward_context.attn_metadata[
|
||||
key].decode_meta.num_computed_tokens_of_pcp_dcp[:, pcp_rank,
|
||||
dcp_rank]
|
||||
pad_length = runtime_shape - len(actual_seq_lengths_kv)
|
||||
pad_tensor = np.zeros(pad_length,
|
||||
dtype=actual_seq_lengths_kv.dtype)
|
||||
actual_seq_lengths_kv = np.concatenate(
|
||||
[actual_seq_lengths_kv, pad_tensor])
|
||||
block_table, block_size, actual_seq_lengths_kv,
|
||||
actual_seq_lengths_q, attn_output, softmax_lse, dcp_size,
|
||||
pcp_rank, dcp_rank) = param
|
||||
attn_metadata = forward_context.attn_metadata[key]
|
||||
actual_seq_lengths_kv = attn_metadata.decode_meta.num_computed_tokens_of_pcp_dcp[:,
|
||||
pcp_rank,
|
||||
dcp_rank]
|
||||
if (runtime_shape - len(actual_seq_lengths_kv)) > 0:
|
||||
actual_seq_lengths_kv = actual_seq_lengths_kv + [0] * (
|
||||
runtime_shape - len(actual_seq_lengths_kv))
|
||||
|
||||
actual_seq_lengths_q = attn_metadata.actual_seq_lengths_q[:
|
||||
attn_metadata
|
||||
.
|
||||
num_decode_tokens]
|
||||
if (runtime_shape - len(actual_seq_lengths_q)):
|
||||
actual_seq_lengths_q = actual_seq_lengths_q + [
|
||||
actual_seq_lengths_q[-1]
|
||||
] * (runtime_shape - len(actual_seq_lengths_q))
|
||||
if dcp_size > 1:
|
||||
num_heads = num_heads * dcp_size
|
||||
|
||||
@@ -323,7 +331,7 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
|
||||
value,
|
||||
num_heads=num_heads,
|
||||
num_key_value_heads=num_kv_heads,
|
||||
input_layout="BSND",
|
||||
input_layout="TND",
|
||||
atten_mask=None,
|
||||
scale=scale,
|
||||
antiquant_mode=0,
|
||||
@@ -332,6 +340,7 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
|
||||
block_table=block_table,
|
||||
block_size=block_size,
|
||||
actual_seq_lengths_kv=actual_seq_lengths_kv,
|
||||
actual_seq_lengths=actual_seq_lengths_q,
|
||||
workspace=graph_params.workspaces.get(runtime_shape),
|
||||
out=[attn_output, softmax_lse])
|
||||
torch.npu.graph_task_update_end(update_stream)
|
||||
|
||||
@@ -1411,11 +1411,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
req_indices, positions_np)
|
||||
self.input_batch.block_table.commit_slot_mapping(
|
||||
total_num_scheduled_tokens)
|
||||
tokens, position_pcp, pcp_unpad_mask = self._update_tokens_for_pcp(
|
||||
tokens)
|
||||
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
|
||||
# update total_num_scheduled_tokens
|
||||
total_num_scheduled_tokens = sum(num_scheduled_tokens[:num_reqs])
|
||||
if self.pcp_size > 1:
|
||||
tokens, position_pcp, pcp_unpad_mask = self._update_tokens_for_pcp(
|
||||
tokens)
|
||||
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
|
||||
total_num_scheduled_tokens = sum(num_scheduled_tokens[:num_reqs])
|
||||
else:
|
||||
position_pcp, pcp_unpad_mask = None, None
|
||||
self.num_pcp_pads = self.num_pcp_pads[:num_reqs]
|
||||
|
||||
total_num_pcp_pads = sum(self.num_pcp_pads)
|
||||
max_num_scheduled_tokens = max(tokens)
|
||||
@@ -4180,8 +4183,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
def _update_tokens_for_pcp(self, tokens):
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
self.num_pcp_pads = self.num_pcp_pads[:num_reqs]
|
||||
if not self.pcp_size > 1:
|
||||
return tokens, None, None
|
||||
tokens = np.array(tokens, dtype=np.int32)
|
||||
num_decode_reqs = sum(
|
||||
self.input_batch.num_computed_tokens_cpu[:num_reqs] >=
|
||||
|
||||
Reference in New Issue
Block a user