Spec decode support for V1 Engine (#874)
<!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> Make spec decode support for V1 Engine - Currently, Ascend does not support the triton kernel. PyTorch is used to rewrite the `rejection_sampler.py` triton kernel. However, PyTorch is not as good as Triton. Therefore, ascend c is used to implement the function in the future. - Currently, spec decode supports only the ngram algorithm. The eagle algorithm needs to be further adapted. ### Does this PR introduce _any_ user-facing change? <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> Not change user facing. ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> test by `tests/singlecard/spec_decode/e2e/test_v1_spec_decode.py` and `tests/sample/test_rejection_sampler.py`, test base function of rejection sampler and e2e function of spec decode. Signed-off-by: ponix-j <657511300@qq.com>
This commit is contained in:
@@ -47,7 +47,12 @@ from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheSpec)
|
||||
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.sampler import Sampler
|
||||
from vllm.v1.spec_decode.eagle import EagleProposer
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||
from vllm.v1.spec_decode.utils import is_spec_decode_supported
|
||||
from vllm.v1.utils import bind_kv_cache
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
@@ -55,6 +60,7 @@ from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
from vllm_ascend.attention.attention import AttentionMaskBuilder
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -110,6 +116,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.model_config = vllm_config.model_config
|
||||
self.lora_config = vllm_config.lora_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.chunked_prefill_enabled = vllm_config.scheduler_config.chunked_prefill_enabled
|
||||
self.device = device
|
||||
|
||||
@@ -202,6 +209,21 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# req_id -> (input_id -> encoder_output)
|
||||
self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {}
|
||||
|
||||
# Set up speculative decoding.
|
||||
self.use_spec_decode = False
|
||||
if self.speculative_config:
|
||||
self.use_spec_decode = True
|
||||
if get_pp_group().is_last_rank:
|
||||
if self.speculative_config.method == "ngram":
|
||||
self.drafter = NgramProposer(self.vllm_config)
|
||||
elif self.speculative_config.method == "eagle":
|
||||
self.drafter = EagleProposer(self.vllm_config,
|
||||
self.device) # type: ignore
|
||||
else:
|
||||
raise ValueError("Unknown speculative decoding method: "
|
||||
f"{self.speculative_config.method}")
|
||||
self.rejection_sampler = AscendRejectionSampler()
|
||||
|
||||
# Request states.
|
||||
self.requests: Dict[str, CachedRequestState] = {}
|
||||
# Persistent batch.
|
||||
@@ -511,7 +533,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> tuple[SpecDecodeMetadata, torch.Tensor, SpecDecodeMetadata,
|
||||
torch.Tensor, int, torch.Tensor]:
|
||||
# Check input valid
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
assert total_num_scheduled_tokens > 0
|
||||
@@ -523,6 +546,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(
|
||||
total_num_scheduled_tokens)
|
||||
else:
|
||||
# Eager mode.
|
||||
num_input_tokens = total_num_scheduled_tokens
|
||||
|
||||
modified_batch = self.attn_metadata_builder.reorder_batch(
|
||||
@@ -615,6 +639,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
common_prefix_len=None,
|
||||
**extra_builder_kwargs,
|
||||
)
|
||||
attn_metadata.num_input_tokens = num_input_tokens
|
||||
|
||||
# Prepare input_ids
|
||||
token_indices = (positions_np +
|
||||
@@ -670,7 +695,106 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
return hidden_states[sample_indices]
|
||||
use_spec_decode = len(
|
||||
scheduler_output.scheduled_spec_decode_tokens) > 0
|
||||
if not use_spec_decode:
|
||||
# NOTE(woosuk): Due to chunked prefills, the batch may contain
|
||||
# partial requests. While we should not sample any token
|
||||
# from these partial requests, we do so for simplicity.
|
||||
# We will ignore the sampled tokens from the partial requests.
|
||||
# TODO: Support prompt logprobs.
|
||||
spec_decode_metadata = None
|
||||
else:
|
||||
# Get the number of draft tokens for each request.
|
||||
# Iterate over the dictionary rather than all requests since not all
|
||||
# requests have draft tokens.
|
||||
num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
|
||||
for req_id, draft_token_ids in (
|
||||
scheduler_output.scheduled_spec_decode_tokens.items()):
|
||||
req_idx = self.input_batch.req_id_to_index[req_id]
|
||||
num_draft_tokens[req_idx] = len(draft_token_ids)
|
||||
|
||||
spec_decode_metadata = self._calc_spec_decode_metadata(
|
||||
num_draft_tokens, cu_num_tokens)
|
||||
sample_indices = spec_decode_metadata.logits_indices
|
||||
|
||||
return (attn_metadata, hidden_states, spec_decode_metadata, positions,
|
||||
total_num_scheduled_tokens, sample_indices)
|
||||
|
||||
def _calc_spec_decode_metadata(
|
||||
self,
|
||||
num_draft_tokens: np.ndarray,
|
||||
cu_num_scheduled_tokens: np.ndarray,
|
||||
) -> SpecDecodeMetadata:
|
||||
# Inputs:
|
||||
# cu_num_scheduled_tokens: [ 4, 104, 107, 207, 209]
|
||||
# num_draft_tokens: [ 3, 0, 2, 0, 1]
|
||||
# Outputs:
|
||||
# cu_num_draft_tokens: [ 3, 3, 5, 5, 6]
|
||||
# logits_indices: [ 0, 1, 2, 3, 103, 104, 105, 106,
|
||||
# 206, 207, 208]
|
||||
# target_logits_indices: [ 0, 1, 2, 5, 6, 9]
|
||||
# bonus_logits_indices: [ 3, 4, 7, 8, 10]
|
||||
|
||||
# Compute the logits indices.
|
||||
# [4, 1, 3, 1, 2]
|
||||
num_sampled_tokens = num_draft_tokens + 1
|
||||
# Step 1. [4, 5, 8, 9, 11]
|
||||
cu_num_sampled_tokens = np.cumsum(num_sampled_tokens, dtype=np.int32)
|
||||
total_num_sampled_tokens = cu_num_sampled_tokens[-1]
|
||||
# Step 2. [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
|
||||
cumsums_offsets = np.repeat(cu_num_sampled_tokens - num_sampled_tokens,
|
||||
num_sampled_tokens)
|
||||
# Step 3. [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
|
||||
arange = self.arange_np[:total_num_sampled_tokens] - cumsums_offsets
|
||||
# Step 4. [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207]
|
||||
logits_indices = np.repeat(
|
||||
cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens)
|
||||
# Step 5. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
|
||||
logits_indices += arange
|
||||
|
||||
# Compute the bonus logits indices.
|
||||
bonus_logits_indices = cu_num_sampled_tokens - 1
|
||||
|
||||
# Compute the draft logits indices.
|
||||
# [3, 3, 5, 5, 6]
|
||||
cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32)
|
||||
total_num_draft_tokens = cu_num_draft_tokens[-1]
|
||||
# [0, 0, 0, 3, 3, 5]
|
||||
cumsums_offsets = np.repeat(cu_num_draft_tokens - num_draft_tokens,
|
||||
num_draft_tokens)
|
||||
# [0, 1, 2, 0, 1, 0]
|
||||
arange = self.arange_np[:total_num_draft_tokens] - cumsums_offsets
|
||||
# [0, 0, 0, 5, 5, 9]
|
||||
target_logits_indices = np.repeat(
|
||||
cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens)
|
||||
# [0, 1, 2, 5, 6, 9]
|
||||
target_logits_indices += arange
|
||||
|
||||
# TODO: Optimize the CPU -> NPU copy.
|
||||
cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to(
|
||||
self.device, non_blocking=True)
|
||||
logits_indices = torch.from_numpy(logits_indices).to(self.device,
|
||||
non_blocking=True)
|
||||
target_logits_indices = torch.from_numpy(target_logits_indices).to(
|
||||
self.device, non_blocking=True)
|
||||
bonus_logits_indices = torch.from_numpy(bonus_logits_indices).to(
|
||||
self.device, non_blocking=True)
|
||||
|
||||
# Compute the draft token ids.
|
||||
# draft_token_indices: [ 1, 2, 3, 105, 106, 208]
|
||||
draft_token_ids = self.input_ids[logits_indices]
|
||||
draft_token_ids = draft_token_ids[target_logits_indices + 1]
|
||||
|
||||
metadata = SpecDecodeMetadata(
|
||||
draft_token_ids=draft_token_ids,
|
||||
num_draft_tokens=num_draft_tokens.tolist(),
|
||||
cu_num_draft_tokens=cu_num_draft_tokens,
|
||||
target_logits_indices=target_logits_indices,
|
||||
bonus_logits_indices=bonus_logits_indices,
|
||||
logits_indices=logits_indices,
|
||||
)
|
||||
return metadata
|
||||
|
||||
def apply_grammar_bitmask(
|
||||
self,
|
||||
@@ -726,6 +850,30 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
return logits.to(self.device).to(logits_dtype)
|
||||
|
||||
def _get_spec_token_ids(
|
||||
self,
|
||||
valid_sampled_token_ids: list[list[int]],
|
||||
sampling_metadata: SamplingMetadata,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
spec_decode_metadata: SpecDecodeMetadata,
|
||||
positions: torch.Tensor,
|
||||
num_scheduled_tokens: int,
|
||||
hidden_states: torch.Tensor,
|
||||
attn_metadata: SpecDecodeMetadata,
|
||||
) -> Optional[list[list[int]]]:
|
||||
if not self.use_spec_decode:
|
||||
# Speculative decoding is not enabled.
|
||||
spec_token_ids = None
|
||||
elif self.speculative_config.method == "ngram":
|
||||
assert isinstance(self.drafter, NgramProposer)
|
||||
spec_token_ids = self._generate_draft_token_ids(
|
||||
valid_sampled_token_ids, sampling_metadata)
|
||||
elif self.speculative_config.method == "eagle":
|
||||
raise NotImplementedError(
|
||||
"eagle method for spec decode doesn't work on vllm-ascend currently"
|
||||
)
|
||||
return spec_token_ids
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
@@ -736,9 +884,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
if not scheduler_output.total_num_scheduled_tokens:
|
||||
# Return empty ModelRunnerOuptut if there's no work to do.
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
hidden_states = self._process_reqs(scheduler_output,
|
||||
intermediate_tensors)
|
||||
logits = self.model.compute_logits(hidden_states, None)
|
||||
(attn_metadata, hidden_states, spec_decode_metadata, positions,
|
||||
num_scheduled_tokens,
|
||||
sample_indices) = (self._process_reqs(scheduler_output,
|
||||
intermediate_tensors))
|
||||
logits = self.model.compute_logits(hidden_states[sample_indices], None)
|
||||
|
||||
# Apply structured output bitmasks if present
|
||||
if scheduler_output.grammar_bitmask is not None:
|
||||
@@ -746,10 +896,35 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
# Sample the next token and get logprobs if needed.
|
||||
sampling_metadata = self.input_batch.sampling_metadata
|
||||
sampler_output = self.sampler(
|
||||
logits=logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
if spec_decode_metadata is None:
|
||||
sampler_output = self.sampler(
|
||||
logits=logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
else:
|
||||
# When indexing with a tensor (bonus_logits_indices), PyTorch
|
||||
# creates a new tensor with separate storage from the original
|
||||
# logits tensor. This means any in-place operations on bonus_logits
|
||||
# won't affect the original logits tensor.
|
||||
bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
|
||||
sampler_output = self.sampler(
|
||||
logits=bonus_logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
bonus_token_ids = sampler_output.sampled_token_ids
|
||||
|
||||
# Just like `bonus_logits`, `target_logits` is a new tensor with
|
||||
# separate storage from the original `logits` tensor. Therefore,
|
||||
# it is safe to update `target_logits` in place.
|
||||
target_logits = logits[spec_decode_metadata.target_logits_indices]
|
||||
output_token_ids = self.rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
None, # draft_probs
|
||||
target_logits,
|
||||
bonus_token_ids,
|
||||
sampling_metadata,
|
||||
)
|
||||
sampler_output.sampled_token_ids = output_token_ids
|
||||
|
||||
# TODO(woosuk): The following loop can be slow since it iterates over
|
||||
# the requests one by one. Optimize.
|
||||
@@ -776,12 +951,29 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
if max_gen_len == 1:
|
||||
# No spec decode tokens.
|
||||
valid_sampled_token_ids = sampled_token_ids.tolist()
|
||||
else:
|
||||
# Includes spec decode tokens.
|
||||
valid_sampled_token_ids = self.rejection_sampler.parse_output(
|
||||
sampled_token_ids,
|
||||
self.input_batch.vocab_size,
|
||||
)
|
||||
|
||||
spec_token_ids = self._get_spec_token_ids(
|
||||
valid_sampled_token_ids,
|
||||
sampling_metadata,
|
||||
scheduler_output,
|
||||
spec_decode_metadata,
|
||||
positions,
|
||||
num_scheduled_tokens,
|
||||
hidden_states,
|
||||
attn_metadata,
|
||||
)
|
||||
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=self.input_batch.req_ids,
|
||||
req_id_to_index=self.input_batch.req_id_to_index,
|
||||
sampled_token_ids=valid_sampled_token_ids,
|
||||
spec_token_ids=None,
|
||||
spec_token_ids=spec_token_ids,
|
||||
logprobs=logprobs_lists,
|
||||
prompt_logprobs_dict={},
|
||||
)
|
||||
@@ -968,6 +1160,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
with DeviceMemoryProfiler() as m: # noqa: SIM117
|
||||
self.model = get_model(vllm_config=self.vllm_config)
|
||||
if hasattr(self, "drafter"):
|
||||
logger.info("Loading drafter model...")
|
||||
self.drafter.load_model(self.model)
|
||||
if self.lora_config:
|
||||
self.model = self.load_lora_model(self.model,
|
||||
self.model_config,
|
||||
@@ -1132,3 +1327,35 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# This usually takes 5~20 seconds.
|
||||
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
|
||||
elapsed_time, npu_graph_size / (1 << 30))
|
||||
|
||||
def _generate_draft_token_ids(
|
||||
self,
|
||||
sampled_token_ids: list[list[int]],
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> list[list[int]]:
|
||||
# TODO(woosuk): Optimize.
|
||||
draft_token_ids: list[list[int]] = []
|
||||
for i, sampled_ids in enumerate(sampled_token_ids):
|
||||
num_sampled_ids = len(sampled_ids)
|
||||
if not num_sampled_ids:
|
||||
# Skip speculative decoding.
|
||||
draft_token_ids.append([])
|
||||
continue
|
||||
|
||||
# Skip requests that require top-p, top-k, etc.
|
||||
req_id = self.input_batch.req_ids[i]
|
||||
if not is_spec_decode_supported(req_id, self.input_batch):
|
||||
draft_token_ids.append([])
|
||||
continue
|
||||
|
||||
# Add sampled_token_ids to token_ids_cpu.
|
||||
start_idx = self.input_batch.num_tokens_no_spec[i]
|
||||
end_idx = start_idx + num_sampled_ids
|
||||
self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids
|
||||
drafter_output = self.drafter.propose(
|
||||
self.input_batch.token_ids_cpu[i, :end_idx])
|
||||
if drafter_output is None or len(drafter_output) == 0:
|
||||
draft_token_ids.append([])
|
||||
else:
|
||||
draft_token_ids.append(drafter_output.tolist())
|
||||
return draft_token_ids
|
||||
|
||||
Reference in New Issue
Block a user