[main][Refactor] Remove with_prefill parameter from set_ascend_forward_context (#5094)
Removes the redundant `with_prefill` parameter from
`set_ascend_forward_context` to align the interface with vLLM's
`set_forward_context` for future refactoring.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
Signed-off-by: Slightwind <slightwindsec@gmail.com>
Co-authored-by: zzzzwwjj <34335947+zzzzwwjj@users.noreply.github.com>
This commit is contained in:
@@ -31,7 +31,6 @@ def set_ascend_forward_context(
|
||||
virtual_engine: int = 0,
|
||||
num_tokens: int = 0,
|
||||
num_tokens_across_dp: Optional[torch.Tensor] = None,
|
||||
with_prefill: bool = True,
|
||||
in_profile_run: bool = False,
|
||||
num_actual_tokens: Optional[int] = None,
|
||||
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
@@ -60,7 +59,6 @@ def set_ascend_forward_context(
|
||||
forward_context.moe_comm_type = moe_comm_type
|
||||
forward_context.moe_comm_method = get_moe_comm_method(moe_comm_type)
|
||||
|
||||
forward_context.with_prefill = with_prefill
|
||||
tp_world_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
forward_context.in_profile_run = in_profile_run
|
||||
|
||||
@@ -594,6 +594,9 @@ class AscendMlaCPImpl(AscendMLAImpl):
|
||||
self.vllm_config, self.o_proj):
|
||||
reach_layer_for_shared_weight_series(self.o_proj)
|
||||
return output.fill_(0)
|
||||
|
||||
forward_context = get_forward_context()
|
||||
|
||||
if self.pcp_size > 1:
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size
|
||||
else:
|
||||
@@ -601,19 +604,19 @@ class AscendMlaCPImpl(AscendMLAImpl):
|
||||
assert attn_metadata.num_decodes is not None and \
|
||||
attn_metadata.num_prefills is not None and \
|
||||
attn_metadata.num_decode_tokens is not None
|
||||
|
||||
has_prefill = attn_metadata.num_prefills > 0
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
# Inputs and outputs may be padded for CUDA graphs
|
||||
output_padded = output
|
||||
o_proj_input_shape = (get_forward_context().num_tokens,
|
||||
o_proj_input_shape = (forward_context.num_tokens,
|
||||
self.num_heads * self.v_head_dim)
|
||||
o_proj_input = torch.empty(o_proj_input_shape,
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
|
||||
# MLA Preprocess
|
||||
forward_context = get_forward_context()
|
||||
if (self.enable_mlapo and
|
||||
(attn_metadata is None or not forward_context.with_prefill)):
|
||||
if self.enable_mlapo and not has_prefill:
|
||||
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
hidden_states.contiguous(), need_gather_q_kv)
|
||||
decode_preprocess_res, prefill_preprocess_res = self._mla_decode_preprocess(
|
||||
@@ -671,7 +674,6 @@ class AscendMlaCPImpl(AscendMLAImpl):
|
||||
|
||||
del o_proj_input
|
||||
|
||||
has_prefill = attn_metadata.num_prefills > 0
|
||||
if has_prefill:
|
||||
maybe_save_kv_layer_to_connector(layer_name, list(kv_cache))
|
||||
return output_padded
|
||||
|
||||
@@ -1395,23 +1395,25 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
self.vllm_config, self.o_proj):
|
||||
reach_layer_for_shared_weight_series(self.o_proj)
|
||||
return output.fill_(0)
|
||||
|
||||
forward_context = get_forward_context()
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
assert attn_metadata.num_decodes is not None and \
|
||||
attn_metadata.num_prefills is not None and \
|
||||
attn_metadata.num_decode_tokens is not None
|
||||
|
||||
has_prefill = attn_metadata.num_prefills > 0
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
# Inputs and outputs may be padded for CUDA graphs
|
||||
output_padded = output
|
||||
o_proj_input_shape = (get_forward_context().num_tokens,
|
||||
o_proj_input_shape = (forward_context.num_tokens,
|
||||
self.num_heads * self.v_head_dim)
|
||||
o_proj_input = torch.empty(o_proj_input_shape,
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
|
||||
# MLA Preprocess
|
||||
forward_context = get_forward_context()
|
||||
if (self.enable_mlapo and
|
||||
(attn_metadata is None or not forward_context.with_prefill)):
|
||||
if self.enable_mlapo and not has_prefill:
|
||||
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
hidden_states.contiguous(), need_gather_q_kv)
|
||||
decode_preprocess_res, prefill_preprocess_res = self._mla_decode_preprocess(
|
||||
@@ -1455,7 +1457,6 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
|
||||
del o_proj_input
|
||||
|
||||
has_prefill = attn_metadata.num_prefills > 0
|
||||
if has_prefill:
|
||||
maybe_save_kv_layer_to_connector(layer_name, list(kv_cache))
|
||||
return output_padded
|
||||
|
||||
@@ -476,12 +476,14 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
# if mlapo, W_UK_T can't trans nz
|
||||
self.W_UK_T = maybe_trans_nz(self.W_UK_T)
|
||||
|
||||
def _v_up_proj(self, x):
|
||||
forward_context = get_forward_context()
|
||||
def _v_up_proj(self, x, has_prefill: bool):
|
||||
# TODO(zzzzwwjj): We should not judge by whether `has_prefill` or not.
|
||||
# The true criteria for judgment is tensorA's shape[0] <= 1024 (num_tokens <= 1024).
|
||||
# This is a bug in the previous code.
|
||||
if x.dtype in [torch.float16, torch.bfloat16] \
|
||||
and hasattr(torch.ops._C_ascend, "batch_matmul_transpose") \
|
||||
and not self.enable_sfa_cp \
|
||||
and not forward_context.with_prefill:
|
||||
and not has_prefill:
|
||||
x = x.view(-1, self.num_heads, self.kv_lora_rank)
|
||||
b, _, _ = x.shape
|
||||
res = torch.empty((b, self.num_heads, self.v_head_dim),
|
||||
@@ -766,7 +768,9 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
# Inputs and outputs may be padded for CUDA graphs
|
||||
output_padded = output
|
||||
|
||||
if self.enable_mlapo and not forward_context.with_prefill:
|
||||
# TODO(zzzzwwjj): In sfa, prefill and decode have the same calculation formula,
|
||||
# so `has_prefill` here is not necessary.
|
||||
if self.enable_mlapo and not has_prefill:
|
||||
hidden_states, ql_nope, q_pe, q_c = self._sfa_preprocess_decode(
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
@@ -841,7 +845,7 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
layout_kv="PA_BSND",
|
||||
sparse_mode=3,
|
||||
)
|
||||
attn_output = self._v_up_proj(attn_output)
|
||||
attn_output = self._v_up_proj(attn_output, has_prefill)
|
||||
maybe_npu_prefetch(inputs=self.o_proj.weight,
|
||||
dependency=attn_output,
|
||||
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
||||
|
||||
@@ -299,7 +299,6 @@ class MtpProposer(Proposer):
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_tokens,
|
||||
with_prefill=with_prefill,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
num_actual_tokens=0,
|
||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||
@@ -779,7 +778,6 @@ class MtpProposer(Proposer):
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
with_prefill=with_prefill,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||
batch_descriptor=batch_descriptor,
|
||||
|
||||
@@ -1424,7 +1424,6 @@ class NPUModelRunner(GPUModelRunner):
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
with_prefill=self.with_prefill,
|
||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||
batch_descriptor=batch_descriptor,
|
||||
num_actual_tokens=scheduler_output.
|
||||
@@ -2137,7 +2136,6 @@ class NPUModelRunner(GPUModelRunner):
|
||||
self.vllm_config,
|
||||
num_tokens=num_tokens_padded,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
with_prefill=with_prefill,
|
||||
in_profile_run=is_profile,
|
||||
num_actual_tokens=0,
|
||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||
|
||||
Reference in New Issue
Block a user