[main][Refactor] Remove with_prefill parameter from set_ascend_forward_context (#5094)

Removes the redundant `with_prefill` parameter from
`set_ascend_forward_context` to align the interface with vLLM's
`set_forward_context` for future refactoring.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
Signed-off-by: Slightwind <slightwindsec@gmail.com>
Co-authored-by: zzzzwwjj <34335947+zzzzwwjj@users.noreply.github.com>
This commit is contained in:
Slightwind
2025-12-23 14:30:50 +08:00
committed by GitHub
parent fa0c212bfa
commit 22138e2727
6 changed files with 22 additions and 21 deletions

View File

@@ -1395,23 +1395,25 @@ class AscendMLAImpl(MLAAttentionImpl):
self.vllm_config, self.o_proj):
reach_layer_for_shared_weight_series(self.o_proj)
return output.fill_(0)
forward_context = get_forward_context()
num_actual_tokens = attn_metadata.num_actual_tokens
assert attn_metadata.num_decodes is not None and \
attn_metadata.num_prefills is not None and \
attn_metadata.num_decode_tokens is not None
has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens
# Inputs and outputs may be padded for CUDA graphs
output_padded = output
o_proj_input_shape = (get_forward_context().num_tokens,
o_proj_input_shape = (forward_context.num_tokens,
self.num_heads * self.v_head_dim)
o_proj_input = torch.empty(o_proj_input_shape,
dtype=hidden_states.dtype,
device=hidden_states.device)
# MLA Preprocess
forward_context = get_forward_context()
if (self.enable_mlapo and
(attn_metadata is None or not forward_context.with_prefill)):
if self.enable_mlapo and not has_prefill:
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
hidden_states.contiguous(), need_gather_q_kv)
decode_preprocess_res, prefill_preprocess_res = self._mla_decode_preprocess(
@@ -1455,7 +1457,6 @@ class AscendMLAImpl(MLAAttentionImpl):
del o_proj_input
has_prefill = attn_metadata.num_prefills > 0
if has_prefill:
maybe_save_kv_layer_to_connector(layer_name, list(kv_cache))
return output_padded