[main][Refactor] Remove with_prefill parameter from set_ascend_forward_context (#5094)
Removes the redundant `with_prefill` parameter from
`set_ascend_forward_context` to align the interface with vLLM's
`set_forward_context` for future refactoring.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
Signed-off-by: Slightwind <slightwindsec@gmail.com>
Co-authored-by: zzzzwwjj <34335947+zzzzwwjj@users.noreply.github.com>
This commit is contained in:
@@ -31,7 +31,6 @@ def set_ascend_forward_context(
|
|||||||
virtual_engine: int = 0,
|
virtual_engine: int = 0,
|
||||||
num_tokens: int = 0,
|
num_tokens: int = 0,
|
||||||
num_tokens_across_dp: Optional[torch.Tensor] = None,
|
num_tokens_across_dp: Optional[torch.Tensor] = None,
|
||||||
with_prefill: bool = True,
|
|
||||||
in_profile_run: bool = False,
|
in_profile_run: bool = False,
|
||||||
num_actual_tokens: Optional[int] = None,
|
num_actual_tokens: Optional[int] = None,
|
||||||
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||||
@@ -60,7 +59,6 @@ def set_ascend_forward_context(
|
|||||||
forward_context.moe_comm_type = moe_comm_type
|
forward_context.moe_comm_type = moe_comm_type
|
||||||
forward_context.moe_comm_method = get_moe_comm_method(moe_comm_type)
|
forward_context.moe_comm_method = get_moe_comm_method(moe_comm_type)
|
||||||
|
|
||||||
forward_context.with_prefill = with_prefill
|
|
||||||
tp_world_size = get_tensor_model_parallel_world_size()
|
tp_world_size = get_tensor_model_parallel_world_size()
|
||||||
|
|
||||||
forward_context.in_profile_run = in_profile_run
|
forward_context.in_profile_run = in_profile_run
|
||||||
|
|||||||
@@ -594,6 +594,9 @@ class AscendMlaCPImpl(AscendMLAImpl):
|
|||||||
self.vllm_config, self.o_proj):
|
self.vllm_config, self.o_proj):
|
||||||
reach_layer_for_shared_weight_series(self.o_proj)
|
reach_layer_for_shared_weight_series(self.o_proj)
|
||||||
return output.fill_(0)
|
return output.fill_(0)
|
||||||
|
|
||||||
|
forward_context = get_forward_context()
|
||||||
|
|
||||||
if self.pcp_size > 1:
|
if self.pcp_size > 1:
|
||||||
num_actual_tokens = attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size
|
num_actual_tokens = attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size
|
||||||
else:
|
else:
|
||||||
@@ -601,19 +604,19 @@ class AscendMlaCPImpl(AscendMLAImpl):
|
|||||||
assert attn_metadata.num_decodes is not None and \
|
assert attn_metadata.num_decodes is not None and \
|
||||||
attn_metadata.num_prefills is not None and \
|
attn_metadata.num_prefills is not None and \
|
||||||
attn_metadata.num_decode_tokens is not None
|
attn_metadata.num_decode_tokens is not None
|
||||||
|
|
||||||
|
has_prefill = attn_metadata.num_prefills > 0
|
||||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||||
# Inputs and outputs may be padded for CUDA graphs
|
# Inputs and outputs may be padded for CUDA graphs
|
||||||
output_padded = output
|
output_padded = output
|
||||||
o_proj_input_shape = (get_forward_context().num_tokens,
|
o_proj_input_shape = (forward_context.num_tokens,
|
||||||
self.num_heads * self.v_head_dim)
|
self.num_heads * self.v_head_dim)
|
||||||
o_proj_input = torch.empty(o_proj_input_shape,
|
o_proj_input = torch.empty(o_proj_input_shape,
|
||||||
dtype=hidden_states.dtype,
|
dtype=hidden_states.dtype,
|
||||||
device=hidden_states.device)
|
device=hidden_states.device)
|
||||||
|
|
||||||
# MLA Preprocess
|
# MLA Preprocess
|
||||||
forward_context = get_forward_context()
|
if self.enable_mlapo and not has_prefill:
|
||||||
if (self.enable_mlapo and
|
|
||||||
(attn_metadata is None or not forward_context.with_prefill)):
|
|
||||||
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||||
hidden_states.contiguous(), need_gather_q_kv)
|
hidden_states.contiguous(), need_gather_q_kv)
|
||||||
decode_preprocess_res, prefill_preprocess_res = self._mla_decode_preprocess(
|
decode_preprocess_res, prefill_preprocess_res = self._mla_decode_preprocess(
|
||||||
@@ -671,7 +674,6 @@ class AscendMlaCPImpl(AscendMLAImpl):
|
|||||||
|
|
||||||
del o_proj_input
|
del o_proj_input
|
||||||
|
|
||||||
has_prefill = attn_metadata.num_prefills > 0
|
|
||||||
if has_prefill:
|
if has_prefill:
|
||||||
maybe_save_kv_layer_to_connector(layer_name, list(kv_cache))
|
maybe_save_kv_layer_to_connector(layer_name, list(kv_cache))
|
||||||
return output_padded
|
return output_padded
|
||||||
|
|||||||
@@ -1395,23 +1395,25 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
self.vllm_config, self.o_proj):
|
self.vllm_config, self.o_proj):
|
||||||
reach_layer_for_shared_weight_series(self.o_proj)
|
reach_layer_for_shared_weight_series(self.o_proj)
|
||||||
return output.fill_(0)
|
return output.fill_(0)
|
||||||
|
|
||||||
|
forward_context = get_forward_context()
|
||||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||||
assert attn_metadata.num_decodes is not None and \
|
assert attn_metadata.num_decodes is not None and \
|
||||||
attn_metadata.num_prefills is not None and \
|
attn_metadata.num_prefills is not None and \
|
||||||
attn_metadata.num_decode_tokens is not None
|
attn_metadata.num_decode_tokens is not None
|
||||||
|
|
||||||
|
has_prefill = attn_metadata.num_prefills > 0
|
||||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||||
# Inputs and outputs may be padded for CUDA graphs
|
# Inputs and outputs may be padded for CUDA graphs
|
||||||
output_padded = output
|
output_padded = output
|
||||||
o_proj_input_shape = (get_forward_context().num_tokens,
|
o_proj_input_shape = (forward_context.num_tokens,
|
||||||
self.num_heads * self.v_head_dim)
|
self.num_heads * self.v_head_dim)
|
||||||
o_proj_input = torch.empty(o_proj_input_shape,
|
o_proj_input = torch.empty(o_proj_input_shape,
|
||||||
dtype=hidden_states.dtype,
|
dtype=hidden_states.dtype,
|
||||||
device=hidden_states.device)
|
device=hidden_states.device)
|
||||||
|
|
||||||
# MLA Preprocess
|
# MLA Preprocess
|
||||||
forward_context = get_forward_context()
|
if self.enable_mlapo and not has_prefill:
|
||||||
if (self.enable_mlapo and
|
|
||||||
(attn_metadata is None or not forward_context.with_prefill)):
|
|
||||||
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||||
hidden_states.contiguous(), need_gather_q_kv)
|
hidden_states.contiguous(), need_gather_q_kv)
|
||||||
decode_preprocess_res, prefill_preprocess_res = self._mla_decode_preprocess(
|
decode_preprocess_res, prefill_preprocess_res = self._mla_decode_preprocess(
|
||||||
@@ -1455,7 +1457,6 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
|
|
||||||
del o_proj_input
|
del o_proj_input
|
||||||
|
|
||||||
has_prefill = attn_metadata.num_prefills > 0
|
|
||||||
if has_prefill:
|
if has_prefill:
|
||||||
maybe_save_kv_layer_to_connector(layer_name, list(kv_cache))
|
maybe_save_kv_layer_to_connector(layer_name, list(kv_cache))
|
||||||
return output_padded
|
return output_padded
|
||||||
|
|||||||
@@ -476,12 +476,14 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
# if mlapo, W_UK_T can't trans nz
|
# if mlapo, W_UK_T can't trans nz
|
||||||
self.W_UK_T = maybe_trans_nz(self.W_UK_T)
|
self.W_UK_T = maybe_trans_nz(self.W_UK_T)
|
||||||
|
|
||||||
def _v_up_proj(self, x):
|
def _v_up_proj(self, x, has_prefill: bool):
|
||||||
forward_context = get_forward_context()
|
# TODO(zzzzwwjj): We should not judge by whether `has_prefill` or not.
|
||||||
|
# The true criteria for judgment is tensorA's shape[0] <= 1024 (num_tokens <= 1024).
|
||||||
|
# This is a bug in the previous code.
|
||||||
if x.dtype in [torch.float16, torch.bfloat16] \
|
if x.dtype in [torch.float16, torch.bfloat16] \
|
||||||
and hasattr(torch.ops._C_ascend, "batch_matmul_transpose") \
|
and hasattr(torch.ops._C_ascend, "batch_matmul_transpose") \
|
||||||
and not self.enable_sfa_cp \
|
and not self.enable_sfa_cp \
|
||||||
and not forward_context.with_prefill:
|
and not has_prefill:
|
||||||
x = x.view(-1, self.num_heads, self.kv_lora_rank)
|
x = x.view(-1, self.num_heads, self.kv_lora_rank)
|
||||||
b, _, _ = x.shape
|
b, _, _ = x.shape
|
||||||
res = torch.empty((b, self.num_heads, self.v_head_dim),
|
res = torch.empty((b, self.num_heads, self.v_head_dim),
|
||||||
@@ -766,7 +768,9 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
# Inputs and outputs may be padded for CUDA graphs
|
# Inputs and outputs may be padded for CUDA graphs
|
||||||
output_padded = output
|
output_padded = output
|
||||||
|
|
||||||
if self.enable_mlapo and not forward_context.with_prefill:
|
# TODO(zzzzwwjj): In sfa, prefill and decode have the same calculation formula,
|
||||||
|
# so `has_prefill` here is not necessary.
|
||||||
|
if self.enable_mlapo and not has_prefill:
|
||||||
hidden_states, ql_nope, q_pe, q_c = self._sfa_preprocess_decode(
|
hidden_states, ql_nope, q_pe, q_c = self._sfa_preprocess_decode(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
@@ -841,7 +845,7 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
layout_kv="PA_BSND",
|
layout_kv="PA_BSND",
|
||||||
sparse_mode=3,
|
sparse_mode=3,
|
||||||
)
|
)
|
||||||
attn_output = self._v_up_proj(attn_output)
|
attn_output = self._v_up_proj(attn_output, has_prefill)
|
||||||
maybe_npu_prefetch(inputs=self.o_proj.weight,
|
maybe_npu_prefetch(inputs=self.o_proj.weight,
|
||||||
dependency=attn_output,
|
dependency=attn_output,
|
||||||
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
||||||
|
|||||||
@@ -299,7 +299,6 @@ class MtpProposer(Proposer):
|
|||||||
attn_metadata,
|
attn_metadata,
|
||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
num_tokens=num_tokens,
|
num_tokens=num_tokens,
|
||||||
with_prefill=with_prefill,
|
|
||||||
num_tokens_across_dp=num_tokens_across_dp,
|
num_tokens_across_dp=num_tokens_across_dp,
|
||||||
num_actual_tokens=0,
|
num_actual_tokens=0,
|
||||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||||
@@ -779,7 +778,6 @@ class MtpProposer(Proposer):
|
|||||||
attn_metadata,
|
attn_metadata,
|
||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
num_tokens=num_input_tokens,
|
num_tokens=num_input_tokens,
|
||||||
with_prefill=with_prefill,
|
|
||||||
num_tokens_across_dp=num_tokens_across_dp,
|
num_tokens_across_dp=num_tokens_across_dp,
|
||||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||||
batch_descriptor=batch_descriptor,
|
batch_descriptor=batch_descriptor,
|
||||||
|
|||||||
@@ -1424,7 +1424,6 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
num_tokens=num_input_tokens,
|
num_tokens=num_input_tokens,
|
||||||
num_tokens_across_dp=num_tokens_across_dp,
|
num_tokens_across_dp=num_tokens_across_dp,
|
||||||
with_prefill=self.with_prefill,
|
|
||||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||||
batch_descriptor=batch_descriptor,
|
batch_descriptor=batch_descriptor,
|
||||||
num_actual_tokens=scheduler_output.
|
num_actual_tokens=scheduler_output.
|
||||||
@@ -2137,7 +2136,6 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
num_tokens=num_tokens_padded,
|
num_tokens=num_tokens_padded,
|
||||||
num_tokens_across_dp=num_tokens_across_dp,
|
num_tokens_across_dp=num_tokens_across_dp,
|
||||||
with_prefill=with_prefill,
|
|
||||||
in_profile_run=is_profile,
|
in_profile_run=is_profile,
|
||||||
num_actual_tokens=0,
|
num_actual_tokens=0,
|
||||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||||
|
|||||||
Reference in New Issue
Block a user