Revert "[Refactor] Unify full-graph parameter update logic (#6041)" (#6227)

This reverts commit 8966a99710.

It breaks the test
`tests/e2e/singlecard/spec_decode/test_mtp_eagle_correctness.py::test_deepseek_mtp_correctness[True-FULL_DECODE_ONLY-2-wemaster/deepseek_mtp_main_random_bf16]`

- vLLM version: v0.14.0
- vLLM main:
d68209402d
This commit is contained in:
wangxiyuan
2026-01-25 15:25:38 +08:00
committed by GitHub
parent 7799c4ca3b
commit 95649344aa
10 changed files with 415 additions and 420 deletions

View File

@@ -84,7 +84,10 @@ from vllm_ascend.attention.utils import AscendCommonAttentionMetadata, using_pag
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
set_draft_graph_params,
set_graph_params,
update_full_graph_params)
update_attn_dcp_pcp_params,
update_attn_params,
update_mla_attn_dcp_pcp_params,
update_mla_attn_params)
# yapf: enable
from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor
from vllm_ascend.eplb.core.eplb_device_transfer_loader import \
@@ -1139,9 +1142,26 @@ class NPUModelRunner(GPUModelRunner):
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL \
and not self.use_sparse:
# TODO: maybe_padded_num_tokens will be removed, use num_input_tokens instead
update_full_graph_params(self.attn_backend, self.update_stream, forward_context,
maybe_padded_num_tokens, self.vllm_config,
self.vllm_config.speculative_config)
if self.vllm_config.model_config.use_mla:
if self.pcp_size * self.dcp_size > 1:
# FIXME: Try using `auto_dispatch_capture=True`
update_mla_attn_dcp_pcp_params(self.update_stream,
forward_context,
maybe_padded_num_tokens)
else:
# FIXME: Try using `auto_dispatch_capture=True`
update_mla_attn_params(self.update_stream, forward_context,
maybe_padded_num_tokens,
self.speculative_config)
else:
if self.pcp_size * self.dcp_size > 1:
update_attn_dcp_pcp_params(self.update_stream,
forward_context,
maybe_padded_num_tokens)
else:
update_attn_params(self.update_stream, forward_context,
maybe_padded_num_tokens,
self.vllm_config)
if get_forward_context().sp_enabled and not isinstance(
hidden_states, IntermediateTensors):
@@ -2018,9 +2038,25 @@ class NPUModelRunner(GPUModelRunner):
assert forward_context is not None
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and \
not forward_context.capturing and not self.use_sparse:
update_full_graph_params(self.attn_backend, self.update_stream, forward_context,
num_tokens, self.vllm_config,
self.speculative_config, positions.shape[0])
if self.vllm_config.model_config.use_mla:
# FIXME: Try using `auto_dispatch_capture=True`
if self.pcp_size * self.dcp_size > 1:
# FIXME: Try using `auto_dispatch_capture=True`
update_mla_attn_dcp_pcp_params(self.update_stream,
forward_context,
positions.shape[0])
else:
# FIXME: Try using `auto_dispatch_capture=True`
update_mla_attn_params(self.update_stream, forward_context,
num_tokens, self.speculative_config)
else:
if self.pcp_size * self.dcp_size > 1:
update_attn_dcp_pcp_params(self.update_stream,
forward_context,
positions.shape[0])
else:
update_attn_params(self.update_stream, forward_context,
num_tokens, self.vllm_config)
if self.use_aux_hidden_state_outputs:
hidden_states, _ = hidden_states
@@ -2863,7 +2899,7 @@ class NPUModelRunner(GPUModelRunner):
attn_layers = get_layers_from_vllm_config(self.vllm_config,
AttentionLayerBase)
# NOTE: Must process Attention/MLAAttention before MambaBase to maintain
# ordering expected by graph parameter update logic in attention backends.
# ordering expected by acl_graph.py's _update_attn_fia_params.
mamba_layers: dict[str, MambaBase] = {}
for layer_name, attn_module in attn_layers.items():
if isinstance(attn_module, Attention):