[Fix] Refactor dummy attention metadata creation (#3497)
### What this PR does / why we need it? The `force_attention` parameter is designed for flash infer kernel warmup, we don't actually need it on Ascend device (at least for now).And it tends to make things more complicated. So we replace the `force_attention` parameter with `aclgraph_runtime_mode` in the attention metadata creation logic. This change makes the control flow more explicit by directly using the graph runtime mode to determine how to build attention metadata, rather than relying on an intermediate boolean flag. This simplification removes redundant logic and clarifies the conditions for building attention metadata for full decode graph mode. ### Does this PR introduce _any_ user-facing change? None. ### How was this patch tested? DP + `FULL_DECODE_ONLY` + online serving. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
@@ -2250,18 +2250,24 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
scheduler_output.finished_req_ids)
|
||||
return None, None
|
||||
|
||||
def _build_attention_metadata(self, create_mixed_batch, num_reqs,
|
||||
num_tokens, max_query_len, force_attention):
|
||||
def _build_dummy_attn_metadata(
|
||||
self,
|
||||
with_prefill: bool,
|
||||
num_reqs: int,
|
||||
num_tokens: int,
|
||||
max_query_len: int,
|
||||
aclgraph_runtime_mode: Optional[CUDAGraphMode] = None,
|
||||
force_attention: bool = False,
|
||||
) -> Optional[dict[str, Any]]:
|
||||
attn_metadata: Optional[dict[str, Any]] = None
|
||||
|
||||
if force_attention:
|
||||
if force_attention or aclgraph_runtime_mode == CUDAGraphMode.FULL:
|
||||
assert with_prefill is False, \
|
||||
"Full decode graph only supports uniform batch now."
|
||||
|
||||
attn_metadata = {}
|
||||
|
||||
if create_mixed_batch:
|
||||
raise NotImplementedError(
|
||||
"force_attention=True is not supported for mixed batches.")
|
||||
else:
|
||||
seq_lens = self.model_config.max_model_len
|
||||
seq_lens = self.model_config.max_model_len
|
||||
self.seq_lens_np[:num_reqs] = seq_lens
|
||||
self.seq_lens_np[num_reqs:] = 0
|
||||
|
||||
@@ -2321,7 +2327,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
forward_context = get_forward_context()
|
||||
assert forward_context is not None
|
||||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and \
|
||||
not forward_context.capturing:
|
||||
not forward_context.capturing and forward_context.attn_metadata is not None:
|
||||
if self.vllm_config.model_config.use_mla:
|
||||
# FIXME: Try using `auto_dispatch_capture=True`
|
||||
update_mla_attn_params(self.update_stream, forward_context,
|
||||
@@ -2409,12 +2415,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
if self.is_kv_producer and not self.is_kv_consumer:
|
||||
with_prefill = True
|
||||
|
||||
# TODO(cmq): check if with_prefill is reasonable
|
||||
attn_metadata = self._build_attention_metadata(
|
||||
# TODO(Mengqing): Set create_mixed_batch to False since it's only used in FI warmup
|
||||
# and not supported in ASCEND now. We could remove it in the future.
|
||||
attn_metadata = self._build_dummy_attn_metadata(
|
||||
False,
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_tokens,
|
||||
max_query_len=max_query_len,
|
||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||
force_attention=force_attention,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user