[feat]dcp pcp support aclgraph (#3731)

### What this PR does / why we need it?
dcp pcp support  full aclgraph, including mla attention_v1

- vLLM version: v0.11.0rc3
- vLLM main:
c9461e05a4

Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
weiguihua2
2025-10-27 09:58:23 +08:00
committed by GitHub
parent 8ab8111fde
commit 4312a92a4f
5 changed files with 414 additions and 68 deletions

View File

@@ -176,16 +176,30 @@ class TestAscendMLAMetadata(TestBase):
class TestAscendMLAMetadataBuilder(TestBase): class TestAscendMLAMetadataBuilder(TestBase):
def test_ascend_mla_metadata_builder_default(self): @patch('vllm.distributed.parallel_state.get_dcp_group')
@patch('vllm.distributed.parallel_state._DCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
return_value=1)
def test_ascend_mla_metadata_builder_default(self, mock_get_dcp_size,
mock_dcp, mock_get_dcp_group):
mock_vllm_config = MagicMock() mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024 mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.model_config.get_head_size.return_value = 64 mock_vllm_config.model_config.get_head_size.return_value = 64
mock_vllm_config.model_config.dtype = torch.float16 mock_vllm_config.model_config.dtype = torch.float16
mock_vllm_config.cache_config.block_size = 16 mock_vllm_config.cache_config.block_size = 16
mock_vllm_config.scheduler_config.max_num_seqs = 4 mock_vllm_config.scheduler_config.max_num_seqs = 4
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
mock_device = 'cpu' mock_device = 'cpu'
mock_dcp.world_size = 1
dcp_group = MagicMock(spec=GroupCoordinator)
dcp_group.rank_in_group = 0
dcp_group.world_size = 1
dcp_group.device_group = MagicMock()
mock_get_dcp_group.return_value = dcp_group
mock_vllm_config.speculative_config = None mock_vllm_config.speculative_config = None
ascend_config = MagicMock() ascend_config = MagicMock()
@@ -200,16 +214,31 @@ class TestAscendMLAMetadataBuilder(TestBase):
builder.chunked_prefill_enabled, builder.chunked_prefill_enabled,
mock_vllm_config.scheduler_config.chunked_prefill_enabled) mock_vllm_config.scheduler_config.chunked_prefill_enabled)
def test_ascend_mla_metadata_builder_spec_decode(self): @patch('vllm.distributed.parallel_state.get_dcp_group')
@patch('vllm.distributed.parallel_state._DCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
return_value=1)
def test_ascend_mla_metadata_builder_spec_decode(self, mock_get_dcp_size,
mock_dcp,
mock_get_dcp_group):
mock_vllm_config = MagicMock() mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024 mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.model_config.get_head_size.return_value = 64 mock_vllm_config.model_config.get_head_size.return_value = 64
mock_vllm_config.model_config.dtype = torch.float16 mock_vllm_config.model_config.dtype = torch.float16
mock_vllm_config.cache_config.block_size = 16 mock_vllm_config.cache_config.block_size = 16
mock_vllm_config.scheduler_config.max_num_seqs = 4 mock_vllm_config.scheduler_config.max_num_seqs = 4
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
mock_device = 'cpu' mock_device = 'cpu'
mock_dcp.world_size = 1
dcp_group = MagicMock(spec=GroupCoordinator)
dcp_group.rank_in_group = 0
dcp_group.world_size = 1
dcp_group.device_group = MagicMock()
mock_get_dcp_group.return_value = dcp_group
mock_spec_config = MagicMock() mock_spec_config = MagicMock()
mock_spec_config.num_speculative_tokens = 3 mock_spec_config.num_speculative_tokens = 3
mock_vllm_config.speculative_config = mock_spec_config mock_vllm_config.speculative_config = mock_spec_config
@@ -226,16 +255,30 @@ class TestAscendMLAMetadataBuilder(TestBase):
builder.chunked_prefill_enabled, builder.chunked_prefill_enabled,
mock_vllm_config.scheduler_config.chunked_prefill_enabled) mock_vllm_config.scheduler_config.chunked_prefill_enabled)
def test_reorder_batch(self): @patch('vllm.distributed.parallel_state.get_dcp_group')
@patch('vllm.distributed.parallel_state._DCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
return_value=1)
def test_reorder_batch(self, mock_get_dcp_size, mock_dcp,
mock_get_dcp_group):
ascend_config = MagicMock() ascend_config = MagicMock()
mock_vllm_config = MagicMock() mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024 mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.cache_config.block_size = 16 mock_vllm_config.cache_config.block_size = 16
mock_vllm_config.scheduler_config.max_num_seqs = 4 mock_vllm_config.scheduler_config.max_num_seqs = 4
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
mock_device = 'cpu' mock_device = 'cpu'
mock_dcp.world_size = 1
dcp_group = MagicMock(spec=GroupCoordinator)
dcp_group.rank_in_group = 0
dcp_group.world_size = 1
dcp_group.device_group = MagicMock()
mock_get_dcp_group.return_value = dcp_group
mock_vllm_config.speculative_config = None mock_vllm_config.speculative_config = None
with patch("vllm_ascend.attention.mla_v1.get_ascend_config", with patch("vllm_ascend.attention.mla_v1.get_ascend_config",

View File

@@ -865,26 +865,81 @@ class AscendAttentionBackendImpl(AttentionImpl):
num_heads = self.num_heads num_heads = self.num_heads
# 1. Compute out&lse by "npu_fused_infer_attention_score" # 1. Compute out&lse by "npu_fused_infer_attention_score"
attn_out, attn_lse = torch.ops.npu.npu_fused_infer_attention_score( q_nope = query.view(query.shape[0], 1, query.shape[1], query.shape[2])
query.view(query.shape[0], 1, query.shape[1], query.shape[2]),
# [b,num_heads,head_size] -> [b,1,num_heads,head_size] # [b,num_heads,head_size] -> [b,1,num_heads,head_size]
self.key_cache.view(self.key_cache.shape[0], k_nope = self.key_cache.view(self.key_cache.shape[0],
self.key_cache.shape[1], -1), self.key_cache.shape[1], -1)
self.value_cache.view(self.key_cache.shape[0], value = self.value_cache.view(self.key_cache.shape[0],
self.key_cache.shape[1], -1), self.key_cache.shape[1], -1)
num_heads=num_heads, common_kwargs = {
num_key_value_heads=self.num_kv_heads, 'num_heads':
input_layout="BSND", num_heads,
atten_mask=None, 'num_key_value_heads':
scale=self.scale, self.num_kv_heads,
antiquant_mode=0, 'input_layout':
antiquant_scale=None, "BSND",
softmax_lse_flag=True, 'atten_mask':
block_table=attn_metadata.block_tables, None,
block_size=self.key_cache.shape[1], 'scale':
actual_seq_lengths_kv=attn_metadata.decode_meta. self.scale,
'antiquant_mode':
0,
'antiquant_scale':
None,
'softmax_lse_flag':
True,
'block_table':
attn_metadata.block_tables,
'block_size':
self.key_cache.shape[1],
"actual_seq_lengths_kv":
attn_metadata.decode_meta.
num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank], num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank],
) }
graph_params = get_graph_params()
forward_context: ForwardContext = get_forward_context()
num_tokens = query.shape[0]
if forward_context.capturing:
stream = torch_npu.npu.current_stream()
event = torch.npu.ExternalEvent()
event.wait(stream)
event.reset(stream)
graph_params.events[num_tokens].append(event)
workspace = graph_params.workspaces.get(num_tokens)
if workspace is None:
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
q_nope, k_nope, value, **common_kwargs)
update_graph_params_workspaces(num_tokens,
weak_ref_tensors(workspace))
attn_out = torch.empty_like(q_nope)
attn_lse = torch.empty((num_tokens, num_heads, 1, 1),
dtype=torch.float,
device=q_nope.device)
graph_params.attn_params[num_tokens].append(
(weak_ref_tensors(q_nope), weak_ref_tensors(k_nope),
weak_ref_tensors(value), self.num_heads, self.num_kv_heads,
self.scale, attn_metadata.block_tables,
self.key_cache.shape[1], attn_metadata.decode_meta.
num_computed_tokens_of_pcp_dcp[:, self.pcp_rank,
self.dcp_rank],
weak_ref_tensors(attn_out), weak_ref_tensors(attn_lse),
self.pcp_rank, self.dcp_rank, self.dcp_size))
torch.npu.graph_task_group_begin(stream)
torch_npu.npu_fused_infer_attention_score.out(
q_nope,
k_nope,
value,
**common_kwargs,
workspace=workspace,
out=[attn_out, attn_lse])
handle = torch.npu.graph_task_group_end(stream)
graph_params.handles[num_tokens].append(handle)
else:
attn_out, attn_lse = torch_npu.npu_fused_infer_attention_score(
q_nope, k_nope, value, **common_kwargs)
attn_out = attn_out.view(attn_out.shape[0], attn_out.shape[2], attn_out = attn_out.view(attn_out.shape[0], attn_out.shape[2],
attn_out.shape[3]) attn_out.shape[3])

View File

@@ -140,6 +140,9 @@ class AscendMLADecodeMetadata:
cos: torch.Tensor = None cos: torch.Tensor = None
num_computed_tokens_of_pcp_dcp: Optional[list[Optional[list[Optional[ num_computed_tokens_of_pcp_dcp: Optional[list[Optional[list[Optional[
list[int]]]]]] = None list[int]]]]]] = None
seq_mask_pcp: torch.Tensor = None
seq_mask_dcp: torch.Tensor = None
cp_seq_len: torch.Tensor = None
@dataclass @dataclass
@@ -259,6 +262,24 @@ class AscendMLAMetadataBuilder:
self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
self.cos_cache = None self.cos_cache = None
self.sin_cache = None self.sin_cache = None
self.pcp_size = get_prefill_context_model_parallel_world_size(
) if prefill_context_parallel_enable() else 1
self.cp_rank = get_prefill_context_model_parallel_rank(
) if self.pcp_size > 1 else 0
self.dcp_size = get_decode_context_model_parallel_world_size()
self.dcp_rank = get_decode_context_model_parallel_rank(
) if self.dcp_size > 1 else 0
decode_max_num_seqs = getattr(scheduler_config, 'decode_max_num_seqs',
0)
max_num_seqs = max(scheduler_config.max_num_seqs, decode_max_num_seqs)
self.seq_mask_pcp_buf = torch.empty(max_num_seqs,
self.pcp_size,
dtype=torch.uint8,
device=device)
self.seq_mask_dcp_buf = torch.empty(max_num_seqs,
self.dcp_size,
dtype=torch.uint8,
device=device)
def reorder_batch(self, input_batch: "InputBatch", def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool: scheduler_output: "SchedulerOutput") -> bool:
@@ -463,6 +484,41 @@ class AscendMLAMetadataBuilder:
block_table = block_table[:num_decodes, ...] block_table = block_table[:num_decodes, ...]
seq_lens_list = seq_lens.tolist() seq_lens_list = seq_lens.tolist()
if num_computed_tokens_of_pcp_dcp is not None:
num_computed_tokens_of_cp_dcp_array = np.array(
num_computed_tokens_of_pcp_dcp
)[:num_decodes] # [bs, pcp_size, dcp_size]
seq_mask_pcp = torch.where(
torch.tensor(
num_computed_tokens_of_cp_dcp_array.sum(2)) == 0, 0,
1).to(torch.uint8)
self.seq_mask_pcp_buf[:seq_mask_pcp.shape[0], :seq_mask_pcp.
shape[1]].copy_(seq_mask_pcp,
non_blocking=True)
seq_mask_pcp_shape = (seq_mask_pcp.shape[0],
seq_mask_pcp.shape[1])
seq_mask_dcp = torch.where(
torch.tensor(
num_computed_tokens_of_cp_dcp_array[:,
self.cp_rank, :])
== 0, 0, 1).to(torch.uint8)
self.seq_mask_dcp_buf[:seq_mask_dcp.shape[0], :seq_mask_dcp.
shape[1]].copy_(seq_mask_dcp,
non_blocking=True)
seq_mask_dcp_shape = (seq_mask_dcp.shape[0],
seq_mask_dcp.shape[1])
cp_seq_len = num_computed_tokens_of_cp_dcp_array[:,
self.cp_rank,
self.dcp_rank]
cp_seq_len = torch.tensor(cp_seq_len, dtype=torch.int32)
cp_seq_len = torch.where(cp_seq_len == 0, 1, cp_seq_len)
else:
seq_mask_pcp_shape = (0, 0)
seq_mask_dcp_shape = (0, 0)
cp_seq_len = None
# TODO: After the fullgraph supports MTP, the if branch needs to deleted # TODO: After the fullgraph supports MTP, the if branch needs to deleted
assert self.cos_cache is not None assert self.cos_cache is not None
assert self.sin_cache is not None assert self.sin_cache is not None
@@ -485,7 +541,14 @@ class AscendMLAMetadataBuilder:
sin=sin, sin=sin,
cos=cos, cos=cos,
num_computed_tokens_of_pcp_dcp= num_computed_tokens_of_pcp_dcp=
num_computed_tokens_of_pcp_dcp) num_computed_tokens_of_pcp_dcp,
seq_mask_pcp=self.
seq_mask_pcp_buf[:seq_mask_pcp_shape[0], :
seq_mask_pcp_shape[1]],
seq_mask_dcp=self.
seq_mask_dcp_buf[:seq_mask_dcp_shape[0], :
seq_mask_dcp_shape[1]],
cp_seq_len=cp_seq_len)
else: else:
cos[:num_decode_tokens, cos[:num_decode_tokens,
...] = self.cos_cache[input_positions].unsqueeze( ...] = self.cos_cache[input_positions].unsqueeze(
@@ -505,7 +568,14 @@ class AscendMLAMetadataBuilder:
sin=sin[:num_decode_tokens, ...], sin=sin[:num_decode_tokens, ...],
cos=cos[:num_decode_tokens, ...], cos=cos[:num_decode_tokens, ...],
num_computed_tokens_of_pcp_dcp= num_computed_tokens_of_pcp_dcp=
num_computed_tokens_of_pcp_dcp) num_computed_tokens_of_pcp_dcp,
seq_mask_pcp=self.
seq_mask_pcp_buf[:seq_mask_pcp_shape[0], :
seq_mask_pcp_shape[1]],
seq_mask_dcp=self.
seq_mask_dcp_buf[:seq_mask_dcp_shape[0], :
seq_mask_dcp_shape[1]],
cp_seq_len=cp_seq_len)
return self.metadata_cls( # type: ignore return self.metadata_cls( # type: ignore
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded, num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
@@ -1590,36 +1660,63 @@ class AscendMLAImpl(MLAAttentionImpl):
q_nope = q_nope.view(num_tokens, num_heads, -1) q_nope = q_nope.view(num_tokens, num_heads, -1)
q_pe = q_pe.view(num_tokens, num_heads, -1) q_pe = q_pe.view(num_tokens, num_heads, -1)
# use pcp & dcp split computed token nums from scheduler to compute actual seq_len and seq_mask # use pcp & dcp split computed token nums from scheduler to compute actual seq_len and seq_mask
num_computed_tokens_of_pcp_dcp = np.array( seq_mask_pcp = decode_meta.seq_mask_pcp
decode_meta.num_computed_tokens_of_pcp_dcp seq_mask_dcp = decode_meta.seq_mask_dcp
)[:attn_metadata.num_decodes] # [bs, pcp_size, dcp_size] seq_len = decode_meta.cp_seq_len
seq_mask_pcp = torch.where(
torch.tensor(num_computed_tokens_of_pcp_dcp.sum(2)) == 0, 0,
1).to(torch.uint8).to(q_pe.device)
seq_mask_dcp = torch.where(
torch.tensor(
num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, :]) == 0, 0,
1).to(torch.uint8).to(q_pe.device)
seq_len = num_computed_tokens_of_pcp_dcp[:, self.pcp_rank,
self.dcp_rank]
seq_len = torch.tensor(seq_len, dtype=torch.int32)
# npu_multi_head_latent_attention does not support seq_len = 0,
# update where seq_len == 0 to 1.
# This will not influence result, since we will use seq_mask to update lse.
seq_len = torch.where(seq_len == 0, 1, seq_len)
if torch.sum(seq_len).item() == 0: common_kwargs = {
# Case that no kv_cache has been stored on this rank, no need to do following computation. "return_lse": True,
attn_output = torch.zeros( "calc_type": "calc_type_ring",
[num_tokens, num_heads, self.kv_lora_rank], }
dtype=q_nope.dtype, graph_params = get_graph_params()
device=q_nope.device) forward_context: ForwardContext = get_forward_context()
softmax_lse = torch.full((num_tokens, num_heads, 1), if forward_context.capturing:
float('-inf'), stream = torch_npu.npu.current_stream()
event = torch.npu.ExternalEvent()
event.wait(stream)
event.reset(stream)
graph_params.events[num_tokens].append(event)
workspace = graph_params.workspaces.get(num_tokens)
if workspace is None:
workspace = torch_npu.atb._npu_multi_head_latent_attention_get_workspace(
q_nope, q_pe, k_nope, k_pe, decode_meta.block_table,
seq_len, num_heads, self.scale, self.num_kv_heads,
**common_kwargs)
update_graph_params_workspaces(num_tokens,
weak_ref_tensors(workspace))
attn_output = torch.empty_like(q_nope)
softmax_lse = torch.empty((num_tokens, num_heads, 1),
dtype=q_nope.dtype, dtype=q_nope.dtype,
device=q_nope.device) device=q_nope.device)
graph_params.attn_params[num_tokens].append(
(weak_ref_tensors(q_nope), weak_ref_tensors(q_pe),
weak_ref_tensors(k_nope), weak_ref_tensors(k_pe),
decode_meta.block_table, seq_len, num_heads, self.scale,
self.num_kv_heads, weak_ref_tensors(attn_output),
weak_ref_tensors(softmax_lse)))
torch.npu.graph_task_group_begin(stream)
torch_npu.atb.npu_multi_head_latent_attention(
q_nope,
q_pe,
k_nope,
k_pe,
decode_meta.block_table,
seq_len,
num_heads,
self.scale,
self.num_kv_heads,
**common_kwargs,
workspace=workspace,
output=attn_output,
lse=softmax_lse)
handle = torch.npu.graph_task_group_end(stream)
graph_params.handles[num_tokens].append(handle)
else: else:
attn_output, softmax_lse = torch_npu.atb.npu_multi_head_latent_attention( attn_output = torch.empty_like(q_nope)
softmax_lse = torch.empty((num_tokens, num_heads, 1),
dtype=q_nope.dtype,
device=q_nope.device)
torch_npu.atb.npu_multi_head_latent_attention(
q_nope, q_nope,
q_pe, q_pe,
k_nope, k_nope,
@@ -1630,7 +1727,9 @@ class AscendMLAImpl(MLAAttentionImpl):
self.scale, self.scale,
self.num_kv_heads, self.num_kv_heads,
return_lse=True, return_lse=True,
calc_type="calc_type_ring") calc_type="calc_type_ring",
output=attn_output,
lse=softmax_lse)
if self.dcp_size > 1: if self.dcp_size > 1:
# Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1] # Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]

View File

@@ -7,6 +7,7 @@ from dataclasses import dataclass
from typing import Any, Callable, Optional from typing import Any, Callable, Optional
from unittest.mock import patch from unittest.mock import patch
import numpy as np
import torch import torch
import torch_npu import torch_npu
import vllm.envs as envs import vllm.envs as envs
@@ -300,6 +301,105 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
event.record(update_stream) event.record(update_stream)
def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
graph_params = get_graph_params()
# FIXME: Behold! We are using a temporary hack here to update the args
# for each layer's attention op in the graph.
for key, param, handle, event in zip(
forward_context.attn_metadata,
graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
):
(q_nope, k_nope, value, num_heads, num_kv_heads, scale, block_table,
block_size, actual_seq_lengths_kv, attn_output, softmax_lse, cp_rank,
dcp_rank, dcp_size) = param
actual_seq_lengths_kv = forward_context.attn_metadata[
key].decode_meta.num_computed_tokens_of_pcp_dcp[:, cp_rank,
dcp_rank]
pad_length = runtime_shape - len(actual_seq_lengths_kv)
pad_tensor = np.zeros(pad_length, dtype=actual_seq_lengths_kv.dtype)
actual_seq_lengths_kv = np.concatenate(
[actual_seq_lengths_kv, pad_tensor])
if dcp_size > 1:
num_heads = num_heads * dcp_size
with torch.npu.stream(update_stream):
torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu.npu_fused_infer_attention_score.out(
q_nope,
k_nope,
value,
num_heads=num_heads,
num_key_value_heads=num_kv_heads,
input_layout="BSND",
atten_mask=None,
scale=scale,
antiquant_mode=0,
antiquant_scale=None,
softmax_lse_flag=True,
block_table=block_table,
block_size=block_size,
actual_seq_lengths_kv=actual_seq_lengths_kv,
workspace=graph_params.workspaces.get(runtime_shape),
out=[attn_output, softmax_lse])
torch.npu.graph_task_update_end(update_stream)
event.record(update_stream)
def update_mla_attn_dcp_pcp_params(update_stream, forward_context,
runtime_shape, speculative_config):
graph_params = get_graph_params()
# FIXME: Behold! We are using a temporary hack here to update the args
# for each layer's attention op in the graph.
for key, param, handle, event in zip(
forward_context.attn_metadata,
graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
):
(q_nope, q_pe, k_nope, k_pe, block_table, seq_len, num_heads, scale,
num_kv_heads, attn_output, softmax_lse) = param
decode_meta = forward_context.attn_metadata[key].decode
seq_len = decode_meta.cp_seq_len
if speculative_config and speculative_config.method == "deepseek_mtp":
spec_multiple = speculative_config.num_speculative_tokens + 1
seq_len = seq_len + [0] * (runtime_shape // spec_multiple -
len(seq_len))
else:
pad_length = runtime_shape - len(seq_len)
pad_tensor = torch.zeros(pad_length,
dtype=seq_len.dtype,
device=seq_len.device)
seq_len = torch.cat([seq_len, pad_tensor], dim=0)
with torch.npu.stream(update_stream):
torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu.atb.npu_multi_head_latent_attention(
q_nope,
q_pe,
k_nope,
k_pe,
block_table,
seq_len,
num_heads,
scale,
num_kv_heads,
return_lse=True,
calc_type="calc_type_ring",
workspace=graph_params.workspaces.get(runtime_shape),
output=attn_output,
lse=softmax_lse)
torch.npu.graph_task_update_end(update_stream)
event.record(update_stream)
@dataclass @dataclass
class GraphParams: class GraphParams:
events: dict[int, list[torch.npu.ExternalEvent]] events: dict[int, list[torch.npu.ExternalEvent]]

View File

@@ -110,10 +110,14 @@ from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
AscendPrefillContextParallelMetadata) AscendPrefillContextParallelMetadata)
# yapf: disable
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper, from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
set_graph_params, set_graph_params,
update_attn_dcp_pcp_params,
update_attn_params, update_attn_params,
update_mla_attn_dcp_pcp_params,
update_mla_attn_params) update_mla_attn_params)
# yapf: enable
from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor
from vllm_ascend.eplb.core.eplb_device_transfer_loader import \ from vllm_ascend.eplb.core.eplb_device_transfer_loader import \
D2DExpertWeightLoader D2DExpertWeightLoader
@@ -1649,6 +1653,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
slot_mapping = blk_table.slot_mapping[:slot_mapping_size] slot_mapping = blk_table.slot_mapping[:slot_mapping_size]
blk_table.slot_mapping[slot_mapping_size:].fill_(0) blk_table.slot_mapping[slot_mapping_size:].fill_(0)
if self.pcp_size > 1: if self.pcp_size > 1:
slot_mapping_for_pcp = blk_table.slot_mapping[:
long_seq_metadata
.
num_actual_tokens_pcp_padded]
slot_mapping_for_pcp[slot_mapping_size:].fill_(-1)
assert pcp_unpad_mask is not None assert pcp_unpad_mask is not None
pcp_padded_slot_mapping = self.pcp_padded_slot_mapping[: pcp_padded_slot_mapping = self.pcp_padded_slot_mapping[:
pcp_unpad_mask pcp_unpad_mask
@@ -1657,10 +1666,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
0]] 0]]
pcp_padded_slot_mapping.fill_(-1) pcp_padded_slot_mapping.fill_(-1)
pcp_padded_slot_mapping[ pcp_padded_slot_mapping[
pcp_unpad_mask] = blk_table.slot_mapping[: pcp_unpad_mask] = slot_mapping_for_pcp[:
slot_mapping_size] slot_mapping_size]
blk_table.slot_mapping[:long_seq_metadata. slot_mapping_for_pcp[:long_seq_metadata.
num_actual_tokens_pcp_padded] = pcp_padded_slot_mapping num_actual_tokens_pcp_padded] = pcp_padded_slot_mapping
slot_mapping = slot_mapping_for_pcp
# Make AscendCommonAttentionMetadata # Make AscendCommonAttentionMetadata
common_attn_metadata = AscendCommonAttentionMetadata( common_attn_metadata = AscendCommonAttentionMetadata(
@@ -1749,10 +1759,22 @@ class NPUModelRunner(LoRAModelRunnerMixin):
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL: if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
# TODO: maybe_padded_num_tokens will be removed, use num_input_tokens instead # TODO: maybe_padded_num_tokens will be removed, use num_input_tokens instead
if self.vllm_config.model_config.use_mla: if self.vllm_config.model_config.use_mla:
if self.pcp_size * self.dcp_size > 1:
# FIXME: Try using `auto_dispatch_capture=True`
update_mla_attn_dcp_pcp_params(self.update_stream,
forward_context,
maybe_padded_num_tokens,
self.speculative_config)
else:
# FIXME: Try using `auto_dispatch_capture=True` # FIXME: Try using `auto_dispatch_capture=True`
update_mla_attn_params(self.update_stream, forward_context, update_mla_attn_params(self.update_stream, forward_context,
maybe_padded_num_tokens, maybe_padded_num_tokens,
self.speculative_config) self.speculative_config)
else:
if self.pcp_size * self.dcp_size > 1:
update_attn_dcp_pcp_params(self.update_stream,
forward_context,
maybe_padded_num_tokens)
else: else:
update_attn_params(self.update_stream, forward_context, update_attn_params(self.update_stream, forward_context,
maybe_padded_num_tokens) maybe_padded_num_tokens)
@@ -2488,6 +2510,19 @@ class NPUModelRunner(LoRAModelRunnerMixin):
kv_cache_group_id].get_device_tensor() kv_cache_group_id].get_device_tensor()
slot_mapping = self.input_batch.block_table[ slot_mapping = self.input_batch.block_table[
kv_cache_group_id].slot_mapping kv_cache_group_id].slot_mapping
self.cp_kv_recover_idx = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device=self.device)
long_seq_metadata = self._generate_pcp_metadata(
num_tokens, self.seq_lens_cpu)
if long_seq_metadata is not None:
pcp_world_size = get_pcp_group(
).world_size if prefill_context_parallel_enable() else 1
dcp_world_size = get_dcp_group().world_size
num_computed_tokens_of_pcp_dcp = [[
[0] * dcp_world_size for _ in range(pcp_world_size)
] for _ in range(num_tokens)]
long_seq_metadata.num_computed_tokens_of_pcp_dcp = num_computed_tokens_of_pcp_dcp
common_attn_metadata = AscendCommonAttentionMetadata( common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=torch.tensor( query_start_loc=torch.tensor(
[0] + self.actual_seq_lengths_q[:num_reqs], [0] + self.actual_seq_lengths_q[:num_reqs],
@@ -2511,6 +2546,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
decode_token_per_req=self.decode_token_per_req, decode_token_per_req=self.decode_token_per_req,
cos=self.cos, cos=self.cos,
sin=self.sin, sin=self.sin,
prefill_context_parallel_metadata=long_seq_metadata,
) )
attn_state = AscendAttentionState.DecodeOnly attn_state = AscendAttentionState.DecodeOnly
if self.speculative_config and \ if self.speculative_config and \
@@ -2539,10 +2575,23 @@ class NPUModelRunner(LoRAModelRunnerMixin):
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and \ if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and \
not forward_context.capturing: not forward_context.capturing:
if self.vllm_config.model_config.use_mla: if self.vllm_config.model_config.use_mla:
# FIXME: Try using `auto_dispatch_capture=True`
if self.pcp_size * self.dcp_size > 1:
# FIXME: Try using `auto_dispatch_capture=True`
update_mla_attn_dcp_pcp_params(self.update_stream,
forward_context,
positions.shape[0],
self.speculative_config)
else:
# FIXME: Try using `auto_dispatch_capture=True` # FIXME: Try using `auto_dispatch_capture=True`
update_mla_attn_params(self.update_stream, forward_context, update_mla_attn_params(self.update_stream, forward_context,
positions.shape[0], positions.shape[0],
self.speculative_config) self.speculative_config)
else:
if self.pcp_size * self.dcp_size > 1:
update_attn_dcp_pcp_params(self.update_stream,
forward_context,
positions.shape[0])
else: else:
update_attn_params(self.update_stream, forward_context, update_attn_params(self.update_stream, forward_context,
positions.shape[0]) positions.shape[0])