Spec decode support for V1 Engine (#874)

<!--  Thanks for sending a pull request!

BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html

-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.

- Please clarify why the changes are needed. For instance, the use case
and bug description.

- Fixes #
-->
Make spec decode support for V1 Engine
- Currently, Ascend does not support the triton kernel. PyTorch is used
to rewrite the `rejection_sampler.py` triton kernel. However, PyTorch is
not as good as Triton. Therefore, ascend c is used to implement the
function in the future.
- Currently, spec decode supports only the ngram algorithm. The eagle
algorithm needs to be further adapted.
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
Not change user facing.

### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
test by `tests/singlecard/spec_decode/e2e/test_v1_spec_decode.py` and
`tests/sample/test_rejection_sampler.py`, test base function of
rejection sampler and e2e function of spec decode.

Signed-off-by: ponix-j <657511300@qq.com>
This commit is contained in:
jiangpeng
2025-05-23 14:25:46 +08:00
committed by GitHub
parent a970b27e2d
commit df58fb80ee
12 changed files with 1553 additions and 12 deletions

View File

@@ -47,7 +47,12 @@ from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.sampler import Sampler
from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.spec_decode.utils import is_spec_decode_supported
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
@@ -55,6 +60,7 @@ from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm_ascend.attention.attention import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
from vllm_ascend.utils import vllm_version_is
if TYPE_CHECKING:
@@ -110,6 +116,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.model_config = vllm_config.model_config
self.lora_config = vllm_config.lora_config
self.scheduler_config = vllm_config.scheduler_config
self.speculative_config = vllm_config.speculative_config
self.chunked_prefill_enabled = vllm_config.scheduler_config.chunked_prefill_enabled
self.device = device
@@ -202,6 +209,21 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# req_id -> (input_id -> encoder_output)
self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {}
# Set up speculative decoding.
self.use_spec_decode = False
if self.speculative_config:
self.use_spec_decode = True
if get_pp_group().is_last_rank:
if self.speculative_config.method == "ngram":
self.drafter = NgramProposer(self.vllm_config)
elif self.speculative_config.method == "eagle":
self.drafter = EagleProposer(self.vllm_config,
self.device) # type: ignore
else:
raise ValueError("Unknown speculative decoding method: "
f"{self.speculative_config.method}")
self.rejection_sampler = AscendRejectionSampler()
# Request states.
self.requests: Dict[str, CachedRequestState] = {}
# Persistent batch.
@@ -511,7 +533,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self,
scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
) -> tuple[SpecDecodeMetadata, torch.Tensor, SpecDecodeMetadata,
torch.Tensor, int, torch.Tensor]:
# Check input valid
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0
@@ -523,6 +546,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_input_tokens = self.vllm_config.pad_for_cudagraph(
total_num_scheduled_tokens)
else:
# Eager mode.
num_input_tokens = total_num_scheduled_tokens
modified_batch = self.attn_metadata_builder.reorder_batch(
@@ -615,6 +639,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
common_prefix_len=None,
**extra_builder_kwargs,
)
attn_metadata.num_input_tokens = num_input_tokens
# Prepare input_ids
token_indices = (positions_np +
@@ -670,7 +695,106 @@ class NPUModelRunner(LoRAModelRunnerMixin):
**model_kwargs,
)
return hidden_states[sample_indices]
use_spec_decode = len(
scheduler_output.scheduled_spec_decode_tokens) > 0
if not use_spec_decode:
# NOTE(woosuk): Due to chunked prefills, the batch may contain
# partial requests. While we should not sample any token
# from these partial requests, we do so for simplicity.
# We will ignore the sampled tokens from the partial requests.
# TODO: Support prompt logprobs.
spec_decode_metadata = None
else:
# Get the number of draft tokens for each request.
# Iterate over the dictionary rather than all requests since not all
# requests have draft tokens.
num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
for req_id, draft_token_ids in (
scheduler_output.scheduled_spec_decode_tokens.items()):
req_idx = self.input_batch.req_id_to_index[req_id]
num_draft_tokens[req_idx] = len(draft_token_ids)
spec_decode_metadata = self._calc_spec_decode_metadata(
num_draft_tokens, cu_num_tokens)
sample_indices = spec_decode_metadata.logits_indices
return (attn_metadata, hidden_states, spec_decode_metadata, positions,
total_num_scheduled_tokens, sample_indices)
def _calc_spec_decode_metadata(
self,
num_draft_tokens: np.ndarray,
cu_num_scheduled_tokens: np.ndarray,
) -> SpecDecodeMetadata:
# Inputs:
# cu_num_scheduled_tokens: [ 4, 104, 107, 207, 209]
# num_draft_tokens: [ 3, 0, 2, 0, 1]
# Outputs:
# cu_num_draft_tokens: [ 3, 3, 5, 5, 6]
# logits_indices: [ 0, 1, 2, 3, 103, 104, 105, 106,
# 206, 207, 208]
# target_logits_indices: [ 0, 1, 2, 5, 6, 9]
# bonus_logits_indices: [ 3, 4, 7, 8, 10]
# Compute the logits indices.
# [4, 1, 3, 1, 2]
num_sampled_tokens = num_draft_tokens + 1
# Step 1. [4, 5, 8, 9, 11]
cu_num_sampled_tokens = np.cumsum(num_sampled_tokens, dtype=np.int32)
total_num_sampled_tokens = cu_num_sampled_tokens[-1]
# Step 2. [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
cumsums_offsets = np.repeat(cu_num_sampled_tokens - num_sampled_tokens,
num_sampled_tokens)
# Step 3. [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
arange = self.arange_np[:total_num_sampled_tokens] - cumsums_offsets
# Step 4. [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207]
logits_indices = np.repeat(
cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens)
# Step 5. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
logits_indices += arange
# Compute the bonus logits indices.
bonus_logits_indices = cu_num_sampled_tokens - 1
# Compute the draft logits indices.
# [3, 3, 5, 5, 6]
cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32)
total_num_draft_tokens = cu_num_draft_tokens[-1]
# [0, 0, 0, 3, 3, 5]
cumsums_offsets = np.repeat(cu_num_draft_tokens - num_draft_tokens,
num_draft_tokens)
# [0, 1, 2, 0, 1, 0]
arange = self.arange_np[:total_num_draft_tokens] - cumsums_offsets
# [0, 0, 0, 5, 5, 9]
target_logits_indices = np.repeat(
cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens)
# [0, 1, 2, 5, 6, 9]
target_logits_indices += arange
# TODO: Optimize the CPU -> NPU copy.
cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to(
self.device, non_blocking=True)
logits_indices = torch.from_numpy(logits_indices).to(self.device,
non_blocking=True)
target_logits_indices = torch.from_numpy(target_logits_indices).to(
self.device, non_blocking=True)
bonus_logits_indices = torch.from_numpy(bonus_logits_indices).to(
self.device, non_blocking=True)
# Compute the draft token ids.
# draft_token_indices: [ 1, 2, 3, 105, 106, 208]
draft_token_ids = self.input_ids[logits_indices]
draft_token_ids = draft_token_ids[target_logits_indices + 1]
metadata = SpecDecodeMetadata(
draft_token_ids=draft_token_ids,
num_draft_tokens=num_draft_tokens.tolist(),
cu_num_draft_tokens=cu_num_draft_tokens,
target_logits_indices=target_logits_indices,
bonus_logits_indices=bonus_logits_indices,
logits_indices=logits_indices,
)
return metadata
def apply_grammar_bitmask(
self,
@@ -726,6 +850,30 @@ class NPUModelRunner(LoRAModelRunnerMixin):
)
return logits.to(self.device).to(logits_dtype)
def _get_spec_token_ids(
self,
valid_sampled_token_ids: list[list[int]],
sampling_metadata: SamplingMetadata,
scheduler_output: "SchedulerOutput",
spec_decode_metadata: SpecDecodeMetadata,
positions: torch.Tensor,
num_scheduled_tokens: int,
hidden_states: torch.Tensor,
attn_metadata: SpecDecodeMetadata,
) -> Optional[list[list[int]]]:
if not self.use_spec_decode:
# Speculative decoding is not enabled.
spec_token_ids = None
elif self.speculative_config.method == "ngram":
assert isinstance(self.drafter, NgramProposer)
spec_token_ids = self._generate_draft_token_ids(
valid_sampled_token_ids, sampling_metadata)
elif self.speculative_config.method == "eagle":
raise NotImplementedError(
"eagle method for spec decode doesn't work on vllm-ascend currently"
)
return spec_token_ids
@torch.inference_mode()
def execute_model(
self,
@@ -736,9 +884,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
if not scheduler_output.total_num_scheduled_tokens:
# Return empty ModelRunnerOuptut if there's no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
hidden_states = self._process_reqs(scheduler_output,
intermediate_tensors)
logits = self.model.compute_logits(hidden_states, None)
(attn_metadata, hidden_states, spec_decode_metadata, positions,
num_scheduled_tokens,
sample_indices) = (self._process_reqs(scheduler_output,
intermediate_tensors))
logits = self.model.compute_logits(hidden_states[sample_indices], None)
# Apply structured output bitmasks if present
if scheduler_output.grammar_bitmask is not None:
@@ -746,10 +896,35 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# Sample the next token and get logprobs if needed.
sampling_metadata = self.input_batch.sampling_metadata
sampler_output = self.sampler(
logits=logits,
sampling_metadata=sampling_metadata,
)
if spec_decode_metadata is None:
sampler_output = self.sampler(
logits=logits,
sampling_metadata=sampling_metadata,
)
else:
# When indexing with a tensor (bonus_logits_indices), PyTorch
# creates a new tensor with separate storage from the original
# logits tensor. This means any in-place operations on bonus_logits
# won't affect the original logits tensor.
bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
sampler_output = self.sampler(
logits=bonus_logits,
sampling_metadata=sampling_metadata,
)
bonus_token_ids = sampler_output.sampled_token_ids
# Just like `bonus_logits`, `target_logits` is a new tensor with
# separate storage from the original `logits` tensor. Therefore,
# it is safe to update `target_logits` in place.
target_logits = logits[spec_decode_metadata.target_logits_indices]
output_token_ids = self.rejection_sampler(
spec_decode_metadata,
None, # draft_probs
target_logits,
bonus_token_ids,
sampling_metadata,
)
sampler_output.sampled_token_ids = output_token_ids
# TODO(woosuk): The following loop can be slow since it iterates over
# the requests one by one. Optimize.
@@ -776,12 +951,29 @@ class NPUModelRunner(LoRAModelRunnerMixin):
if max_gen_len == 1:
# No spec decode tokens.
valid_sampled_token_ids = sampled_token_ids.tolist()
else:
# Includes spec decode tokens.
valid_sampled_token_ids = self.rejection_sampler.parse_output(
sampled_token_ids,
self.input_batch.vocab_size,
)
spec_token_ids = self._get_spec_token_ids(
valid_sampled_token_ids,
sampling_metadata,
scheduler_output,
spec_decode_metadata,
positions,
num_scheduled_tokens,
hidden_states,
attn_metadata,
)
model_runner_output = ModelRunnerOutput(
req_ids=self.input_batch.req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=valid_sampled_token_ids,
spec_token_ids=None,
spec_token_ids=spec_token_ids,
logprobs=logprobs_lists,
prompt_logprobs_dict={},
)
@@ -968,6 +1160,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
with DeviceMemoryProfiler() as m: # noqa: SIM117
self.model = get_model(vllm_config=self.vllm_config)
if hasattr(self, "drafter"):
logger.info("Loading drafter model...")
self.drafter.load_model(self.model)
if self.lora_config:
self.model = self.load_lora_model(self.model,
self.model_config,
@@ -1132,3 +1327,35 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# This usually takes 5~20 seconds.
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
elapsed_time, npu_graph_size / (1 << 30))
def _generate_draft_token_ids(
self,
sampled_token_ids: list[list[int]],
sampling_metadata: SamplingMetadata,
) -> list[list[int]]:
# TODO(woosuk): Optimize.
draft_token_ids: list[list[int]] = []
for i, sampled_ids in enumerate(sampled_token_ids):
num_sampled_ids = len(sampled_ids)
if not num_sampled_ids:
# Skip speculative decoding.
draft_token_ids.append([])
continue
# Skip requests that require top-p, top-k, etc.
req_id = self.input_batch.req_ids[i]
if not is_spec_decode_supported(req_id, self.input_batch):
draft_token_ids.append([])
continue
# Add sampled_token_ids to token_ids_cpu.
start_idx = self.input_batch.num_tokens_no_spec[i]
end_idx = start_idx + num_sampled_ids
self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids
drafter_output = self.drafter.propose(
self.input_batch.token_ids_cpu[i, :end_idx])
if drafter_output is None or len(drafter_output) == 0:
draft_token_ids.append([])
else:
draft_token_ids.append(drafter_output.tolist())
return draft_token_ids