From 22138e2727f43ad43898734325969fe298c142e0 Mon Sep 17 00:00:00 2001 From: Slightwind Date: Tue, 23 Dec 2025 14:30:50 +0800 Subject: [PATCH] [main][Refactor] Remove `with_prefill` parameter from `set_ascend_forward_context` (#5094) Removes the redundant `with_prefill` parameter from `set_ascend_forward_context` to align the interface with vLLM's `set_forward_context` for future refactoring. - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: SlightwindSec Signed-off-by: Slightwind Co-authored-by: zzzzwwjj <34335947+zzzzwwjj@users.noreply.github.com> --- vllm_ascend/ascend_forward_context.py | 2 -- vllm_ascend/attention/mla_cp.py | 12 +++++++----- vllm_ascend/attention/mla_v1.py | 11 ++++++----- vllm_ascend/attention/sfa_v1.py | 14 +++++++++----- vllm_ascend/spec_decode/mtp_proposer.py | 2 -- vllm_ascend/worker/model_runner_v1.py | 2 -- 6 files changed, 22 insertions(+), 21 deletions(-) diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index e12b45fa..ada22242 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -31,7 +31,6 @@ def set_ascend_forward_context( virtual_engine: int = 0, num_tokens: int = 0, num_tokens_across_dp: Optional[torch.Tensor] = None, - with_prefill: bool = True, in_profile_run: bool = False, num_actual_tokens: Optional[int] = None, aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, @@ -60,7 +59,6 @@ def set_ascend_forward_context( forward_context.moe_comm_type = moe_comm_type forward_context.moe_comm_method = get_moe_comm_method(moe_comm_type) - forward_context.with_prefill = with_prefill tp_world_size = get_tensor_model_parallel_world_size() forward_context.in_profile_run = in_profile_run diff --git a/vllm_ascend/attention/mla_cp.py b/vllm_ascend/attention/mla_cp.py index 2ddc41db..ae24557a 100644 --- a/vllm_ascend/attention/mla_cp.py +++ b/vllm_ascend/attention/mla_cp.py @@ -594,6 +594,9 @@ class AscendMlaCPImpl(AscendMLAImpl): self.vllm_config, self.o_proj): reach_layer_for_shared_weight_series(self.o_proj) return output.fill_(0) + + forward_context = get_forward_context() + if self.pcp_size > 1: num_actual_tokens = attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size else: @@ -601,19 +604,19 @@ class AscendMlaCPImpl(AscendMLAImpl): assert attn_metadata.num_decodes is not None and \ attn_metadata.num_prefills is not None and \ attn_metadata.num_decode_tokens is not None + + has_prefill = attn_metadata.num_prefills > 0 num_decode_tokens = attn_metadata.num_decode_tokens # Inputs and outputs may be padded for CUDA graphs output_padded = output - o_proj_input_shape = (get_forward_context().num_tokens, + o_proj_input_shape = (forward_context.num_tokens, self.num_heads * self.v_head_dim) o_proj_input = torch.empty(o_proj_input_shape, dtype=hidden_states.dtype, device=hidden_states.device) # MLA Preprocess - forward_context = get_forward_context() - if (self.enable_mlapo and - (attn_metadata is None or not forward_context.with_prefill)): + if self.enable_mlapo and not has_prefill: hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( hidden_states.contiguous(), need_gather_q_kv) decode_preprocess_res, prefill_preprocess_res = self._mla_decode_preprocess( @@ -671,7 +674,6 @@ class AscendMlaCPImpl(AscendMLAImpl): del o_proj_input - has_prefill = attn_metadata.num_prefills > 0 if has_prefill: maybe_save_kv_layer_to_connector(layer_name, list(kv_cache)) return output_padded diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 352f1d33..6e47b98e 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -1395,23 +1395,25 @@ class AscendMLAImpl(MLAAttentionImpl): self.vllm_config, self.o_proj): reach_layer_for_shared_weight_series(self.o_proj) return output.fill_(0) + + forward_context = get_forward_context() num_actual_tokens = attn_metadata.num_actual_tokens assert attn_metadata.num_decodes is not None and \ attn_metadata.num_prefills is not None and \ attn_metadata.num_decode_tokens is not None + + has_prefill = attn_metadata.num_prefills > 0 num_decode_tokens = attn_metadata.num_decode_tokens # Inputs and outputs may be padded for CUDA graphs output_padded = output - o_proj_input_shape = (get_forward_context().num_tokens, + o_proj_input_shape = (forward_context.num_tokens, self.num_heads * self.v_head_dim) o_proj_input = torch.empty(o_proj_input_shape, dtype=hidden_states.dtype, device=hidden_states.device) # MLA Preprocess - forward_context = get_forward_context() - if (self.enable_mlapo and - (attn_metadata is None or not forward_context.with_prefill)): + if self.enable_mlapo and not has_prefill: hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( hidden_states.contiguous(), need_gather_q_kv) decode_preprocess_res, prefill_preprocess_res = self._mla_decode_preprocess( @@ -1455,7 +1457,6 @@ class AscendMLAImpl(MLAAttentionImpl): del o_proj_input - has_prefill = attn_metadata.num_prefills > 0 if has_prefill: maybe_save_kv_layer_to_connector(layer_name, list(kv_cache)) return output_padded diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index 32a66c0e..48aac26c 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -476,12 +476,14 @@ class AscendSFAImpl(MLAAttentionImpl): # if mlapo, W_UK_T can't trans nz self.W_UK_T = maybe_trans_nz(self.W_UK_T) - def _v_up_proj(self, x): - forward_context = get_forward_context() + def _v_up_proj(self, x, has_prefill: bool): + # TODO(zzzzwwjj): We should not judge by whether `has_prefill` or not. + # The true criteria for judgment is tensorA's shape[0] <= 1024 (num_tokens <= 1024). + # This is a bug in the previous code. if x.dtype in [torch.float16, torch.bfloat16] \ and hasattr(torch.ops._C_ascend, "batch_matmul_transpose") \ and not self.enable_sfa_cp \ - and not forward_context.with_prefill: + and not has_prefill: x = x.view(-1, self.num_heads, self.kv_lora_rank) b, _, _ = x.shape res = torch.empty((b, self.num_heads, self.v_head_dim), @@ -766,7 +768,9 @@ class AscendSFAImpl(MLAAttentionImpl): # Inputs and outputs may be padded for CUDA graphs output_padded = output - if self.enable_mlapo and not forward_context.with_prefill: + # TODO(zzzzwwjj): In sfa, prefill and decode have the same calculation formula, + # so `has_prefill` here is not necessary. + if self.enable_mlapo and not has_prefill: hidden_states, ql_nope, q_pe, q_c = self._sfa_preprocess_decode( hidden_states=hidden_states, kv_cache=kv_cache, @@ -841,7 +845,7 @@ class AscendSFAImpl(MLAAttentionImpl): layout_kv="PA_BSND", sparse_mode=3, ) - attn_output = self._v_up_proj(attn_output) + attn_output = self._v_up_proj(attn_output, has_prefill) maybe_npu_prefetch(inputs=self.o_proj.weight, dependency=attn_output, max_size=MAX_O_PROJ_PREFETCH_SIZE, diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 9f71183b..66dd65bd 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -299,7 +299,6 @@ class MtpProposer(Proposer): attn_metadata, self.vllm_config, num_tokens=num_tokens, - with_prefill=with_prefill, num_tokens_across_dp=num_tokens_across_dp, num_actual_tokens=0, aclgraph_runtime_mode=aclgraph_runtime_mode, @@ -779,7 +778,6 @@ class MtpProposer(Proposer): attn_metadata, self.vllm_config, num_tokens=num_input_tokens, - with_prefill=with_prefill, num_tokens_across_dp=num_tokens_across_dp, aclgraph_runtime_mode=aclgraph_runtime_mode, batch_descriptor=batch_descriptor, diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index d8454674..c048a577 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1424,7 +1424,6 @@ class NPUModelRunner(GPUModelRunner): self.vllm_config, num_tokens=num_input_tokens, num_tokens_across_dp=num_tokens_across_dp, - with_prefill=self.with_prefill, aclgraph_runtime_mode=aclgraph_runtime_mode, batch_descriptor=batch_descriptor, num_actual_tokens=scheduler_output. @@ -2137,7 +2136,6 @@ class NPUModelRunner(GPUModelRunner): self.vllm_config, num_tokens=num_tokens_padded, num_tokens_across_dp=num_tokens_across_dp, - with_prefill=with_prefill, in_profile_run=is_profile, num_actual_tokens=0, aclgraph_runtime_mode=aclgraph_runtime_mode,