[main][Refactor] Remove with_prefill parameter from set_ascend_forward_context (#5094)
Removes the redundant `with_prefill` parameter from
`set_ascend_forward_context` to align the interface with vLLM's
`set_forward_context` for future refactoring.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
Signed-off-by: Slightwind <slightwindsec@gmail.com>
Co-authored-by: zzzzwwjj <34335947+zzzzwwjj@users.noreply.github.com>
This commit is contained in:
@@ -1395,23 +1395,25 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
self.vllm_config, self.o_proj):
|
||||
reach_layer_for_shared_weight_series(self.o_proj)
|
||||
return output.fill_(0)
|
||||
|
||||
forward_context = get_forward_context()
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
assert attn_metadata.num_decodes is not None and \
|
||||
attn_metadata.num_prefills is not None and \
|
||||
attn_metadata.num_decode_tokens is not None
|
||||
|
||||
has_prefill = attn_metadata.num_prefills > 0
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
# Inputs and outputs may be padded for CUDA graphs
|
||||
output_padded = output
|
||||
o_proj_input_shape = (get_forward_context().num_tokens,
|
||||
o_proj_input_shape = (forward_context.num_tokens,
|
||||
self.num_heads * self.v_head_dim)
|
||||
o_proj_input = torch.empty(o_proj_input_shape,
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
|
||||
# MLA Preprocess
|
||||
forward_context = get_forward_context()
|
||||
if (self.enable_mlapo and
|
||||
(attn_metadata is None or not forward_context.with_prefill)):
|
||||
if self.enable_mlapo and not has_prefill:
|
||||
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
hidden_states.contiguous(), need_gather_q_kv)
|
||||
decode_preprocess_res, prefill_preprocess_res = self._mla_decode_preprocess(
|
||||
@@ -1455,7 +1457,6 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
|
||||
del o_proj_input
|
||||
|
||||
has_prefill = attn_metadata.num_prefills > 0
|
||||
if has_prefill:
|
||||
maybe_save_kv_layer_to_connector(layer_name, list(kv_cache))
|
||||
return output_padded
|
||||
|
||||
Reference in New Issue
Block a user