[main][Refactor] Remove with_prefill parameter from set_ascend_forward_context (#5094)
Removes the redundant `with_prefill` parameter from
`set_ascend_forward_context` to align the interface with vLLM's
`set_forward_context` for future refactoring.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
Signed-off-by: Slightwind <slightwindsec@gmail.com>
Co-authored-by: zzzzwwjj <34335947+zzzzwwjj@users.noreply.github.com>
This commit is contained in:
@@ -476,12 +476,14 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
# if mlapo, W_UK_T can't trans nz
|
||||
self.W_UK_T = maybe_trans_nz(self.W_UK_T)
|
||||
|
||||
def _v_up_proj(self, x):
|
||||
forward_context = get_forward_context()
|
||||
def _v_up_proj(self, x, has_prefill: bool):
|
||||
# TODO(zzzzwwjj): We should not judge by whether `has_prefill` or not.
|
||||
# The true criteria for judgment is tensorA's shape[0] <= 1024 (num_tokens <= 1024).
|
||||
# This is a bug in the previous code.
|
||||
if x.dtype in [torch.float16, torch.bfloat16] \
|
||||
and hasattr(torch.ops._C_ascend, "batch_matmul_transpose") \
|
||||
and not self.enable_sfa_cp \
|
||||
and not forward_context.with_prefill:
|
||||
and not has_prefill:
|
||||
x = x.view(-1, self.num_heads, self.kv_lora_rank)
|
||||
b, _, _ = x.shape
|
||||
res = torch.empty((b, self.num_heads, self.v_head_dim),
|
||||
@@ -766,7 +768,9 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
# Inputs and outputs may be padded for CUDA graphs
|
||||
output_padded = output
|
||||
|
||||
if self.enable_mlapo and not forward_context.with_prefill:
|
||||
# TODO(zzzzwwjj): In sfa, prefill and decode have the same calculation formula,
|
||||
# so `has_prefill` here is not necessary.
|
||||
if self.enable_mlapo and not has_prefill:
|
||||
hidden_states, ql_nope, q_pe, q_c = self._sfa_preprocess_decode(
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
@@ -841,7 +845,7 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
layout_kv="PA_BSND",
|
||||
sparse_mode=3,
|
||||
)
|
||||
attn_output = self._v_up_proj(attn_output)
|
||||
attn_output = self._v_up_proj(attn_output, has_prefill)
|
||||
maybe_npu_prefetch(inputs=self.o_proj.weight,
|
||||
dependency=attn_output,
|
||||
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
||||
|
||||
Reference in New Issue
Block a user