[feat]dcp pcp support aclgraph (#3731)

### What this PR does / why we need it?
dcp pcp support  full aclgraph, including mla attention_v1

- vLLM version: v0.11.0rc3
- vLLM main:
c9461e05a4

Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
weiguihua2
2025-10-27 09:58:23 +08:00
committed by GitHub
parent 8ab8111fde
commit 4312a92a4f
5 changed files with 414 additions and 68 deletions

View File

@@ -176,16 +176,30 @@ class TestAscendMLAMetadata(TestBase):
class TestAscendMLAMetadataBuilder(TestBase):
def test_ascend_mla_metadata_builder_default(self):
@patch('vllm.distributed.parallel_state.get_dcp_group')
@patch('vllm.distributed.parallel_state._DCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
return_value=1)
def test_ascend_mla_metadata_builder_default(self, mock_get_dcp_size,
mock_dcp, mock_get_dcp_group):
mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.model_config.get_head_size.return_value = 64
mock_vllm_config.model_config.dtype = torch.float16
mock_vllm_config.cache_config.block_size = 16
mock_vllm_config.scheduler_config.max_num_seqs = 4
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
mock_device = 'cpu'
mock_dcp.world_size = 1
dcp_group = MagicMock(spec=GroupCoordinator)
dcp_group.rank_in_group = 0
dcp_group.world_size = 1
dcp_group.device_group = MagicMock()
mock_get_dcp_group.return_value = dcp_group
mock_vllm_config.speculative_config = None
ascend_config = MagicMock()
@@ -200,16 +214,31 @@ class TestAscendMLAMetadataBuilder(TestBase):
builder.chunked_prefill_enabled,
mock_vllm_config.scheduler_config.chunked_prefill_enabled)
def test_ascend_mla_metadata_builder_spec_decode(self):
@patch('vllm.distributed.parallel_state.get_dcp_group')
@patch('vllm.distributed.parallel_state._DCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
return_value=1)
def test_ascend_mla_metadata_builder_spec_decode(self, mock_get_dcp_size,
mock_dcp,
mock_get_dcp_group):
mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.model_config.get_head_size.return_value = 64
mock_vllm_config.model_config.dtype = torch.float16
mock_vllm_config.cache_config.block_size = 16
mock_vllm_config.scheduler_config.max_num_seqs = 4
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
mock_device = 'cpu'
mock_dcp.world_size = 1
dcp_group = MagicMock(spec=GroupCoordinator)
dcp_group.rank_in_group = 0
dcp_group.world_size = 1
dcp_group.device_group = MagicMock()
mock_get_dcp_group.return_value = dcp_group
mock_spec_config = MagicMock()
mock_spec_config.num_speculative_tokens = 3
mock_vllm_config.speculative_config = mock_spec_config
@@ -226,16 +255,30 @@ class TestAscendMLAMetadataBuilder(TestBase):
builder.chunked_prefill_enabled,
mock_vllm_config.scheduler_config.chunked_prefill_enabled)
def test_reorder_batch(self):
@patch('vllm.distributed.parallel_state.get_dcp_group')
@patch('vllm.distributed.parallel_state._DCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
return_value=1)
def test_reorder_batch(self, mock_get_dcp_size, mock_dcp,
mock_get_dcp_group):
ascend_config = MagicMock()
mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.cache_config.block_size = 16
mock_vllm_config.scheduler_config.max_num_seqs = 4
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
mock_device = 'cpu'
mock_dcp.world_size = 1
dcp_group = MagicMock(spec=GroupCoordinator)
dcp_group.rank_in_group = 0
dcp_group.world_size = 1
dcp_group.device_group = MagicMock()
mock_get_dcp_group.return_value = dcp_group
mock_vllm_config.speculative_config = None
with patch("vllm_ascend.attention.mla_v1.get_ascend_config",

View File

@@ -865,26 +865,81 @@ class AscendAttentionBackendImpl(AttentionImpl):
num_heads = self.num_heads
# 1. Compute out&lse by "npu_fused_infer_attention_score"
attn_out, attn_lse = torch.ops.npu.npu_fused_infer_attention_score(
query.view(query.shape[0], 1, query.shape[1], query.shape[2]),
# [b,num_heads,head_size] -> [b,1,num_heads,head_size]
self.key_cache.view(self.key_cache.shape[0],
self.key_cache.shape[1], -1),
self.value_cache.view(self.key_cache.shape[0],
self.key_cache.shape[1], -1),
num_heads=num_heads,
num_key_value_heads=self.num_kv_heads,
input_layout="BSND",
atten_mask=None,
scale=self.scale,
antiquant_mode=0,
antiquant_scale=None,
softmax_lse_flag=True,
block_table=attn_metadata.block_tables,
block_size=self.key_cache.shape[1],
actual_seq_lengths_kv=attn_metadata.decode_meta.
q_nope = query.view(query.shape[0], 1, query.shape[1], query.shape[2])
# [b,num_heads,head_size] -> [b,1,num_heads,head_size]
k_nope = self.key_cache.view(self.key_cache.shape[0],
self.key_cache.shape[1], -1)
value = self.value_cache.view(self.key_cache.shape[0],
self.key_cache.shape[1], -1)
common_kwargs = {
'num_heads':
num_heads,
'num_key_value_heads':
self.num_kv_heads,
'input_layout':
"BSND",
'atten_mask':
None,
'scale':
self.scale,
'antiquant_mode':
0,
'antiquant_scale':
None,
'softmax_lse_flag':
True,
'block_table':
attn_metadata.block_tables,
'block_size':
self.key_cache.shape[1],
"actual_seq_lengths_kv":
attn_metadata.decode_meta.
num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank],
)
}
graph_params = get_graph_params()
forward_context: ForwardContext = get_forward_context()
num_tokens = query.shape[0]
if forward_context.capturing:
stream = torch_npu.npu.current_stream()
event = torch.npu.ExternalEvent()
event.wait(stream)
event.reset(stream)
graph_params.events[num_tokens].append(event)
workspace = graph_params.workspaces.get(num_tokens)
if workspace is None:
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
q_nope, k_nope, value, **common_kwargs)
update_graph_params_workspaces(num_tokens,
weak_ref_tensors(workspace))
attn_out = torch.empty_like(q_nope)
attn_lse = torch.empty((num_tokens, num_heads, 1, 1),
dtype=torch.float,
device=q_nope.device)
graph_params.attn_params[num_tokens].append(
(weak_ref_tensors(q_nope), weak_ref_tensors(k_nope),
weak_ref_tensors(value), self.num_heads, self.num_kv_heads,
self.scale, attn_metadata.block_tables,
self.key_cache.shape[1], attn_metadata.decode_meta.
num_computed_tokens_of_pcp_dcp[:, self.pcp_rank,
self.dcp_rank],
weak_ref_tensors(attn_out), weak_ref_tensors(attn_lse),
self.pcp_rank, self.dcp_rank, self.dcp_size))
torch.npu.graph_task_group_begin(stream)
torch_npu.npu_fused_infer_attention_score.out(
q_nope,
k_nope,
value,
**common_kwargs,
workspace=workspace,
out=[attn_out, attn_lse])
handle = torch.npu.graph_task_group_end(stream)
graph_params.handles[num_tokens].append(handle)
else:
attn_out, attn_lse = torch_npu.npu_fused_infer_attention_score(
q_nope, k_nope, value, **common_kwargs)
attn_out = attn_out.view(attn_out.shape[0], attn_out.shape[2],
attn_out.shape[3])

View File

@@ -140,6 +140,9 @@ class AscendMLADecodeMetadata:
cos: torch.Tensor = None
num_computed_tokens_of_pcp_dcp: Optional[list[Optional[list[Optional[
list[int]]]]]] = None
seq_mask_pcp: torch.Tensor = None
seq_mask_dcp: torch.Tensor = None
cp_seq_len: torch.Tensor = None
@dataclass
@@ -259,6 +262,24 @@ class AscendMLAMetadataBuilder:
self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
self.cos_cache = None
self.sin_cache = None
self.pcp_size = get_prefill_context_model_parallel_world_size(
) if prefill_context_parallel_enable() else 1
self.cp_rank = get_prefill_context_model_parallel_rank(
) if self.pcp_size > 1 else 0
self.dcp_size = get_decode_context_model_parallel_world_size()
self.dcp_rank = get_decode_context_model_parallel_rank(
) if self.dcp_size > 1 else 0
decode_max_num_seqs = getattr(scheduler_config, 'decode_max_num_seqs',
0)
max_num_seqs = max(scheduler_config.max_num_seqs, decode_max_num_seqs)
self.seq_mask_pcp_buf = torch.empty(max_num_seqs,
self.pcp_size,
dtype=torch.uint8,
device=device)
self.seq_mask_dcp_buf = torch.empty(max_num_seqs,
self.dcp_size,
dtype=torch.uint8,
device=device)
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
@@ -463,6 +484,41 @@ class AscendMLAMetadataBuilder:
block_table = block_table[:num_decodes, ...]
seq_lens_list = seq_lens.tolist()
if num_computed_tokens_of_pcp_dcp is not None:
num_computed_tokens_of_cp_dcp_array = np.array(
num_computed_tokens_of_pcp_dcp
)[:num_decodes] # [bs, pcp_size, dcp_size]
seq_mask_pcp = torch.where(
torch.tensor(
num_computed_tokens_of_cp_dcp_array.sum(2)) == 0, 0,
1).to(torch.uint8)
self.seq_mask_pcp_buf[:seq_mask_pcp.shape[0], :seq_mask_pcp.
shape[1]].copy_(seq_mask_pcp,
non_blocking=True)
seq_mask_pcp_shape = (seq_mask_pcp.shape[0],
seq_mask_pcp.shape[1])
seq_mask_dcp = torch.where(
torch.tensor(
num_computed_tokens_of_cp_dcp_array[:,
self.cp_rank, :])
== 0, 0, 1).to(torch.uint8)
self.seq_mask_dcp_buf[:seq_mask_dcp.shape[0], :seq_mask_dcp.
shape[1]].copy_(seq_mask_dcp,
non_blocking=True)
seq_mask_dcp_shape = (seq_mask_dcp.shape[0],
seq_mask_dcp.shape[1])
cp_seq_len = num_computed_tokens_of_cp_dcp_array[:,
self.cp_rank,
self.dcp_rank]
cp_seq_len = torch.tensor(cp_seq_len, dtype=torch.int32)
cp_seq_len = torch.where(cp_seq_len == 0, 1, cp_seq_len)
else:
seq_mask_pcp_shape = (0, 0)
seq_mask_dcp_shape = (0, 0)
cp_seq_len = None
# TODO: After the fullgraph supports MTP, the if branch needs to deleted
assert self.cos_cache is not None
assert self.sin_cache is not None
@@ -485,7 +541,14 @@ class AscendMLAMetadataBuilder:
sin=sin,
cos=cos,
num_computed_tokens_of_pcp_dcp=
num_computed_tokens_of_pcp_dcp)
num_computed_tokens_of_pcp_dcp,
seq_mask_pcp=self.
seq_mask_pcp_buf[:seq_mask_pcp_shape[0], :
seq_mask_pcp_shape[1]],
seq_mask_dcp=self.
seq_mask_dcp_buf[:seq_mask_dcp_shape[0], :
seq_mask_dcp_shape[1]],
cp_seq_len=cp_seq_len)
else:
cos[:num_decode_tokens,
...] = self.cos_cache[input_positions].unsqueeze(
@@ -505,7 +568,14 @@ class AscendMLAMetadataBuilder:
sin=sin[:num_decode_tokens, ...],
cos=cos[:num_decode_tokens, ...],
num_computed_tokens_of_pcp_dcp=
num_computed_tokens_of_pcp_dcp)
num_computed_tokens_of_pcp_dcp,
seq_mask_pcp=self.
seq_mask_pcp_buf[:seq_mask_pcp_shape[0], :
seq_mask_pcp_shape[1]],
seq_mask_dcp=self.
seq_mask_dcp_buf[:seq_mask_dcp_shape[0], :
seq_mask_dcp_shape[1]],
cp_seq_len=cp_seq_len)
return self.metadata_cls( # type: ignore
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
@@ -1590,36 +1660,63 @@ class AscendMLAImpl(MLAAttentionImpl):
q_nope = q_nope.view(num_tokens, num_heads, -1)
q_pe = q_pe.view(num_tokens, num_heads, -1)
# use pcp & dcp split computed token nums from scheduler to compute actual seq_len and seq_mask
num_computed_tokens_of_pcp_dcp = np.array(
decode_meta.num_computed_tokens_of_pcp_dcp
)[:attn_metadata.num_decodes] # [bs, pcp_size, dcp_size]
seq_mask_pcp = torch.where(
torch.tensor(num_computed_tokens_of_pcp_dcp.sum(2)) == 0, 0,
1).to(torch.uint8).to(q_pe.device)
seq_mask_dcp = torch.where(
torch.tensor(
num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, :]) == 0, 0,
1).to(torch.uint8).to(q_pe.device)
seq_len = num_computed_tokens_of_pcp_dcp[:, self.pcp_rank,
self.dcp_rank]
seq_len = torch.tensor(seq_len, dtype=torch.int32)
# npu_multi_head_latent_attention does not support seq_len = 0,
# update where seq_len == 0 to 1.
# This will not influence result, since we will use seq_mask to update lse.
seq_len = torch.where(seq_len == 0, 1, seq_len)
seq_mask_pcp = decode_meta.seq_mask_pcp
seq_mask_dcp = decode_meta.seq_mask_dcp
seq_len = decode_meta.cp_seq_len
if torch.sum(seq_len).item() == 0:
# Case that no kv_cache has been stored on this rank, no need to do following computation.
attn_output = torch.zeros(
[num_tokens, num_heads, self.kv_lora_rank],
dtype=q_nope.dtype,
device=q_nope.device)
softmax_lse = torch.full((num_tokens, num_heads, 1),
float('-inf'),
dtype=q_nope.dtype,
device=q_nope.device)
common_kwargs = {
"return_lse": True,
"calc_type": "calc_type_ring",
}
graph_params = get_graph_params()
forward_context: ForwardContext = get_forward_context()
if forward_context.capturing:
stream = torch_npu.npu.current_stream()
event = torch.npu.ExternalEvent()
event.wait(stream)
event.reset(stream)
graph_params.events[num_tokens].append(event)
workspace = graph_params.workspaces.get(num_tokens)
if workspace is None:
workspace = torch_npu.atb._npu_multi_head_latent_attention_get_workspace(
q_nope, q_pe, k_nope, k_pe, decode_meta.block_table,
seq_len, num_heads, self.scale, self.num_kv_heads,
**common_kwargs)
update_graph_params_workspaces(num_tokens,
weak_ref_tensors(workspace))
attn_output = torch.empty_like(q_nope)
softmax_lse = torch.empty((num_tokens, num_heads, 1),
dtype=q_nope.dtype,
device=q_nope.device)
graph_params.attn_params[num_tokens].append(
(weak_ref_tensors(q_nope), weak_ref_tensors(q_pe),
weak_ref_tensors(k_nope), weak_ref_tensors(k_pe),
decode_meta.block_table, seq_len, num_heads, self.scale,
self.num_kv_heads, weak_ref_tensors(attn_output),
weak_ref_tensors(softmax_lse)))
torch.npu.graph_task_group_begin(stream)
torch_npu.atb.npu_multi_head_latent_attention(
q_nope,
q_pe,
k_nope,
k_pe,
decode_meta.block_table,
seq_len,
num_heads,
self.scale,
self.num_kv_heads,
**common_kwargs,
workspace=workspace,
output=attn_output,
lse=softmax_lse)
handle = torch.npu.graph_task_group_end(stream)
graph_params.handles[num_tokens].append(handle)
else:
attn_output, softmax_lse = torch_npu.atb.npu_multi_head_latent_attention(
attn_output = torch.empty_like(q_nope)
softmax_lse = torch.empty((num_tokens, num_heads, 1),
dtype=q_nope.dtype,
device=q_nope.device)
torch_npu.atb.npu_multi_head_latent_attention(
q_nope,
q_pe,
k_nope,
@@ -1630,7 +1727,9 @@ class AscendMLAImpl(MLAAttentionImpl):
self.scale,
self.num_kv_heads,
return_lse=True,
calc_type="calc_type_ring")
calc_type="calc_type_ring",
output=attn_output,
lse=softmax_lse)
if self.dcp_size > 1:
# Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]

View File

@@ -7,6 +7,7 @@ from dataclasses import dataclass
from typing import Any, Callable, Optional
from unittest.mock import patch
import numpy as np
import torch
import torch_npu
import vllm.envs as envs
@@ -300,6 +301,105 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
event.record(update_stream)
def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
graph_params = get_graph_params()
# FIXME: Behold! We are using a temporary hack here to update the args
# for each layer's attention op in the graph.
for key, param, handle, event in zip(
forward_context.attn_metadata,
graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
):
(q_nope, k_nope, value, num_heads, num_kv_heads, scale, block_table,
block_size, actual_seq_lengths_kv, attn_output, softmax_lse, cp_rank,
dcp_rank, dcp_size) = param
actual_seq_lengths_kv = forward_context.attn_metadata[
key].decode_meta.num_computed_tokens_of_pcp_dcp[:, cp_rank,
dcp_rank]
pad_length = runtime_shape - len(actual_seq_lengths_kv)
pad_tensor = np.zeros(pad_length, dtype=actual_seq_lengths_kv.dtype)
actual_seq_lengths_kv = np.concatenate(
[actual_seq_lengths_kv, pad_tensor])
if dcp_size > 1:
num_heads = num_heads * dcp_size
with torch.npu.stream(update_stream):
torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu.npu_fused_infer_attention_score.out(
q_nope,
k_nope,
value,
num_heads=num_heads,
num_key_value_heads=num_kv_heads,
input_layout="BSND",
atten_mask=None,
scale=scale,
antiquant_mode=0,
antiquant_scale=None,
softmax_lse_flag=True,
block_table=block_table,
block_size=block_size,
actual_seq_lengths_kv=actual_seq_lengths_kv,
workspace=graph_params.workspaces.get(runtime_shape),
out=[attn_output, softmax_lse])
torch.npu.graph_task_update_end(update_stream)
event.record(update_stream)
def update_mla_attn_dcp_pcp_params(update_stream, forward_context,
runtime_shape, speculative_config):
graph_params = get_graph_params()
# FIXME: Behold! We are using a temporary hack here to update the args
# for each layer's attention op in the graph.
for key, param, handle, event in zip(
forward_context.attn_metadata,
graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
):
(q_nope, q_pe, k_nope, k_pe, block_table, seq_len, num_heads, scale,
num_kv_heads, attn_output, softmax_lse) = param
decode_meta = forward_context.attn_metadata[key].decode
seq_len = decode_meta.cp_seq_len
if speculative_config and speculative_config.method == "deepseek_mtp":
spec_multiple = speculative_config.num_speculative_tokens + 1
seq_len = seq_len + [0] * (runtime_shape // spec_multiple -
len(seq_len))
else:
pad_length = runtime_shape - len(seq_len)
pad_tensor = torch.zeros(pad_length,
dtype=seq_len.dtype,
device=seq_len.device)
seq_len = torch.cat([seq_len, pad_tensor], dim=0)
with torch.npu.stream(update_stream):
torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu.atb.npu_multi_head_latent_attention(
q_nope,
q_pe,
k_nope,
k_pe,
block_table,
seq_len,
num_heads,
scale,
num_kv_heads,
return_lse=True,
calc_type="calc_type_ring",
workspace=graph_params.workspaces.get(runtime_shape),
output=attn_output,
lse=softmax_lse)
torch.npu.graph_task_update_end(update_stream)
event.record(update_stream)
@dataclass
class GraphParams:
events: dict[int, list[torch.npu.ExternalEvent]]

View File

@@ -110,10 +110,14 @@ from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
AscendPrefillContextParallelMetadata)
# yapf: disable
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
set_graph_params,
update_attn_dcp_pcp_params,
update_attn_params,
update_mla_attn_dcp_pcp_params,
update_mla_attn_params)
# yapf: enable
from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor
from vllm_ascend.eplb.core.eplb_device_transfer_loader import \
D2DExpertWeightLoader
@@ -1649,6 +1653,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
slot_mapping = blk_table.slot_mapping[:slot_mapping_size]
blk_table.slot_mapping[slot_mapping_size:].fill_(0)
if self.pcp_size > 1:
slot_mapping_for_pcp = blk_table.slot_mapping[:
long_seq_metadata
.
num_actual_tokens_pcp_padded]
slot_mapping_for_pcp[slot_mapping_size:].fill_(-1)
assert pcp_unpad_mask is not None
pcp_padded_slot_mapping = self.pcp_padded_slot_mapping[:
pcp_unpad_mask
@@ -1657,10 +1666,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
0]]
pcp_padded_slot_mapping.fill_(-1)
pcp_padded_slot_mapping[
pcp_unpad_mask] = blk_table.slot_mapping[:
slot_mapping_size]
blk_table.slot_mapping[:long_seq_metadata.
num_actual_tokens_pcp_padded] = pcp_padded_slot_mapping
pcp_unpad_mask] = slot_mapping_for_pcp[:
slot_mapping_size]
slot_mapping_for_pcp[:long_seq_metadata.
num_actual_tokens_pcp_padded] = pcp_padded_slot_mapping
slot_mapping = slot_mapping_for_pcp
# Make AscendCommonAttentionMetadata
common_attn_metadata = AscendCommonAttentionMetadata(
@@ -1749,13 +1759,25 @@ class NPUModelRunner(LoRAModelRunnerMixin):
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
# TODO: maybe_padded_num_tokens will be removed, use num_input_tokens instead
if self.vllm_config.model_config.use_mla:
# FIXME: Try using `auto_dispatch_capture=True`
update_mla_attn_params(self.update_stream, forward_context,
maybe_padded_num_tokens,
self.speculative_config)
if self.pcp_size * self.dcp_size > 1:
# FIXME: Try using `auto_dispatch_capture=True`
update_mla_attn_dcp_pcp_params(self.update_stream,
forward_context,
maybe_padded_num_tokens,
self.speculative_config)
else:
# FIXME: Try using `auto_dispatch_capture=True`
update_mla_attn_params(self.update_stream, forward_context,
maybe_padded_num_tokens,
self.speculative_config)
else:
update_attn_params(self.update_stream, forward_context,
maybe_padded_num_tokens)
if self.pcp_size * self.dcp_size > 1:
update_attn_dcp_pcp_params(self.update_stream,
forward_context,
maybe_padded_num_tokens)
else:
update_attn_params(self.update_stream, forward_context,
maybe_padded_num_tokens)
if get_forward_context().sp_enabled:
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
@@ -2488,6 +2510,19 @@ class NPUModelRunner(LoRAModelRunnerMixin):
kv_cache_group_id].get_device_tensor()
slot_mapping = self.input_batch.block_table[
kv_cache_group_id].slot_mapping
self.cp_kv_recover_idx = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device=self.device)
long_seq_metadata = self._generate_pcp_metadata(
num_tokens, self.seq_lens_cpu)
if long_seq_metadata is not None:
pcp_world_size = get_pcp_group(
).world_size if prefill_context_parallel_enable() else 1
dcp_world_size = get_dcp_group().world_size
num_computed_tokens_of_pcp_dcp = [[
[0] * dcp_world_size for _ in range(pcp_world_size)
] for _ in range(num_tokens)]
long_seq_metadata.num_computed_tokens_of_pcp_dcp = num_computed_tokens_of_pcp_dcp
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=torch.tensor(
[0] + self.actual_seq_lengths_q[:num_reqs],
@@ -2511,6 +2546,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
decode_token_per_req=self.decode_token_per_req,
cos=self.cos,
sin=self.sin,
prefill_context_parallel_metadata=long_seq_metadata,
)
attn_state = AscendAttentionState.DecodeOnly
if self.speculative_config and \
@@ -2540,12 +2576,25 @@ class NPUModelRunner(LoRAModelRunnerMixin):
not forward_context.capturing:
if self.vllm_config.model_config.use_mla:
# FIXME: Try using `auto_dispatch_capture=True`
update_mla_attn_params(self.update_stream, forward_context,
positions.shape[0],
self.speculative_config)
if self.pcp_size * self.dcp_size > 1:
# FIXME: Try using `auto_dispatch_capture=True`
update_mla_attn_dcp_pcp_params(self.update_stream,
forward_context,
positions.shape[0],
self.speculative_config)
else:
# FIXME: Try using `auto_dispatch_capture=True`
update_mla_attn_params(self.update_stream, forward_context,
positions.shape[0],
self.speculative_config)
else:
update_attn_params(self.update_stream, forward_context,
positions.shape[0])
if self.pcp_size * self.dcp_size > 1:
update_attn_dcp_pcp_params(self.update_stream,
forward_context,
positions.shape[0])
else:
update_attn_params(self.update_stream, forward_context,
positions.shape[0])
if self.drafter and self.drafter.name == SpecDcodeType.EAGLE3:
hidden_states, _ = hidden_states