Sync from v0.13

This commit is contained in:
2026-01-19 10:38:50 +08:00
parent b2ef04d792
commit 5aef6c175a
3714 changed files with 854317 additions and 89342 deletions

View File

1331
vllm/v1/spec_decode/eagle.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,73 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models.interfaces import is_mixture_of_experts
from vllm.v1.sample.metadata import SamplingMetadata
# Initialize logger
logger = init_logger(__name__)
class MedusaProposer:
"""
Medusa proposer class for generating token sequences
"""
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
):
# Save config parameters
self.vllm_config = vllm_config
self.device = device
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
self.hidden_size = (
vllm_config.speculative_config.draft_model_config.get_hidden_size()
)
self.dtype = vllm_config.model_config.dtype
def propose(
self,
target_hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
# Generate blocks and compute logits
blocks = self.model(target_hidden_states)
logits = self.model.compute_logits(blocks)
# Compute argmax for each Medusa head and stack into a single tensor
# Shape: [batch_size, num_heads]
draft_tokens = torch.stack([logit.argmax(dim=-1) for logit in logits], dim=1)
return draft_tokens
def load_model(self, target_model: nn.Module) -> None:
from vllm.compilation.backends import set_model_tag
with set_model_tag("medusa_head"):
self.model = get_model(
vllm_config=self.vllm_config,
model_config=self.vllm_config.speculative_config.draft_model_config,
)
assert not (
is_mixture_of_experts(self.model)
and self.vllm_config.parallel_config.enable_eplb
), "EPLB for Medusa is not supported"
@torch.inference_mode()
def dummy_run(self, num_tokens: int) -> None:
hidden_states = torch.zeros(
(self.max_num_tokens, self.hidden_size),
dtype=self.dtype,
device=self.device,
)
with set_forward_context(None, self.vllm_config, num_tokens=num_tokens):
self.model(hidden_states)

View File

@@ -0,0 +1,66 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import numpy as np
import torch
@dataclass
class SpecDecodeMetadata:
# [num_tokens]
draft_token_ids: torch.Tensor
# [batch_size]
num_draft_tokens: list[int]
# [batch_size]
cu_num_draft_tokens: torch.Tensor
# [batch_size]
cu_num_sampled_tokens: torch.Tensor
# [num_tokens]
target_logits_indices: torch.Tensor
# [batch_size]
bonus_logits_indices: torch.Tensor
# [num_tokens + batch_size]
logits_indices: torch.Tensor
def __post_init__(self):
self.max_spec_len = max(self.num_draft_tokens)
@classmethod
def make_dummy(
cls,
draft_token_ids: list[list[int]],
device: torch.device,
) -> "SpecDecodeMetadata":
batch_size = len(draft_token_ids)
num_draft_tokens = [len(ids) for ids in draft_token_ids]
num_sampled_tokens = [len(ids) + 1 for ids in draft_token_ids]
flattened_draft_token_ids = sum(draft_token_ids, [])
num_tokens = len(flattened_draft_token_ids)
draft_token_ids_tensor = torch.tensor(
flattened_draft_token_ids, dtype=torch.int32, device=device
)
cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32)
cu_num_draft_tokens_tensor = torch.from_numpy(cu_num_draft_tokens).to(device)
cu_num_sampled_tokens = np.cumsum(num_sampled_tokens, dtype=np.int32)
cu_num_sampled_tokens_tensor = torch.from_numpy(cu_num_sampled_tokens).to(
device
)
target_logits_indices = torch.zeros(
num_tokens, dtype=torch.int32, device=device
)
bonus_logits_indices = torch.zeros(batch_size, dtype=torch.int32, device=device)
logits_indices = torch.zeros(
num_tokens + batch_size, dtype=torch.int32, device=device
)
return cls(
draft_token_ids=draft_token_ids_tensor,
num_draft_tokens=num_draft_tokens,
cu_num_draft_tokens=cu_num_draft_tokens_tensor,
cu_num_sampled_tokens=cu_num_sampled_tokens_tensor,
target_logits_indices=target_logits_indices,
bonus_logits_indices=bonus_logits_indices,
logits_indices=logits_indices,
)

View File

@@ -0,0 +1,225 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from dataclasses import dataclass, field
import numpy as np
import prometheus_client
from vllm.config import SpeculativeConfig
from vllm.logger import init_logger
logger = init_logger(__name__)
@dataclass
class SpecDecodingStats:
"""Per-step iteration decoding stats from scheduler.
Each scheduler step, statistics on spec decoding performance are
aggregated across requests by the scheduler and returned to the
frontend in EngineCoreOutputs->SchedulerStats.
"""
num_spec_tokens: int
num_drafts: int = 0
num_draft_tokens: int = 0
num_accepted_tokens: int = 0
num_accepted_tokens_per_pos: list[int] = field(default_factory=list)
@classmethod
def new(cls, num_spec_tokens: int) -> "SpecDecodingStats":
return cls(
num_spec_tokens=num_spec_tokens,
num_accepted_tokens_per_pos=[0] * num_spec_tokens,
)
def observe_draft(self, num_draft_tokens: int, num_accepted_tokens: int):
self.num_drafts += 1
self.num_draft_tokens += num_draft_tokens
self.num_accepted_tokens += num_accepted_tokens
assert num_accepted_tokens <= self.num_spec_tokens
for i in range(num_accepted_tokens):
self.num_accepted_tokens_per_pos[i] += 1
class SpecDecodingLogging:
"""Aggregate and log spec decoding metrics.
LoggingStatLogger aggregates per-iteration metrics over a set
time interval using observe() and then logs them using log()
before resetting to zero.
"""
def __init__(self):
self.reset()
def reset(self):
self.num_drafts: list[int] = []
self.num_draft_tokens: list[int] = []
self.num_accepted_tokens: list[int] = []
self.accepted_tokens_per_pos_lists: list[list[int]] = []
self.last_log_time = time.monotonic()
def observe(self, spec_decoding_stats: SpecDecodingStats):
self.num_drafts.append(spec_decoding_stats.num_drafts)
self.num_draft_tokens.append(spec_decoding_stats.num_draft_tokens)
self.num_accepted_tokens.append(spec_decoding_stats.num_accepted_tokens)
self.accepted_tokens_per_pos_lists.append(
spec_decoding_stats.num_accepted_tokens_per_pos
)
def log(self, log_fn=logger.info):
if not self.num_drafts:
return
num_drafts = np.sum(self.num_drafts)
num_draft_tokens = np.sum(self.num_draft_tokens)
num_accepted_tokens = np.sum(self.num_accepted_tokens)
draft_throughput = 0
accepted_throughput = 0
elapsed_time = time.monotonic() - self.last_log_time
if elapsed_time > 0:
draft_throughput = num_draft_tokens / elapsed_time
accepted_throughput = num_accepted_tokens / elapsed_time
draft_acceptance_rate = (
num_accepted_tokens / num_draft_tokens * 100
if num_draft_tokens > 0
else float("nan")
)
# Conventionally, mean acceptance length includes the bonus token
mean_acceptance_length = 1 + (num_accepted_tokens / num_drafts)
pos_matrix = np.array(self.accepted_tokens_per_pos_lists)
acceptance_rates = np.sum(pos_matrix, axis=0) / num_drafts
rates_str = ", ".join(f"{p:.3f}" for p in acceptance_rates)
log_fn(
"SpecDecoding metrics: "
"Mean acceptance length: %.2f, "
"Accepted throughput: %.2f tokens/s, "
"Drafted throughput: %.2f tokens/s, "
"Accepted: %d tokens, "
"Drafted: %d tokens, "
"Per-position acceptance rate: %s, "
"Avg Draft acceptance rate: %.1f%%",
mean_acceptance_length,
accepted_throughput,
draft_throughput,
num_accepted_tokens,
num_draft_tokens,
rates_str,
draft_acceptance_rate,
)
self.reset()
class SpecDecodingProm:
"""Record spec decoding metrics in Prometheus.
The acceptance rate can be calculated using a PromQL query:
rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) /
rate(vllm:spec_decode_num_draft_tokens_total[$interval])
The mean acceptance length (conventionally including bonus tokens)
can be calculated using:
1 + (
rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) /
rate(vllm:spec_decode_num_drafts[$interval]))
A per-position acceptance rate vector can be computed using
vllm:spec_decode_num_accepted_tokens_per_pos[$interval] /
vllm:spec_decode_num_drafts[$interval]
"""
_counter_cls = prometheus_client.Counter
def __init__(
self,
speculative_config: SpeculativeConfig | None,
labelnames: list[str],
per_engine_labelvalues: dict[int, list[object]],
):
self.spec_decoding_enabled = speculative_config is not None
if not self.spec_decoding_enabled:
return
counter_drafts = self._counter_cls(
name="vllm:spec_decode_num_drafts",
documentation="Number of spec decoding drafts.",
labelnames=labelnames,
)
self.counter_spec_decode_num_drafts = make_per_engine(
counter_drafts, per_engine_labelvalues
)
counter_draft_tokens = self._counter_cls(
name="vllm:spec_decode_num_draft_tokens",
documentation="Number of draft tokens.",
labelnames=labelnames,
)
self.counter_spec_decode_num_draft_tokens = make_per_engine(
counter_draft_tokens, per_engine_labelvalues
)
counter_accepted_tokens = self._counter_cls(
name="vllm:spec_decode_num_accepted_tokens",
documentation="Number of accepted tokens.",
labelnames=labelnames,
)
self.counter_spec_decode_num_accepted_tokens = make_per_engine(
counter_accepted_tokens, per_engine_labelvalues
)
assert speculative_config is not None
num_spec_tokens = (
speculative_config.num_speculative_tokens
if self.spec_decoding_enabled
else 0
)
pos_labelnames = labelnames + ["position"]
base_counter = self._counter_cls(
name="vllm:spec_decode_num_accepted_tokens_per_pos",
documentation="Accepted tokens per draft position.",
labelnames=pos_labelnames,
)
self.counter_spec_decode_num_accepted_tokens_per_pos: dict[
int, list[prometheus_client.Counter]
] = {
idx: [base_counter.labels(*lv, str(pos)) for pos in range(num_spec_tokens)]
for idx, lv in per_engine_labelvalues.items()
}
def observe(self, spec_decoding_stats: SpecDecodingStats, engine_idx: int = 0):
if not self.spec_decoding_enabled:
return
self.counter_spec_decode_num_drafts[engine_idx].inc(
spec_decoding_stats.num_drafts
)
self.counter_spec_decode_num_draft_tokens[engine_idx].inc(
spec_decoding_stats.num_draft_tokens
)
self.counter_spec_decode_num_accepted_tokens[engine_idx].inc(
spec_decoding_stats.num_accepted_tokens
)
for pos, counter in enumerate(
self.counter_spec_decode_num_accepted_tokens_per_pos[engine_idx]
):
counter.inc(spec_decoding_stats.num_accepted_tokens_per_pos[pos])
def make_per_engine(
counter: prometheus_client.Counter,
per_engine_labelvalues: dict[int, list[object]],
):
"""Create a counter for each label value."""
return {
idx: counter.labels(*labelvalues)
for idx, labelvalues in per_engine_labelvalues.items()
}

View File

@@ -0,0 +1,291 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import numpy as np
from numba import get_num_threads, jit, njit, prange, set_num_threads
from vllm.config import VllmConfig
class NgramProposer:
def __init__(self, vllm_config: VllmConfig):
assert vllm_config.speculative_config is not None
assert vllm_config.speculative_config.prompt_lookup_min is not None
assert vllm_config.speculative_config.prompt_lookup_max is not None
# Minimum length of the n-gram to match.
self.min_n = vllm_config.speculative_config.prompt_lookup_min
# Maximum length of the n-gram to match.
self.max_n = vllm_config.speculative_config.prompt_lookup_max
# Number of tokens follow the match. If there are less than k
# tokens follow the match, we will return the maximum amount of
# tokens until the end.
self.k = vllm_config.speculative_config.num_speculative_tokens
# Maximum length of the model.
self.max_model_len = vllm_config.model_config.max_model_len
# Pre-allocate buffers for numba batch propose.
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
self.valid_ngram_draft = np.zeros((max_num_seqs, self.k), dtype=np.int32)
self.valid_ngram_num_drafts = np.zeros((max_num_seqs), dtype=np.int32)
# Threshold of total number of tokens in the batch to enable
# multi-threading in numba batch propose.
self.num_tokens_threshold = 8192
tp_size = vllm_config.parallel_config.tensor_parallel_size
cpu_count = os.cpu_count()
# Max number of threads for numba parallel processing.
if cpu_count:
# Divide by 2 to use physical cores
# and not logical cores (hyper-threading).
# Cap the number of threads to 8 to avoid using too many threads
# since other components like frontend (incl tokenization)
# and Structured Outputs also use multiple threads.
# TODO(ekagra-ranjan): bump up the cap from 1 to 8
# when TP parallelization for ngram is implemented.
self.num_numba_thread_available = min(1, (cpu_count // 2))
# Divide by tp_size to ensure each tensor parallel rank
# has some threads since all ranks will run this.
self.num_numba_thread_available //= tp_size
else:
self.num_numba_thread_available = 1
# Trigger Numba JIT compilation for N-gram proposer.
# This usually takes less than 1 second.
self.propose(
[[]] * 1024,
[""] * 1024,
np.zeros(1024, dtype=np.int32),
np.zeros((1024, self.max_model_len), dtype=np.int32),
set(),
)
def batch_propose(
self,
num_requests: int,
valid_ngram_requests: list,
num_tokens_no_spec: np.ndarray,
token_ids_cpu: np.ndarray,
) -> list[list[int]]:
"""Batch version of ngram proposer using numba for acceleration.
Args:
valid_ngram_requests:
Set of indices of requests that need ngram proposals.
num_tokens_no_spec:
Numpy array of shape (batch_size,) representing the number
of tokens without speculative tokens for each request.
token_ids_cpu:
Numpy array of shape (batch_size, max_model_len)
representing the token IDs for each request.
Returns:
list[list[int]]:
A list where each element is a list of proposed
token IDs for the corresponding request.
"""
draft_token_ids: list[list[int]] = []
# Only run batch propose if there are requests needing ngram proposals.
# avoid calling numba function with empty list which causes error
# ValueError: cannot compute fingerprint of empty list
if num_ngram_requests := len(valid_ngram_requests):
original_num_numba_threads = get_num_threads()
# Ensure we use at least one thread.
# If total tokens is small, using multiple threads
# may slow down due to overhead.
total_tokens = np.sum(num_tokens_no_spec)
if total_tokens >= self.num_tokens_threshold:
final_num_threads = max(
1, min(self.num_numba_thread_available, num_ngram_requests)
)
set_num_threads(final_num_threads)
else:
set_num_threads(1)
batch_propose_numba(
valid_ngram_requests,
num_tokens_no_spec,
token_ids_cpu,
self.min_n,
self.max_n,
self.max_model_len,
self.k,
self.valid_ngram_draft,
self.valid_ngram_num_drafts,
)
# Restore original number of threads.
set_num_threads(original_num_numba_threads)
for i in range(num_requests):
if i in valid_ngram_requests and self.valid_ngram_num_drafts[i] > 0:
draft_token_ids.append(
self.valid_ngram_draft[i, : self.valid_ngram_num_drafts[i]].tolist()
)
else:
draft_token_ids.append([])
return draft_token_ids
def propose(
self,
sampled_token_ids: list[list[int]],
req_ids: list[str],
num_tokens_no_spec: np.ndarray,
token_ids_cpu: np.ndarray,
spec_decode_unsupported_reqs: set,
) -> list[list[int]]:
# find which requests need ngram proposals
valid_ngram_requests = []
for i, sampled_ids in enumerate(sampled_token_ids):
num_sampled_ids = len(sampled_ids)
if not num_sampled_ids:
# Skip speculative decoding.
continue
# Skip requests that require sampling parameters that are not
# supported with speculative decoding.
req_id = req_ids[i]
if req_id in spec_decode_unsupported_reqs:
continue
num_tokens = num_tokens_no_spec[i]
if num_tokens >= self.max_model_len:
# Skip requests that have already reached the max model length.
continue
valid_ngram_requests.append(i)
draft_token_ids = self.batch_propose(
len(sampled_token_ids),
valid_ngram_requests,
num_tokens_no_spec,
token_ids_cpu,
)
return draft_token_ids
def load_model(self, *args, **kwargs):
# No model to load.
pass
@njit(parallel=True)
def batch_propose_numba(
valid_ngram_requests: list,
num_tokens_no_spec: np.ndarray,
token_ids_cpu: np.ndarray,
min_n: int,
max_n: int,
max_model_len: int,
k: int,
valid_ngram_draft: np.ndarray,
valid_ngram_num_drafts: np.ndarray,
):
for i in prange(len(valid_ngram_requests)):
idx = valid_ngram_requests[i]
num_tokens = num_tokens_no_spec[idx]
context_token_ids = token_ids_cpu[idx, :num_tokens]
drafter_output = _find_longest_matched_ngram_and_propose_tokens(
origin_tokens=context_token_ids,
min_ngram=min_n,
max_ngram=max_n,
max_model_len=max_model_len,
k=k,
)
valid_ngram_num_drafts[idx] = drafter_output.shape[0]
if len(drafter_output):
valid_ngram_draft[idx, : drafter_output.shape[0]] = drafter_output
@jit(nopython=True)
def _find_longest_matched_ngram_and_propose_tokens(
origin_tokens: np.ndarray,
min_ngram: int,
max_ngram: int,
max_model_len: int,
k: int,
) -> np.ndarray:
"""
Find the longest n-gram which matches the suffix of the given tokens
whose length is within [min_ngram, max_ngram] (inclusive).
If found, we will extract k right after the matched ngram.
"""
# Do not generate draft tokens is context is shorter than minimum n-gram
total_token = origin_tokens.shape[0]
if total_token < min_ngram:
return np.empty((0,), dtype=origin_tokens.dtype)
# Do not generate draft tokens beyond the max model length.
k = min(k, max_model_len - total_token)
if k <= 0:
return np.empty((0,), dtype=origin_tokens.dtype)
# Flip tokens, and the goal become to find longest ngram
# on the rightmost position which matches the prefix with
# length [min_n, max_n] (inclusive).
tokens = origin_tokens[::-1]
# Longest prefix (not including itself) which is a suffix of
# the current position.
# lps[i] = max{v, where tokens[0:v] == tokens[i+1-v:i+1]}
#
# As ngram is capped by max_ngram to save memory, we only need to
# store lps for the first max_ngram prefix.
lps = np.zeros(max_ngram, dtype=np.int32)
longest_ngram = 0
position = 0
# lps[0] always equal to 0, we start with index 1
prev_lps = 0
i = 1
while i < total_token:
# tokens[:prev_lps] is the longest prefix as a suffix of tokens[:i]
if tokens[prev_lps] == tokens[i]:
# Token match: tokens[:prev_lps+1] is the longest prefix as
# a suffix of tokens[:i+1]
prev_lps += 1
# Check if we found a longer valid ngram.
#
# Update position when longest_ngram matched prev_lps,
# as we want to get the target n-gram of the earliest position
# in the original tokens (i.e.
# latest position in the reversed tokens)
if prev_lps >= longest_ngram:
longest_ngram = prev_lps
position = i
if i < max_ngram:
# Store LPS for the first max_ngram prefix
lps[i] = prev_lps
if prev_lps == max_ngram:
# When prev_lps reached max_ngram, update prev_lps
# to lps[max_ngram-1] to avoid matching ngram
# longer than max_ngram
prev_lps = lps[max_ngram - 1]
i += 1
elif prev_lps != 0:
# Token mismatch: try the second-longest prefix
# among all suffix of tokens[:i],
# which is the longest prefix of tokens[:prev_lps]
prev_lps = lps[prev_lps - 1]
else:
# Token mismatch, and no more prefix (except empty string)
# as a suffix of tokens[:i]
i += 1
if longest_ngram < min_ngram:
# No valid ngram is found
return np.empty((0,), dtype=origin_tokens.dtype)
# Flip the position back, so in origin_tokens,
# origin_tokens[total_token-1-position:total_token-1-position+longest_ngram]
# is the matched ngram, so we should start drafting tokens from
# total_token-1-position+longest_ngram
start_position = total_token - 1 - position + longest_ngram
k = min(k, total_token - start_position)
return origin_tokens[start_position : start_position + k]

View File

@@ -0,0 +1,101 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.config import VllmConfig
from vllm.v1.worker.gpu_input_batch import InputBatch
class SuffixDecodingProposer:
"""
Speculative decoding proposer for Suffix Decoding (https://arxiv.org/pdf/2411.04975).
This class imports and uses the official implementation from Arctic Inference
(https://github.com/snowflakedb/ArcticInference).
"""
def __init__(self, vllm_config: VllmConfig):
config = vllm_config.speculative_config
self.num_speculative_tokens = config.num_speculative_tokens
self.max_tree_depth = config.suffix_decoding_max_tree_depth
self.max_spec_factor = config.suffix_decoding_max_spec_factor
self.min_token_prob = config.suffix_decoding_min_token_prob
self.max_model_len = vllm_config.model_config.max_model_len
# Lazy import to avoid error when Suffix Decoding is not used.
from arctic_inference.suffix_decoding import SuffixDecodingCache
# Initialize and empty cache. This object will take care of caching request
# outputs, evicting old requests, and manages the per-prompt suffix trees.
self.suffix_cache = SuffixDecodingCache(
max_tree_depth=config.suffix_decoding_max_tree_depth,
max_cached_requests=config.suffix_decoding_max_cached_requests,
)
def propose(
self,
input_batch: InputBatch,
sampled_token_ids: list[list[int]],
) -> list[list[int]]:
"""
Propose speculative tokens for each request in the input batch. Suffix Decoding
will speculate a dynamic number of tokens for each request every decoding step,
so each entry in the returned list may have different lengths.
"""
draft_token_ids: list[list[int]] = []
for i, sampled_ids in enumerate(sampled_token_ids):
if not sampled_ids:
# Skip speculative decoding for partial prefills.
draft_token_ids.append([])
continue
# Skip requests that require sampling parameters that are not
# supported with speculative decoding.
req_id = input_batch.req_ids[i]
if req_id in input_batch.spec_decode_unsupported_reqs:
draft_token_ids.append([])
continue
num_tokens = input_batch.num_tokens_no_spec[i]
if num_tokens >= self.max_model_len:
# Skip requests that have already reached the max model length.
draft_token_ids.append([])
continue
index = input_batch.req_id_to_index[req_id]
if req_id not in self.suffix_cache.active_requests:
if req_id in self.suffix_cache.cached_requests:
# Reset the suffix cache for this request.
self.suffix_cache.evict_cached_response(req_id)
num_prompt_tokens = input_batch.num_prompt_tokens[index]
prompt_token_ids = input_batch.token_ids_cpu[index, :num_prompt_tokens]
# Start a new request, this will build the suffix tree for that prompt.
self.suffix_cache.start_request(req_id, prompt_token_ids)
# Append the newly sampled ids to the suffix cache for this request.
self.suffix_cache.add_active_response(req_id, sampled_ids)
# Suffix decoding only uses the most recent tokens up to max_tree_depth, so
# we extract the pattern from the end of the input.
start = max(0, num_tokens - self.max_tree_depth)
pattern = input_batch.token_ids_cpu[i, start:num_tokens]
draft = self.suffix_cache.speculate(
req_id,
pattern,
max_spec_tokens=min(
self.num_speculative_tokens, self.max_model_len - num_tokens - 1
),
max_spec_factor=self.max_spec_factor,
min_token_prob=self.min_token_prob,
)
draft_token_ids.append(draft.token_ids)
# Stop requests that were not seen in the input batch.
for req_id in (
self.suffix_cache.active_requests - input_batch.req_id_to_index.keys()
):
self.suffix_cache.stop_request(req_id)
return draft_token_ids
def load_model(self, *args, **kwargs):
# No model to load.
pass

View File

@@ -0,0 +1,121 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.sampling_params import SamplingParams
from vllm.triton_utils import tl, triton
_SAMPLING_EPS = 1e-5
def is_spec_decode_unsupported(sampling_params: SamplingParams) -> bool:
"""True if request is incompatible with speculative decoding"""
return (
sampling_params.frequency_penalty != 0.0
or sampling_params.presence_penalty != 0.0
or sampling_params.repetition_penalty != 1.0
or sampling_params.min_p > _SAMPLING_EPS
or sampling_params.logprobs is not None
)
@triton.jit
def eagle_prepare_inputs_padded_kernel(
cu_num_draft_tokens_ptr, # [num_reqs]
valid_sampled_tokens_count_ptr, # [num_reqs]
query_start_loc_gpu_ptr, # [num_reqs + 1]
token_indices_to_sample_ptr, # [num_reqs] (output)
num_reqs, # tl.int32
):
"""
Fused kernel for Eagle prepare_input_padded. This kernel computes the
token index to sample for each request, taking into account the number
of draft tokens and the number of valid sampled tokens (which is one more than
the number of accepted tokens).
"""
req_idx = tl.program_id(axis=0)
if req_idx >= num_reqs:
return
# Calculate num_draft_tokens from cu_num_draft_tokens, which is an inclusive
# cumulative sum (first entry is the first value, not zero).
cu_draft_curr = tl.load(cu_num_draft_tokens_ptr + req_idx)
num_draft_tokens = 0
if req_idx == 0:
num_draft_tokens = cu_draft_curr
else:
cu_draft_prev = tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
num_draft_tokens = cu_draft_curr - cu_draft_prev
valid_count = tl.load(valid_sampled_tokens_count_ptr + req_idx)
num_rejected_tokens = num_draft_tokens + 1 - valid_count
num_rejected_tokens = tl.where(num_draft_tokens > 0, num_rejected_tokens, 0)
# query_start_loc[req_idx + 1] is the start position of the next request,
# which is one past the last token of this request.
q_last_tok_idx = tl.load(query_start_loc_gpu_ptr + req_idx + 1) - 1
index_to_sample = q_last_tok_idx - num_rejected_tokens
tl.store(token_indices_to_sample_ptr + req_idx, index_to_sample)
@triton.jit
def eagle_prepare_next_token_padded_kernel(
sampled_token_ids_ptr, # [num_reqs, num_sampled_tokens_per_req]
discard_request_mask_ptr, # [num_reqs]
backup_next_token_ids_ptr, # [num_reqs]
next_token_ids_ptr, # [num_reqs] (output)
valid_sampled_tokens_count_ptr, # [num_reqs] (output)
vocab_size, # tl.int32
num_sampled_tokens_per_req, # tl.int32 (num_spec_tokens + 1)
num_reqs, # tl.int32
stride_sampled_token_ids, # tl.int32 (stride for dim 0)
BLOCK_SIZE_TOKENS: tl.constexpr, # Power-of-2 >= num_sampled_tokens_per_req
):
"""
Fused kernel for Eagle prepare_next_token_ids_padded. This kernel computes the
number of valid (1 + accepted) tokens for each request, and the corresponding
"next" token id to sample from during speculative decoding. This is the
"last accepted token" from the sampled tokens, or the backup token if no
tokens were accepted or if the request is marked as discarded.
"""
req_idx = tl.program_id(axis=0)
if req_idx >= num_reqs:
return
# Check if this request is discarded.
is_discarded = tl.load(discard_request_mask_ptr + req_idx)
if is_discarded:
backup_token = tl.load(backup_next_token_ids_ptr + req_idx)
valid_count = tl.full((), 0, dtype=tl.uint32)
tl.store(next_token_ids_ptr + req_idx, backup_token)
tl.store(valid_sampled_tokens_count_ptr + req_idx, valid_count)
else:
# Count the number of valid tokens among the sampled tokens.
token_offs = tl.arange(0, BLOCK_SIZE_TOKENS)
token_mask = token_offs < num_sampled_tokens_per_req
row_ptr = sampled_token_ids_ptr + req_idx * stride_sampled_token_ids
token_ids = tl.load(row_ptr + token_offs, mask=token_mask, other=-1)
# Rejected tokens are -1, valid tokens are in [0, vocab_size)
is_valid_mask = (token_ids != -1) & (token_ids < vocab_size) & token_mask
valid_count = tl.sum(is_valid_mask)
if valid_count > 0:
# Guaranteed to be well-defined since
# valid_count > 0 implies is_valid_mask is not empty
last_valid_index = tl.max(tl.where(is_valid_mask, token_offs, -1))
# Select the token at that index, using a sum trick since
# we don't want to load again to access token_ids[last_valid_index].
last_valid_token = tl.sum(
tl.where(token_offs == last_valid_index, token_ids, 0)
)
tl.store(next_token_ids_ptr + req_idx, last_valid_token)
else:
# No valid tokens found, use backup token
backup_token = tl.load(backup_next_token_ids_ptr + req_idx)
tl.store(next_token_ids_ptr + req_idx, backup_token)
tl.store(valid_sampled_tokens_count_ptr + req_idx, valid_count)