[main][Refactor] Remove with_prefill parameter from set_ascend_forward_context (#5094)

Removes the redundant `with_prefill` parameter from
`set_ascend_forward_context` to align the interface with vLLM's
`set_forward_context` for future refactoring.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
Signed-off-by: Slightwind <slightwindsec@gmail.com>
Co-authored-by: zzzzwwjj <34335947+zzzzwwjj@users.noreply.github.com>
This commit is contained in:
Slightwind
2025-12-23 14:30:50 +08:00
committed by GitHub
parent fa0c212bfa
commit 22138e2727
6 changed files with 22 additions and 21 deletions

View File

@@ -476,12 +476,14 @@ class AscendSFAImpl(MLAAttentionImpl):
# if mlapo, W_UK_T can't trans nz
self.W_UK_T = maybe_trans_nz(self.W_UK_T)
def _v_up_proj(self, x):
forward_context = get_forward_context()
def _v_up_proj(self, x, has_prefill: bool):
# TODO(zzzzwwjj): We should not judge by whether `has_prefill` or not.
# The true criteria for judgment is tensorA's shape[0] <= 1024 (num_tokens <= 1024).
# This is a bug in the previous code.
if x.dtype in [torch.float16, torch.bfloat16] \
and hasattr(torch.ops._C_ascend, "batch_matmul_transpose") \
and not self.enable_sfa_cp \
and not forward_context.with_prefill:
and not has_prefill:
x = x.view(-1, self.num_heads, self.kv_lora_rank)
b, _, _ = x.shape
res = torch.empty((b, self.num_heads, self.v_head_dim),
@@ -766,7 +768,9 @@ class AscendSFAImpl(MLAAttentionImpl):
# Inputs and outputs may be padded for CUDA graphs
output_padded = output
if self.enable_mlapo and not forward_context.with_prefill:
# TODO(zzzzwwjj): In sfa, prefill and decode have the same calculation formula,
# so `has_prefill` here is not necessary.
if self.enable_mlapo and not has_prefill:
hidden_states, ql_nope, q_pe, q_c = self._sfa_preprocess_decode(
hidden_states=hidden_states,
kv_cache=kv_cache,
@@ -841,7 +845,7 @@ class AscendSFAImpl(MLAAttentionImpl):
layout_kv="PA_BSND",
sparse_mode=3,
)
attn_output = self._v_up_proj(attn_output)
attn_output = self._v_up_proj(attn_output, has_prefill)
maybe_npu_prefetch(inputs=self.o_proj.weight,
dependency=attn_output,
max_size=MAX_O_PROJ_PREFETCH_SIZE,