[310P][Bugfix]: fix ngram graph replay accuracy error (#7134)
### What this PR does / why we need it?
On the 310P device, when running ACLGraph together with the n-gram
speculative decoding algorithm, both graph capture and graph replay
require `uniform_decode_query_len` and do not depend on
`attention_state`. This leads to a rather interesting and unexpected
issue on 310P: during decode-only, execution does **not** enter the
graph, while in the split-fuse state (that is, the chunked prefill
state), it instead enters graph execution directly.
The issue can be resolved by forcibly setting `uniform_decode_query_len`
to `1`, so that 310P captures only the decode-only graph, and replay is
then controlled through `attention_state`.
### Does this PR introduce _any_ user-facing change?
NO
- vLLM version: v0.16.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: Tflowers-0129 <2906339855@qq.com>
This commit is contained in:
@@ -17,6 +17,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager, nullcontext
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch_npu
|
||||
@@ -24,14 +26,151 @@ from vllm.logger import logger
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, MambaSpec
|
||||
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ
|
||||
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
||||
|
||||
_NGRAM_GRAPH_UNIFORM_DECODE_QUERY_LEN = 1
|
||||
|
||||
|
||||
class NPUModelRunner310(NPUModelRunner):
|
||||
# Inherited from parent runner; annotated here to satisfy strict type checks.
|
||||
uniform_decode_query_len: int
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._acl_format = ACL_FORMAT_FRACTAL_NZ
|
||||
if self.speculative_config is not None and self.speculative_config.method == "ngram":
|
||||
# 310P ngram requires decode-only graph shapes to be built with q_len=1.
|
||||
# Keep dispatcher's internal query_len in sync to avoid key-init assert.
|
||||
self.cudagraph_dispatcher.uniform_decode_query_len = _NGRAM_GRAPH_UNIFORM_DECODE_QUERY_LEN
|
||||
|
||||
@contextmanager
|
||||
def temporary_modify_uniform_decode_query_len(self):
|
||||
# This is only needed for the 310P ngram path where dispatcher uses q_len=1
|
||||
# while runner's default uniform_decode_query_len remains 1 + num_spec_tokens.
|
||||
# TODO: remove this temporary override after upstream supports independent
|
||||
# decode capture query_len for backend-specific paths.
|
||||
if self.speculative_config is None or self.speculative_config.method != "ngram":
|
||||
yield
|
||||
return
|
||||
|
||||
original_uniform_decode_query_len = self.uniform_decode_query_len
|
||||
self.uniform_decode_query_len = _NGRAM_GRAPH_UNIFORM_DECODE_QUERY_LEN
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.uniform_decode_query_len = original_uniform_decode_query_len
|
||||
|
||||
def _determine_batch_execution_and_padding(
|
||||
self,
|
||||
num_tokens: int,
|
||||
num_reqs: int,
|
||||
num_scheduled_tokens_np: np.ndarray,
|
||||
max_num_scheduled_tokens: int,
|
||||
use_cascade_attn: bool,
|
||||
allow_microbatching: bool = False,
|
||||
force_eager: bool = False,
|
||||
force_uniform_decode: bool | None = None,
|
||||
force_has_lora: bool | None = None,
|
||||
force_num_active_loras: int | None = None,
|
||||
num_encoder_reqs: int = 0,
|
||||
):
|
||||
if self.attn_state in (AscendAttentionState.ChunkedPrefill, AscendAttentionState.PrefillCacheHit):
|
||||
force_eager = True
|
||||
|
||||
if force_uniform_decode is None and self.attn_state == AscendAttentionState.DecodeOnly:
|
||||
decode_query_len = _NGRAM_GRAPH_UNIFORM_DECODE_QUERY_LEN
|
||||
if (
|
||||
max_num_scheduled_tokens == decode_query_len
|
||||
and num_tokens == max_num_scheduled_tokens * num_reqs
|
||||
and np.all(self.input_batch.num_computed_tokens_cpu[:num_reqs] > 0)
|
||||
):
|
||||
# Respect explicit caller override: only force when unset.
|
||||
force_uniform_decode = True
|
||||
|
||||
return super()._determine_batch_execution_and_padding(
|
||||
num_tokens=num_tokens,
|
||||
num_reqs=num_reqs,
|
||||
num_scheduled_tokens_np=num_scheduled_tokens_np,
|
||||
max_num_scheduled_tokens=max_num_scheduled_tokens,
|
||||
use_cascade_attn=use_cascade_attn,
|
||||
allow_microbatching=allow_microbatching,
|
||||
force_eager=force_eager,
|
||||
force_uniform_decode=force_uniform_decode,
|
||||
force_has_lora=force_has_lora,
|
||||
force_num_active_loras=force_num_active_loras,
|
||||
num_encoder_reqs=num_encoder_reqs,
|
||||
)
|
||||
|
||||
def _pad_query_start_loc_for_fia(self, num_tokens_padded: int, num_reqs_padded: int, num_reqs: int) -> int:
|
||||
# Keep this aligned with the dispatcher because batch_desc.num_reqs is
|
||||
# generated by dispatcher._create_padded_batch_descriptor().
|
||||
# For 310P ngram we intentionally set dispatcher q_len=1, while runner's
|
||||
# default uniform_decode_query_len may remain 1 + num_spec_tokens.
|
||||
uniform_decode_query_len = self.cudagraph_dispatcher.uniform_decode_query_len
|
||||
|
||||
if num_tokens_padded == num_reqs_padded * uniform_decode_query_len:
|
||||
# Uniform-batch case: num_reqs must be no greater than num_reqs_padded
|
||||
assert num_reqs <= num_reqs_padded
|
||||
|
||||
last_loc = self.query_start_loc.np[num_reqs]
|
||||
self.query_start_loc.np[num_reqs + 1 : num_reqs_padded + 1] = (
|
||||
self.arange_np[1 : num_reqs_padded + 1 - num_reqs] * uniform_decode_query_len + last_loc
|
||||
)
|
||||
else:
|
||||
# Mixed-batch case: num_reqs must equal num_reqs_padded
|
||||
assert num_reqs == num_reqs_padded
|
||||
|
||||
# Insert a dummy request instead of setting query_start_loc[num_reqs] = num_tokens_padded directly
|
||||
self.query_start_loc.np[num_reqs_padded + 1] = num_tokens_padded
|
||||
num_reqs_padded = num_reqs_padded + 1
|
||||
|
||||
self.query_start_loc.copy_to_gpu()
|
||||
return num_reqs_padded
|
||||
|
||||
@torch.inference_mode()
|
||||
def _dummy_run(
|
||||
self,
|
||||
num_tokens: int,
|
||||
with_prefill: bool = False,
|
||||
cudagraph_runtime_mode=None,
|
||||
force_attention: bool = False,
|
||||
uniform_decode: bool = False,
|
||||
is_profile: bool = False,
|
||||
create_mixed_batch: bool = False,
|
||||
allow_microbatching: bool = True,
|
||||
skip_eplb: bool = False,
|
||||
remove_lora: bool = True,
|
||||
is_graph_capturing: bool = False,
|
||||
num_active_loras: int = 0,
|
||||
):
|
||||
temporary_context = self.temporary_modify_uniform_decode_query_len() if uniform_decode else nullcontext()
|
||||
with temporary_context:
|
||||
return super()._dummy_run(
|
||||
num_tokens=num_tokens,
|
||||
with_prefill=with_prefill,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
force_attention=force_attention,
|
||||
uniform_decode=uniform_decode,
|
||||
is_profile=is_profile,
|
||||
create_mixed_batch=create_mixed_batch,
|
||||
allow_microbatching=allow_microbatching,
|
||||
skip_eplb=skip_eplb,
|
||||
remove_lora=remove_lora,
|
||||
is_graph_capturing=is_graph_capturing,
|
||||
num_active_loras=num_active_loras,
|
||||
)
|
||||
|
||||
def _check_and_update_cudagraph_mode(
|
||||
self,
|
||||
attention_backends,
|
||||
kv_cache_groups,
|
||||
) -> None:
|
||||
# 910B does not need this branch because runner/dispatcher query_len are
|
||||
# naturally consistent there. 310P ngram needs temporary alignment.
|
||||
with self.temporary_modify_uniform_decode_query_len():
|
||||
super()._check_and_update_cudagraph_mode(attention_backends, kv_cache_groups)
|
||||
|
||||
def initialize_kv_cache_tensors(self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user