Reapply "[Refactor] Unify full-graph parameter update logic (#6041)" (#6227) (#6231)

This reverts commit 95649344aa.

The CI failure doesn't related to this change. Let's reapply it.

- vLLM version: v0.14.0
- vLLM main:
d68209402d
This commit is contained in:
wangxiyuan
2026-01-26 09:04:54 +08:00
committed by GitHub
parent c38c838d03
commit 4e3919e965
10 changed files with 420 additions and 415 deletions

View File

@@ -24,12 +24,15 @@ from vllm.forward_context import BatchDescriptor, ForwardContext
from tests.ut.base import TestBase
from vllm_ascend.attention.attention_v1 import (AscendMetadata,
AscendMetadataForDecode)
from vllm_ascend.attention.context_parallel.attention_cp import \
AscendAttentionCPImpl
from vllm_ascend.attention.context_parallel.mla_cp import AscendMlaCPImpl
from vllm_ascend.attention.mla_v1 import (AscendMLADecodeMetadata,
AscendMLAMetadata)
from vllm_ascend.compilation.acl_graph import (
ACLGraphEntry, ACLGraphWrapper, get_draft_graph_params, get_graph_params,
set_draft_graph_params, set_graph_params, update_attn_dcp_pcp_params,
update_draft_graph_params_workspaces, update_mla_attn_dcp_pcp_params)
set_draft_graph_params, set_graph_params,
update_draft_graph_params_workspaces)
class TestACLGraphEntry(TestBase):
@@ -811,8 +814,9 @@ class TestPCPDCPGraphParams(TestBase):
out, lse))
with patch("torch_npu._C._npu_setStream", return_value=None):
update_mla_attn_dcp_pcp_params(self.update_stream, forward_context,
4)
AscendMlaCPImpl.update_graph_params(
self.update_stream, forward_context, 4
)
_mock_graph_task_end.assert_called_once()
@@ -852,6 +856,8 @@ class TestPCPDCPGraphParams(TestBase):
out, lse, 2, 0, 0))
with patch("torch_npu._C._npu_setStream", return_value=None):
update_attn_dcp_pcp_params(self.update_stream, forward_context, 4)
AscendAttentionCPImpl.update_graph_params(
self.update_stream, forward_context, 4, None
)
_mock_graph_task_end.assert_called_once()

View File

@@ -333,11 +333,11 @@ class TestEagleProposerDummyRun(TestBase):
self.proposer.dummy_run(num_tokens=64, with_prefill=True, num_reqs=4)
self.assertTrue(self.proposer._runnable.call_count == 1)
@patch("vllm_ascend.spec_decode.eagle_proposer.update_attn_params")
@patch("vllm_ascend.spec_decode.eagle_proposer.update_full_graph_params")
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context")
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
def test_dummy_run_in_graph_capture(self, mock_context, mock_get_context,
mock_update_attn_params):
mock_update_full_graph_params):
last_use_cuda_graph = self.proposer.use_cuda_graph
mock_return_context = MagicMock()
mock_return_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
@@ -352,14 +352,14 @@ class TestEagleProposerDummyRun(TestBase):
in_graph_capturing=True,
aclgraph_runtime_mode=CUDAGraphMode.FULL)
self.assertTrue(self.proposer._runnable.call_count == 1)
mock_update_attn_params.assert_not_called()
mock_update_full_graph_params.assert_not_called()
self.proposer.use_cuda_graph = last_use_cuda_graph
@patch("vllm_ascend.spec_decode.eagle_proposer.update_attn_params")
@patch("vllm_ascend.spec_decode.eagle_proposer.update_full_graph_params")
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context")
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
def test_dummy_run_in_graph_run(self, mock_context, mock_get_context,
mock_update_attn_params):
mock_update_full_graph_params):
last_use_cuda_graph = self.proposer.use_cuda_graph
mock_return_context = MagicMock()
mock_return_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
@@ -374,7 +374,7 @@ class TestEagleProposerDummyRun(TestBase):
in_graph_capturing=False,
aclgraph_runtime_mode=CUDAGraphMode.FULL)
self.assertTrue(self.proposer._runnable.call_count == 1)
self.assertTrue(mock_update_attn_params.call_count == 1)
self.assertTrue(mock_update_full_graph_params.call_count == 1)
self.proposer.use_cuda_graph = last_use_cuda_graph

View File

@@ -371,6 +371,144 @@ class AscendAttentionBackendImpl(AttentionImpl):
self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
)
@staticmethod
def update_graph_params(
update_stream,
forward_context,
num_tokens,
vllm_config,
speculative_config=None,
num_dcp_pcp_tokens=None,
):
if using_paged_attention(num_tokens, vllm_config):
# Paged Attention update logic
if forward_context.is_draft_model:
graph_params = get_draft_graph_params()
else:
graph_params = get_graph_params()
with torch.npu.stream(update_stream):
for key, param, handle, event in zip(
forward_context.attn_metadata,
graph_params.attn_params[num_tokens],
graph_params.handles[num_tokens],
graph_params.events[num_tokens],
):
(
query,
key_cache,
value_cache,
num_kv_heads,
num_heads,
scale,
block_table,
seq_lens,
output,
) = param
seq_lens = forward_context.attn_metadata[key].seq_lens
workspace = torch_npu._npu_paged_attention_get_workspace(
query=query,
key_cache=key_cache,
value_cache=value_cache,
num_kv_heads=num_kv_heads,
num_heads=num_heads,
scale_value=scale,
block_table=block_table,
context_lens=seq_lens,
out=output,
)
torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu._npu_paged_attention(
query=query,
key_cache=key_cache,
value_cache=value_cache,
num_kv_heads=num_kv_heads,
num_heads=num_heads,
scale_value=scale,
block_table=block_table,
context_lens=seq_lens,
out=output,
workspace=workspace,
)
torch.npu.graph_task_update_end(update_stream)
event.record(update_stream)
else:
# FIA update logic
if forward_context.is_draft_model:
graph_params = get_draft_graph_params()
attn_metadata = forward_context.draft_attn_metadatas
attn_keys = list(attn_metadata[0].keys())
else:
graph_params = get_graph_params()
attn_metadata = forward_context.attn_metadata
attn_keys = list(attn_metadata.keys())
# For Qwen3-next, since the kv_cache_config has already categorized
# linear_attn and self_attn, the attn_metadata is first arranged with
# self_attn followed by linear_attn. Therefore, using zip directly
# filters out the update operations for linear_attn.
# TODO: We use a new variable `attn_keys` to ensure the loop count is
# correct after get by `zip` because of the new structure of the attn_metadata
# when running with the merged full eagle-graph. Should check it with Qwen3-next.
num_layers = len(attn_keys)
if num_layers == 0:
return
if forward_context.is_draft_model:
attn_keys = attn_keys * (len(graph_params.attn_params[num_tokens]) // num_layers)
attn_count = 0
with torch.npu.stream(update_stream):
for key, param, handle, event in zip(
attn_keys,
graph_params.attn_params[num_tokens],
graph_params.handles[num_tokens],
graph_params.events[num_tokens],
):
(
query,
key_cache,
value,
block_tables,
attn_mask,
block_size,
seq_lens,
query_start_loc,
num_kv_heads,
num_heads,
scale,
attn_output,
softmax_lse,
) = param
if forward_context.is_draft_model:
draft_step = attn_count // num_layers
seq_lens = attn_metadata[draft_step][key].seq_lens_list
actual_seq_lengths_q = attn_metadata[draft_step][key].actual_seq_lengths_q
attn_count = attn_count + 1
else:
seq_lens = attn_metadata[key].seq_lens_list
actual_seq_lengths_q = attn_metadata[key].actual_seq_lengths_q
torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu.npu_fused_infer_attention_score.out(
query=query,
key=key_cache,
value=value,
block_table=block_tables,
atten_mask=attn_mask,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=actual_seq_lengths_q,
actual_seq_lengths_kv=seq_lens,
num_key_value_heads=num_kv_heads,
num_heads=num_heads,
scale=scale,
sparse_mode=3,
workspace=graph_params.workspaces.get(num_tokens),
out=[attn_output, softmax_lse],
)
torch.npu.graph_task_update_end(update_stream)
event.record(update_stream)
def process_weights_after_loading(self, act_dtype: torch.dtype):
super().process_weights_after_loading(act_dtype)
if flashcomm2_oshard_manager.flashcomm2_oshard_enable():

View File

@@ -276,6 +276,79 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
self.dcp_rank = get_decode_context_model_parallel_rank() if self.dcp_size > 1 else 0
self.dcp_group = get_dcp_group().device_group if self.dcp_size > 1 else None
@staticmethod
def update_graph_params(
update_stream,
forward_context,
num_tokens,
vllm_config,
speculative_config=None,
num_dcp_pcp_tokens=None,
):
graph_params = get_graph_params()
# FIXME: Behold! We are using a temporary hack here to update the args
# for each layer's attention op in the graph.
with torch.npu.stream(update_stream):
for key, param, handle, event in zip(
forward_context.attn_metadata,
graph_params.attn_params[num_tokens],
graph_params.handles[num_tokens],
graph_params.events[num_tokens],
):
(
q_nope,
k_nope,
value,
num_heads,
num_kv_heads,
scale,
block_table,
block_size,
actual_seq_lengths_kv,
actual_seq_lengths_q,
attn_output,
softmax_lse,
dcp_size,
pcp_rank,
dcp_rank,
) = param
attn_metadata = forward_context.attn_metadata[key]
actual_seq_lengths_kv = attn_metadata.decode_meta.num_computed_tokens_of_pcp_dcp[:, pcp_rank, dcp_rank]
pad_length = num_tokens - len(actual_seq_lengths_kv)
if pad_length > 0:
pad_tensor = np.zeros(pad_length, dtype=actual_seq_lengths_kv.dtype)
actual_seq_lengths_kv = np.concatenate([actual_seq_lengths_kv, pad_tensor])
actual_seq_lengths_q = attn_metadata.actual_seq_lengths_q
if dcp_size > 1:
num_heads = num_heads * dcp_size
torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu.npu_fused_infer_attention_score.out(
q_nope,
k_nope,
value,
num_heads=num_heads,
num_key_value_heads=num_kv_heads,
input_layout="TND",
atten_mask=None,
scale=scale,
antiquant_mode=0,
antiquant_scale=None,
softmax_lse_flag=True,
block_table=block_table,
block_size=block_size,
actual_seq_lengths_kv=actual_seq_lengths_kv,
actual_seq_lengths=actual_seq_lengths_q,
workspace=graph_params.workspaces.get(num_tokens),
out=[attn_output, softmax_lse],
)
torch.npu.graph_task_update_end(update_stream)
event.record(update_stream)
def _attention_with_nomask_and_mask(
self,
q: torch.Tensor,

View File

@@ -284,6 +284,85 @@ class AscendMlaCPImpl(AscendMLAImpl):
self.dcp_rank = get_decode_context_model_parallel_rank() if self.dcp_size > 1 else 0
self.dcp_group = get_dcp_group().device_group if self.dcp_size > 1 else None
@staticmethod
def update_graph_params(
update_stream,
forward_context,
num_tokens,
vllm_config=None,
speculative_config=None,
num_dcp_pcp_tokens=None,
):
if forward_context.is_draft_model:
graph_params = get_draft_graph_params()
else:
graph_params = get_graph_params()
# FIXME: Behold! We are using a temporary hack here to update the args
# for each layer's attention op in the graph.
with torch.npu.stream(update_stream):
for key, param, handle, event in zip(
forward_context.attn_metadata,
graph_params.attn_params[num_tokens],
graph_params.handles[num_tokens],
graph_params.events[num_tokens],
):
(
q_nope,
k_nope,
q_pe,
k_pe,
num_heads,
num_kv_heads,
input_layout,
spec_attn_mask,
sparse_mode,
scale,
block_table,
block_size,
actual_seq_lengths,
actual_seq_lengths_kv,
attn_output,
softmax_lse,
) = param
decode_meta = forward_context.attn_metadata[key].decode
seq_len = decode_meta.cp_seq_len
if isinstance(seq_len, torch.Tensor):
seq_len = seq_len.tolist()
actual_seq_lengths_kv = seq_len
pad_length = num_tokens - len(actual_seq_lengths_kv)
if pad_length > 0:
actual_seq_lengths_kv = actual_seq_lengths_kv + [0] * (num_tokens - len(actual_seq_lengths_kv))
torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu.npu_fused_infer_attention_score.out(
q_nope,
k_nope,
k_nope,
query_rope=q_pe,
key_rope=k_pe,
num_heads=num_heads,
num_key_value_heads=num_kv_heads,
input_layout=input_layout,
atten_mask=spec_attn_mask,
sparse_mode=sparse_mode,
scale=scale,
antiquant_mode=0,
antiquant_scale=None,
softmax_lse_flag=True,
block_table=block_table,
block_size=block_size,
actual_seq_lengths_kv=actual_seq_lengths_kv,
actual_seq_lengths=actual_seq_lengths,
workspace=graph_params.workspaces.get(num_tokens),
out=[attn_output, softmax_lse],
)
torch.npu.graph_task_update_end(update_stream)
event.record(update_stream)
def get_num_actual_tokens(self, attn_metadata: M):
if self.pcp_size > 1:
return attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size

View File

@@ -720,6 +720,88 @@ class AscendMLAImpl(MLAAttentionImpl):
)
register_all_layers_to_shard_weight_series(self.layer_sharding_kwargs)
@staticmethod
def update_graph_params(
update_stream,
forward_context,
num_tokens,
vllm_config=None,
speculative_config=None,
num_dcp_pcp_tokens=None,
):
if forward_context.is_draft_model:
graph_params = get_draft_graph_params()
else:
graph_params = get_graph_params()
# FIXME: Behold! We are using a temporary hack here to update the args
# for each layer's attention op in the graph.
with torch.npu.stream(update_stream):
for key, param, handle, event in zip(
forward_context.attn_metadata,
graph_params.attn_params[num_tokens],
graph_params.handles[num_tokens],
graph_params.events[num_tokens],
):
(
q_nope,
k_nope,
q_pe,
k_pe,
num_heads,
num_kv_heads,
input_layout,
attn_mask,
sparse_mode,
scale,
block_table,
block_size,
seq_lens_list,
actual_seq_lengths,
attn_output,
softmax_lse,
) = param
seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list
if speculative_config and speculative_config.method == "mtp" and not forward_context.is_draft_model:
actual_seq_lengths = forward_context.attn_metadata[key].decode.actual_seq_lengths_q
spec_multiple = speculative_config.num_speculative_tokens + 1
seq_lens_list = seq_lens_list + [0] * (num_tokens // spec_multiple - len(seq_lens_list))
actual_seq_lengths = [spec_multiple * (i + 1) for i in range(num_tokens // spec_multiple)]
elif forward_context.is_draft_model:
actual_seq_lengths = forward_context.attn_metadata[key].decode.actual_seq_lengths_q
block_table = forward_context.attn_metadata[key].decode.block_table
# TODO: This is a hack and should be fixed in the future.
if speculative_config.disable_padded_drafter_batch:
block_table = block_table[: len(actual_seq_lengths)]
seq_lens_list = seq_lens_list + [0] * (len(actual_seq_lengths) - len(seq_lens_list))
else:
seq_lens_list = seq_lens_list + [0] * (num_tokens - len(seq_lens_list))
torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu.npu_fused_infer_attention_score.out(
q_nope,
k_nope,
k_nope,
query_rope=q_pe,
key_rope=k_pe,
num_heads=num_heads,
num_key_value_heads=num_kv_heads,
input_layout=input_layout,
atten_mask=attn_mask,
sparse_mode=sparse_mode,
scale=scale,
antiquant_mode=0,
antiquant_scale=None,
block_table=block_table,
block_size=block_size,
actual_seq_lengths_kv=seq_lens_list,
actual_seq_lengths=actual_seq_lengths,
workspace=graph_params.workspaces.get(num_tokens),
out=[attn_output, softmax_lse],
)
torch.npu.graph_task_update_end(update_stream)
event.record(update_stream)
def _v_up_proj(self, x):
# Convert from (N, B, L)/(N, B, 1, L) to (N, B, L)
x = x.view(self.num_heads, -1, self.kv_lora_rank)

View File

@@ -8,7 +8,6 @@ from dataclasses import dataclass
from typing import Any
from unittest.mock import patch
import numpy as np
import torch
import torch_npu
import vllm.envs as envs
@@ -20,8 +19,6 @@ from vllm.forward_context import BatchDescriptor, get_forward_context
from vllm.logger import logger
from vllm.platforms import current_platform
from vllm_ascend.attention.utils import using_paged_attention
from ..utils import weak_ref_tensors
@@ -213,343 +210,24 @@ def weak_ref_workspaces(params):
params.workspaces[num_tokens] = weak_ref_tensors(params.workspaces[num_tokens])
def _update_attn_pa_params(update_stream, forward_context, runtime_shape):
graph_params = get_graph_params()
# FIXME: Behold! We are using a temporary hack here to update the args
# for each layer's attention op in the graph.
with torch.npu.stream(update_stream):
for key, param, handle, event in zip(
forward_context.attn_metadata,
graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
):
(
query,
key_cache,
value_cache,
num_kv_heads,
num_heads,
scale,
block_table,
seq_lens,
output,
) = param
seq_lens = forward_context.attn_metadata[key].seq_lens
torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu._npu_paged_attention(
query=query,
key_cache=key_cache,
value_cache=value_cache,
num_kv_heads=num_kv_heads,
num_heads=num_heads,
scale_value=scale,
block_table=block_table,
context_lens=seq_lens,
out=output,
workspace=graph_params.workspaces.get(runtime_shape),
)
torch.npu.graph_task_update_end(update_stream)
event.record(update_stream)
def _update_attn_fia_params(update_stream, forward_context, runtime_shape, draft_attn_metadatas=None):
if forward_context.is_draft_model:
graph_params = get_draft_graph_params()
attn_metadata = draft_attn_metadatas
attn_keys = list(attn_metadata[0].keys())
else:
graph_params = get_graph_params()
attn_metadata = forward_context.attn_metadata
attn_keys = list(attn_metadata.keys())
# For Qwen3-next, since the kv_cache_config has already categorized
# linear_attn and self_attn, the attn_metadata is first arranged with
# self_attn followed by linear_attn. Therefore, using zip directly
# filters out the update operations for linear_attn.
# TODO: We use a new variable `attn_keys` to ensure the loop count is
# correct after get by `zip` because of the new structure of the attn_metadata
# when running with the merged full eagle-graph. Should check it with Qwen3-next.
num_layers = len(attn_keys)
if num_layers == 0:
return
if forward_context.is_draft_model:
attn_keys = attn_keys * (len(graph_params.attn_params[runtime_shape]) // num_layers)
attn_count = 0
with torch.npu.stream(update_stream):
for key, param, handle, event in zip(
attn_keys,
graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
):
(
query,
key_cache,
value,
block_tables,
attn_mask,
block_size,
seq_lens,
query_start_loc,
num_kv_heads,
num_heads,
scale,
attn_output,
softmax_lse,
) = param
if forward_context.is_draft_model:
draft_step = attn_count // num_layers
seq_lens = attn_metadata[draft_step][key].seq_lens_list
actual_seq_lengths_q = attn_metadata[draft_step][key].actual_seq_lengths_q
attn_count = attn_count + 1
else:
seq_lens = attn_metadata[key].seq_lens_list
actual_seq_lengths_q = attn_metadata[key].actual_seq_lengths_q
torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu.npu_fused_infer_attention_score.out(
query=query,
key=key_cache,
value=value,
block_table=block_tables,
atten_mask=attn_mask,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=actual_seq_lengths_q,
actual_seq_lengths_kv=seq_lens,
num_key_value_heads=num_kv_heads,
num_heads=num_heads,
scale=scale,
sparse_mode=3,
workspace=graph_params.workspaces.get(runtime_shape),
out=[attn_output, softmax_lse],
)
torch.npu.graph_task_update_end(update_stream)
event.record(update_stream)
def update_attn_params(update_stream, forward_context, runtime_shape, vllm_config, draft_attn_metadatas=None):
if using_paged_attention(runtime_shape, vllm_config):
_update_attn_pa_params(update_stream, forward_context, runtime_shape)
else:
_update_attn_fia_params(update_stream, forward_context, runtime_shape, draft_attn_metadatas)
def update_mla_attn_params(update_stream, forward_context, runtime_shape, speculative_config):
if forward_context.is_draft_model:
graph_params = get_draft_graph_params()
else:
graph_params = get_graph_params()
# FIXME: Behold! We are using a temporary hack here to update the args
# for each layer's attention op in the graph.
with torch.npu.stream(update_stream):
for key, param, handle, event in zip(
forward_context.attn_metadata,
graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
):
(
q_nope,
k_nope,
q_pe,
k_pe,
num_heads,
num_kv_heads,
input_layout,
attn_mask,
sparse_mode,
scale,
block_table,
block_size,
seq_lens_list,
actual_seq_lengths,
attn_output,
softmax_lse,
) = param
seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list
if speculative_config and speculative_config.method == "mtp" and not forward_context.is_draft_model:
actual_seq_lengths = forward_context.attn_metadata[key].decode.actual_seq_lengths_q
spec_multiple = speculative_config.num_speculative_tokens + 1
seq_lens_list = seq_lens_list + [0] * (runtime_shape // spec_multiple - len(seq_lens_list))
actual_seq_lengths = [spec_multiple * (i + 1) for i in range(runtime_shape // spec_multiple)]
elif forward_context.is_draft_model:
actual_seq_lengths = forward_context.attn_metadata[key].decode.actual_seq_lengths_q
block_table = forward_context.attn_metadata[key].decode.block_table
# TODO: This is a hack and should be fixed in the future.
if speculative_config.disable_padded_drafter_batch:
block_table = block_table[: len(actual_seq_lengths)]
seq_lens_list = seq_lens_list + [0] * (len(actual_seq_lengths) - len(seq_lens_list))
else:
seq_lens_list = seq_lens_list + [0] * (runtime_shape - len(seq_lens_list))
torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu.npu_fused_infer_attention_score.out(
q_nope,
k_nope,
k_nope,
query_rope=q_pe,
key_rope=k_pe,
num_heads=num_heads,
num_key_value_heads=num_kv_heads,
input_layout=input_layout,
atten_mask=attn_mask,
sparse_mode=sparse_mode,
scale=scale,
antiquant_mode=0,
antiquant_scale=None,
block_table=block_table,
block_size=block_size,
actual_seq_lengths_kv=seq_lens_list,
actual_seq_lengths=actual_seq_lengths,
workspace=graph_params.workspaces.get(runtime_shape),
out=[attn_output, softmax_lse],
)
torch.npu.graph_task_update_end(update_stream)
event.record(update_stream)
def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
# FIXME: Behold! We are using a temporary hack here to update the args
# for each layer's attention op in the graph.
graph_params = get_graph_params()
with torch.npu.stream(update_stream):
for key, param, handle, event in zip(
forward_context.attn_metadata,
graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
):
(
q_nope,
k_nope,
value,
num_heads,
num_kv_heads,
scale,
block_table,
block_size,
actual_seq_lengths_kv,
actual_seq_lengths_q,
attn_output,
softmax_lse,
dcp_size,
pcp_rank,
dcp_rank,
) = param
attn_metadata = forward_context.attn_metadata[key]
actual_seq_lengths_kv = attn_metadata.decode_meta.num_computed_tokens_of_pcp_dcp[:, pcp_rank, dcp_rank]
pad_length = runtime_shape - len(actual_seq_lengths_kv)
if pad_length > 0:
pad_tensor = np.zeros(pad_length, dtype=actual_seq_lengths_kv.dtype)
actual_seq_lengths_kv = np.concatenate([actual_seq_lengths_kv, pad_tensor])
actual_seq_lengths_q = attn_metadata.actual_seq_lengths_q
if dcp_size > 1:
num_heads = num_heads * dcp_size
torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu.npu_fused_infer_attention_score.out(
q_nope,
k_nope,
value,
num_heads=num_heads,
num_key_value_heads=num_kv_heads,
input_layout="TND",
atten_mask=None,
scale=scale,
antiquant_mode=0,
antiquant_scale=None,
softmax_lse_flag=True,
block_table=block_table,
block_size=block_size,
actual_seq_lengths_kv=actual_seq_lengths_kv,
actual_seq_lengths=actual_seq_lengths_q,
workspace=graph_params.workspaces.get(runtime_shape),
out=[attn_output, softmax_lse],
)
torch.npu.graph_task_update_end(update_stream)
event.record(update_stream)
def update_mla_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
if forward_context.is_draft_model:
graph_params = get_draft_graph_params()
else:
graph_params = get_graph_params()
# FIXME: Behold! We are using a temporary hack here to update the args
# for each layer's attention op in the graph.
with torch.npu.stream(update_stream):
for key, param, handle, event in zip(
forward_context.attn_metadata,
graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
):
(
q_nope,
k_nope,
q_pe,
k_pe,
num_heads,
num_kv_heads,
input_layout,
spec_attn_mask,
sparse_mode,
scale,
block_table,
block_size,
actual_seq_lengths,
actual_seq_lengths_kv,
attn_output,
softmax_lse,
) = param
decode_meta = forward_context.attn_metadata[key].decode
seq_len = decode_meta.cp_seq_len
if isinstance(seq_len, torch.Tensor):
seq_len = seq_len.tolist()
actual_seq_lengths_kv = seq_len
pad_length = runtime_shape - len(actual_seq_lengths_kv)
if pad_length > 0:
actual_seq_lengths_kv = actual_seq_lengths_kv + [0] * (runtime_shape - len(actual_seq_lengths_kv))
torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu.npu_fused_infer_attention_score.out(
q_nope,
k_nope,
k_nope,
query_rope=q_pe,
key_rope=k_pe,
num_heads=num_heads,
num_key_value_heads=num_kv_heads,
input_layout=input_layout,
atten_mask=spec_attn_mask,
sparse_mode=sparse_mode,
scale=scale,
antiquant_mode=0,
antiquant_scale=None,
softmax_lse_flag=True,
block_table=block_table,
block_size=block_size,
actual_seq_lengths_kv=actual_seq_lengths_kv,
actual_seq_lengths=actual_seq_lengths,
workspace=graph_params.workspaces.get(runtime_shape),
out=[attn_output, softmax_lse],
)
torch.npu.graph_task_update_end(update_stream)
event.record(update_stream)
def update_full_graph_params(
attn_backend,
update_stream,
forward_context,
num_tokens,
vllm_config,
speculative_config=None,
num_dcp_pcp_tokens=None,
):
impl_cls = attn_backend.get_impl_cls()
impl_cls.update_graph_params(
update_stream,
forward_context,
num_tokens,
vllm_config,
speculative_config,
num_dcp_pcp_tokens,
)
@dataclass

View File

@@ -416,7 +416,7 @@ def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
kv_cache_spec: dict[str, KVCacheSpec] = {}
attn_layers = get_layers_from_vllm_config(vllm_config, AttentionLayerBase)
# NOTE: Must process Attention/MLAAttention before MambaBase to maintain
# ordering expected by acl_graph.py's _update_attn_fia_params.
# ordering expected by graph parameter update logic in attention backends.
mamba_layers: dict[str, MambaBase] = {}
for layer_name, attn_module in attn_layers.items():
if isinstance(attn_module, Attention):

View File

@@ -36,10 +36,7 @@ from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
update_attn_dcp_pcp_params,
update_attn_params,
update_mla_attn_dcp_pcp_params,
update_mla_attn_params)
update_full_graph_params)
from vllm_ascend.ops.rotary_embedding import update_cos_sin
from vllm_ascend.ops.triton.spec_decode.utils import \
prepare_inputs_padded_kernel
@@ -1181,21 +1178,9 @@ class EagleProposer(VllmEagleProposer):
# update full-graph params for one spec token
def _update_full_graph_params(self, forward_context, num_tokens, draft_attn_metadatas=None):
if self.vllm_config.model_config.use_mla:
if self.pcp_size * self.dcp_size > 1:
update_mla_attn_dcp_pcp_params(self.update_stream,
forward_context, num_tokens)
else:
update_mla_attn_params(self.update_stream, forward_context,
num_tokens,
self.vllm_config.speculative_config)
else:
if self.pcp_size * self.dcp_size > 1:
update_attn_dcp_pcp_params(self.update_stream, forward_context,
num_tokens)
else:
update_attn_params(self.update_stream, forward_context,
num_tokens, self.vllm_config, draft_attn_metadatas)
update_full_graph_params(
self.runner.attn_backend, self.update_stream, forward_context, num_tokens,
self.vllm_config, self.vllm_config.speculative_config)
# padding tensor into desired size
def _pad_tensor(self, tensor, pad_size):

View File

@@ -84,10 +84,7 @@ from vllm_ascend.attention.utils import AscendCommonAttentionMetadata, using_pag
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
set_draft_graph_params,
set_graph_params,
update_attn_dcp_pcp_params,
update_attn_params,
update_mla_attn_dcp_pcp_params,
update_mla_attn_params)
update_full_graph_params)
# yapf: enable
from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor
from vllm_ascend.eplb.core.eplb_device_transfer_loader import \
@@ -1142,26 +1139,9 @@ class NPUModelRunner(GPUModelRunner):
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL \
and not self.use_sparse:
# TODO: maybe_padded_num_tokens will be removed, use num_input_tokens instead
if self.vllm_config.model_config.use_mla:
if self.pcp_size * self.dcp_size > 1:
# FIXME: Try using `auto_dispatch_capture=True`
update_mla_attn_dcp_pcp_params(self.update_stream,
forward_context,
maybe_padded_num_tokens)
else:
# FIXME: Try using `auto_dispatch_capture=True`
update_mla_attn_params(self.update_stream, forward_context,
maybe_padded_num_tokens,
self.speculative_config)
else:
if self.pcp_size * self.dcp_size > 1:
update_attn_dcp_pcp_params(self.update_stream,
forward_context,
maybe_padded_num_tokens)
else:
update_attn_params(self.update_stream, forward_context,
maybe_padded_num_tokens,
self.vllm_config)
update_full_graph_params(self.attn_backend, self.update_stream, forward_context,
maybe_padded_num_tokens, self.vllm_config,
self.vllm_config.speculative_config)
if get_forward_context().sp_enabled and not isinstance(
hidden_states, IntermediateTensors):
@@ -2038,25 +2018,9 @@ class NPUModelRunner(GPUModelRunner):
assert forward_context is not None
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and \
not forward_context.capturing and not self.use_sparse:
if self.vllm_config.model_config.use_mla:
# FIXME: Try using `auto_dispatch_capture=True`
if self.pcp_size * self.dcp_size > 1:
# FIXME: Try using `auto_dispatch_capture=True`
update_mla_attn_dcp_pcp_params(self.update_stream,
forward_context,
positions.shape[0])
else:
# FIXME: Try using `auto_dispatch_capture=True`
update_mla_attn_params(self.update_stream, forward_context,
num_tokens, self.speculative_config)
else:
if self.pcp_size * self.dcp_size > 1:
update_attn_dcp_pcp_params(self.update_stream,
forward_context,
positions.shape[0])
else:
update_attn_params(self.update_stream, forward_context,
num_tokens, self.vllm_config)
update_full_graph_params(self.attn_backend, self.update_stream, forward_context,
num_tokens, self.vllm_config,
self.speculative_config, positions.shape[0])
if self.use_aux_hidden_state_outputs:
hidden_states, _ = hidden_states
@@ -2899,7 +2863,7 @@ class NPUModelRunner(GPUModelRunner):
attn_layers = get_layers_from_vllm_config(self.vllm_config,
AttentionLayerBase)
# NOTE: Must process Attention/MLAAttention before MambaBase to maintain
# ordering expected by acl_graph.py's _update_attn_fia_params.
# ordering expected by graph parameter update logic in attention backends.
mamba_layers: dict[str, MambaBase] = {}
for layer_name, attn_module in attn_layers.items():
if isinstance(attn_module, Attention):