[main][Refactor] Remove with_prefill parameter from set_ascend_forward_context (#5094)

Removes the redundant `with_prefill` parameter from
`set_ascend_forward_context` to align the interface with vLLM's
`set_forward_context` for future refactoring.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
Signed-off-by: Slightwind <slightwindsec@gmail.com>
Co-authored-by: zzzzwwjj <34335947+zzzzwwjj@users.noreply.github.com>
This commit is contained in:
Slightwind
2025-12-23 14:30:50 +08:00
committed by GitHub
parent fa0c212bfa
commit 22138e2727
6 changed files with 22 additions and 21 deletions

View File

@@ -31,7 +31,6 @@ def set_ascend_forward_context(
virtual_engine: int = 0, virtual_engine: int = 0,
num_tokens: int = 0, num_tokens: int = 0,
num_tokens_across_dp: Optional[torch.Tensor] = None, num_tokens_across_dp: Optional[torch.Tensor] = None,
with_prefill: bool = True,
in_profile_run: bool = False, in_profile_run: bool = False,
num_actual_tokens: Optional[int] = None, num_actual_tokens: Optional[int] = None,
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
@@ -60,7 +59,6 @@ def set_ascend_forward_context(
forward_context.moe_comm_type = moe_comm_type forward_context.moe_comm_type = moe_comm_type
forward_context.moe_comm_method = get_moe_comm_method(moe_comm_type) forward_context.moe_comm_method = get_moe_comm_method(moe_comm_type)
forward_context.with_prefill = with_prefill
tp_world_size = get_tensor_model_parallel_world_size() tp_world_size = get_tensor_model_parallel_world_size()
forward_context.in_profile_run = in_profile_run forward_context.in_profile_run = in_profile_run

View File

@@ -594,6 +594,9 @@ class AscendMlaCPImpl(AscendMLAImpl):
self.vllm_config, self.o_proj): self.vllm_config, self.o_proj):
reach_layer_for_shared_weight_series(self.o_proj) reach_layer_for_shared_weight_series(self.o_proj)
return output.fill_(0) return output.fill_(0)
forward_context = get_forward_context()
if self.pcp_size > 1: if self.pcp_size > 1:
num_actual_tokens = attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size num_actual_tokens = attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size
else: else:
@@ -601,19 +604,19 @@ class AscendMlaCPImpl(AscendMLAImpl):
assert attn_metadata.num_decodes is not None and \ assert attn_metadata.num_decodes is not None and \
attn_metadata.num_prefills is not None and \ attn_metadata.num_prefills is not None and \
attn_metadata.num_decode_tokens is not None attn_metadata.num_decode_tokens is not None
has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens num_decode_tokens = attn_metadata.num_decode_tokens
# Inputs and outputs may be padded for CUDA graphs # Inputs and outputs may be padded for CUDA graphs
output_padded = output output_padded = output
o_proj_input_shape = (get_forward_context().num_tokens, o_proj_input_shape = (forward_context.num_tokens,
self.num_heads * self.v_head_dim) self.num_heads * self.v_head_dim)
o_proj_input = torch.empty(o_proj_input_shape, o_proj_input = torch.empty(o_proj_input_shape,
dtype=hidden_states.dtype, dtype=hidden_states.dtype,
device=hidden_states.device) device=hidden_states.device)
# MLA Preprocess # MLA Preprocess
forward_context = get_forward_context() if self.enable_mlapo and not has_prefill:
if (self.enable_mlapo and
(attn_metadata is None or not forward_context.with_prefill)):
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
hidden_states.contiguous(), need_gather_q_kv) hidden_states.contiguous(), need_gather_q_kv)
decode_preprocess_res, prefill_preprocess_res = self._mla_decode_preprocess( decode_preprocess_res, prefill_preprocess_res = self._mla_decode_preprocess(
@@ -671,7 +674,6 @@ class AscendMlaCPImpl(AscendMLAImpl):
del o_proj_input del o_proj_input
has_prefill = attn_metadata.num_prefills > 0
if has_prefill: if has_prefill:
maybe_save_kv_layer_to_connector(layer_name, list(kv_cache)) maybe_save_kv_layer_to_connector(layer_name, list(kv_cache))
return output_padded return output_padded

View File

@@ -1395,23 +1395,25 @@ class AscendMLAImpl(MLAAttentionImpl):
self.vllm_config, self.o_proj): self.vllm_config, self.o_proj):
reach_layer_for_shared_weight_series(self.o_proj) reach_layer_for_shared_weight_series(self.o_proj)
return output.fill_(0) return output.fill_(0)
forward_context = get_forward_context()
num_actual_tokens = attn_metadata.num_actual_tokens num_actual_tokens = attn_metadata.num_actual_tokens
assert attn_metadata.num_decodes is not None and \ assert attn_metadata.num_decodes is not None and \
attn_metadata.num_prefills is not None and \ attn_metadata.num_prefills is not None and \
attn_metadata.num_decode_tokens is not None attn_metadata.num_decode_tokens is not None
has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens num_decode_tokens = attn_metadata.num_decode_tokens
# Inputs and outputs may be padded for CUDA graphs # Inputs and outputs may be padded for CUDA graphs
output_padded = output output_padded = output
o_proj_input_shape = (get_forward_context().num_tokens, o_proj_input_shape = (forward_context.num_tokens,
self.num_heads * self.v_head_dim) self.num_heads * self.v_head_dim)
o_proj_input = torch.empty(o_proj_input_shape, o_proj_input = torch.empty(o_proj_input_shape,
dtype=hidden_states.dtype, dtype=hidden_states.dtype,
device=hidden_states.device) device=hidden_states.device)
# MLA Preprocess # MLA Preprocess
forward_context = get_forward_context() if self.enable_mlapo and not has_prefill:
if (self.enable_mlapo and
(attn_metadata is None or not forward_context.with_prefill)):
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
hidden_states.contiguous(), need_gather_q_kv) hidden_states.contiguous(), need_gather_q_kv)
decode_preprocess_res, prefill_preprocess_res = self._mla_decode_preprocess( decode_preprocess_res, prefill_preprocess_res = self._mla_decode_preprocess(
@@ -1455,7 +1457,6 @@ class AscendMLAImpl(MLAAttentionImpl):
del o_proj_input del o_proj_input
has_prefill = attn_metadata.num_prefills > 0
if has_prefill: if has_prefill:
maybe_save_kv_layer_to_connector(layer_name, list(kv_cache)) maybe_save_kv_layer_to_connector(layer_name, list(kv_cache))
return output_padded return output_padded

View File

@@ -476,12 +476,14 @@ class AscendSFAImpl(MLAAttentionImpl):
# if mlapo, W_UK_T can't trans nz # if mlapo, W_UK_T can't trans nz
self.W_UK_T = maybe_trans_nz(self.W_UK_T) self.W_UK_T = maybe_trans_nz(self.W_UK_T)
def _v_up_proj(self, x): def _v_up_proj(self, x, has_prefill: bool):
forward_context = get_forward_context() # TODO(zzzzwwjj): We should not judge by whether `has_prefill` or not.
# The true criteria for judgment is tensorA's shape[0] <= 1024 (num_tokens <= 1024).
# This is a bug in the previous code.
if x.dtype in [torch.float16, torch.bfloat16] \ if x.dtype in [torch.float16, torch.bfloat16] \
and hasattr(torch.ops._C_ascend, "batch_matmul_transpose") \ and hasattr(torch.ops._C_ascend, "batch_matmul_transpose") \
and not self.enable_sfa_cp \ and not self.enable_sfa_cp \
and not forward_context.with_prefill: and not has_prefill:
x = x.view(-1, self.num_heads, self.kv_lora_rank) x = x.view(-1, self.num_heads, self.kv_lora_rank)
b, _, _ = x.shape b, _, _ = x.shape
res = torch.empty((b, self.num_heads, self.v_head_dim), res = torch.empty((b, self.num_heads, self.v_head_dim),
@@ -766,7 +768,9 @@ class AscendSFAImpl(MLAAttentionImpl):
# Inputs and outputs may be padded for CUDA graphs # Inputs and outputs may be padded for CUDA graphs
output_padded = output output_padded = output
if self.enable_mlapo and not forward_context.with_prefill: # TODO(zzzzwwjj): In sfa, prefill and decode have the same calculation formula,
# so `has_prefill` here is not necessary.
if self.enable_mlapo and not has_prefill:
hidden_states, ql_nope, q_pe, q_c = self._sfa_preprocess_decode( hidden_states, ql_nope, q_pe, q_c = self._sfa_preprocess_decode(
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache, kv_cache=kv_cache,
@@ -841,7 +845,7 @@ class AscendSFAImpl(MLAAttentionImpl):
layout_kv="PA_BSND", layout_kv="PA_BSND",
sparse_mode=3, sparse_mode=3,
) )
attn_output = self._v_up_proj(attn_output) attn_output = self._v_up_proj(attn_output, has_prefill)
maybe_npu_prefetch(inputs=self.o_proj.weight, maybe_npu_prefetch(inputs=self.o_proj.weight,
dependency=attn_output, dependency=attn_output,
max_size=MAX_O_PROJ_PREFETCH_SIZE, max_size=MAX_O_PROJ_PREFETCH_SIZE,

View File

@@ -299,7 +299,6 @@ class MtpProposer(Proposer):
attn_metadata, attn_metadata,
self.vllm_config, self.vllm_config,
num_tokens=num_tokens, num_tokens=num_tokens,
with_prefill=with_prefill,
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
num_actual_tokens=0, num_actual_tokens=0,
aclgraph_runtime_mode=aclgraph_runtime_mode, aclgraph_runtime_mode=aclgraph_runtime_mode,
@@ -779,7 +778,6 @@ class MtpProposer(Proposer):
attn_metadata, attn_metadata,
self.vllm_config, self.vllm_config,
num_tokens=num_input_tokens, num_tokens=num_input_tokens,
with_prefill=with_prefill,
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
aclgraph_runtime_mode=aclgraph_runtime_mode, aclgraph_runtime_mode=aclgraph_runtime_mode,
batch_descriptor=batch_descriptor, batch_descriptor=batch_descriptor,

View File

@@ -1424,7 +1424,6 @@ class NPUModelRunner(GPUModelRunner):
self.vllm_config, self.vllm_config,
num_tokens=num_input_tokens, num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
with_prefill=self.with_prefill,
aclgraph_runtime_mode=aclgraph_runtime_mode, aclgraph_runtime_mode=aclgraph_runtime_mode,
batch_descriptor=batch_descriptor, batch_descriptor=batch_descriptor,
num_actual_tokens=scheduler_output. num_actual_tokens=scheduler_output.
@@ -2137,7 +2136,6 @@ class NPUModelRunner(GPUModelRunner):
self.vllm_config, self.vllm_config,
num_tokens=num_tokens_padded, num_tokens=num_tokens_padded,
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
with_prefill=with_prefill,
in_profile_run=is_profile, in_profile_run=is_profile,
num_actual_tokens=0, num_actual_tokens=0,
aclgraph_runtime_mode=aclgraph_runtime_mode, aclgraph_runtime_mode=aclgraph_runtime_mode,