This commit is contained in:
root
2026-03-05 18:06:10 +08:00
commit 809cecae09
2569 changed files with 478204 additions and 0 deletions

View File

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

1229
v1/spec_decode/eagle.py Normal file

File diff suppressed because it is too large Load Diff

73
v1/spec_decode/medusa.py Normal file
View File

@@ -0,0 +1,73 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models.interfaces import is_mixture_of_experts
from vllm.v1.sample.metadata import SamplingMetadata
# Initialize logger
logger = init_logger(__name__)
class MedusaProposer:
"""
Medusa proposer class for generating token sequences
"""
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
):
# Save config parameters
self.vllm_config = vllm_config
self.device = device
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
self.hidden_size = (
vllm_config.speculative_config.draft_model_config.get_hidden_size()
)
self.dtype = vllm_config.model_config.dtype
def propose(
self,
target_hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> list[list[int]]:
# Generate blocks and compute logits
blocks = self.model(target_hidden_states)
logits = self.model.compute_logits(blocks)
# Get draft tokens and transpose the result
# TODO(woosuk): OPTIMIZATION: Return GPU tensor without GPU-CPU
# synchronization.
draft_tokens = [logit.argmax(dim=-1).tolist() for logit in logits]
return [list(row) for row in zip(*draft_tokens)]
def load_model(self, target_model: nn.Module) -> None:
from vllm.compilation.backends import set_model_tag
with set_model_tag("medusa_head"):
self.model = get_model(
vllm_config=self.vllm_config,
model_config=self.vllm_config.speculative_config.draft_model_config,
)
assert not (
is_mixture_of_experts(self.model)
and self.vllm_config.parallel_config.enable_eplb
), "EPLB for Medusa is not supported"
@torch.inference_mode()
def dummy_run(self, num_tokens: int) -> None:
hidden_states = torch.zeros(
(self.max_num_tokens, self.hidden_size),
dtype=self.dtype,
device=self.device,
)
with set_forward_context(None, self.vllm_config, num_tokens=num_tokens):
self.model(hidden_states)

View File

@@ -0,0 +1,66 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import numpy as np
import torch
@dataclass
class SpecDecodeMetadata:
# [num_tokens]
draft_token_ids: torch.Tensor
# [batch_size]
num_draft_tokens: list[int]
# [batch_size]
cu_num_draft_tokens: torch.Tensor
# [batch_size]
cu_num_sampled_tokens: torch.Tensor
# [num_tokens]
target_logits_indices: torch.Tensor
# [batch_size]
bonus_logits_indices: torch.Tensor
# [num_tokens + batch_size]
logits_indices: torch.Tensor
def __post_init__(self):
self.max_spec_len = max(self.num_draft_tokens)
@classmethod
def make_dummy(
cls,
draft_token_ids: list[list[int]],
device: torch.device,
) -> "SpecDecodeMetadata":
batch_size = len(draft_token_ids)
num_draft_tokens = [len(ids) for ids in draft_token_ids]
num_sampled_tokens = [len(ids) + 1 for ids in draft_token_ids]
flattened_draft_token_ids = sum(draft_token_ids, [])
num_tokens = len(flattened_draft_token_ids)
draft_token_ids_tensor = torch.tensor(
flattened_draft_token_ids, dtype=torch.int32, device=device
)
cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32)
cu_num_draft_tokens_tensor = torch.from_numpy(cu_num_draft_tokens).to(device)
cu_num_sampled_tokens = np.cumsum(num_sampled_tokens, dtype=np.int32)
cu_num_sampled_tokens_tensor = torch.from_numpy(cu_num_sampled_tokens).to(
device
)
target_logits_indices = torch.zeros(
num_tokens, dtype=torch.int32, device=device
)
bonus_logits_indices = torch.zeros(batch_size, dtype=torch.int32, device=device)
logits_indices = torch.zeros(
num_tokens + batch_size, dtype=torch.int32, device=device
)
return cls(
draft_token_ids=draft_token_ids_tensor,
num_draft_tokens=num_draft_tokens,
cu_num_draft_tokens=cu_num_draft_tokens_tensor,
cu_num_sampled_tokens=cu_num_sampled_tokens_tensor,
target_logits_indices=target_logits_indices,
bonus_logits_indices=bonus_logits_indices,
logits_indices=logits_indices,
)

224
v1/spec_decode/metrics.py Normal file
View File

@@ -0,0 +1,224 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from dataclasses import dataclass, field
import numpy as np
import prometheus_client
from vllm.config import SpeculativeConfig
from vllm.logger import init_logger
logger = init_logger(__name__)
@dataclass
class SpecDecodingStats:
"""Per-step iteration decoding stats from scheduler.
Each scheduler step, statistics on spec decoding performance are
aggregated across requests by the scheduler and returned to the
frontend in EngineCoreOutputs->SchedulerStats.
"""
num_spec_tokens: int
num_drafts: int = 0
num_draft_tokens: int = 0
num_accepted_tokens: int = 0
num_accepted_tokens_per_pos: list[int] = field(default_factory=list)
@classmethod
def new(cls, num_spec_tokens: int) -> "SpecDecodingStats":
return cls(
num_spec_tokens=num_spec_tokens,
num_accepted_tokens_per_pos=[0] * num_spec_tokens,
)
def observe_draft(self, num_draft_tokens: int, num_accepted_tokens: int):
self.num_drafts += 1
self.num_draft_tokens += num_draft_tokens
self.num_accepted_tokens += num_accepted_tokens
assert num_accepted_tokens <= self.num_spec_tokens
for i in range(num_accepted_tokens):
self.num_accepted_tokens_per_pos[i] += 1
class SpecDecodingLogging:
"""Aggregate and log spec decoding metrics.
LoggingStatLogger aggregates per-iteration metrics over a set
time interval using observe() and then logs them using log()
before resetting to zero.
"""
def __init__(self):
self.reset()
def reset(self):
self.num_drafts: list[int] = []
self.num_draft_tokens: list[int] = []
self.num_accepted_tokens: list[int] = []
self.accepted_tokens_per_pos_lists: list[list[int]] = []
self.last_log_time = time.monotonic()
def observe(self, spec_decoding_stats: SpecDecodingStats):
self.num_drafts.append(spec_decoding_stats.num_drafts)
self.num_draft_tokens.append(spec_decoding_stats.num_draft_tokens)
self.num_accepted_tokens.append(spec_decoding_stats.num_accepted_tokens)
self.accepted_tokens_per_pos_lists.append(
spec_decoding_stats.num_accepted_tokens_per_pos
)
def log(self, log_fn=logger.info):
if not self.num_drafts:
return
num_drafts = np.sum(self.num_drafts)
num_draft_tokens = np.sum(self.num_draft_tokens)
num_accepted_tokens = np.sum(self.num_accepted_tokens)
draft_throughput = 0
accepted_throughput = 0
elapsed_time = time.monotonic() - self.last_log_time
if elapsed_time > 0:
draft_throughput = num_draft_tokens / elapsed_time
accepted_throughput = num_accepted_tokens / elapsed_time
draft_acceptance_rate = (
num_accepted_tokens / num_draft_tokens * 100
if num_draft_tokens > 0
else float("nan")
)
# Conventionally, mean acceptance length includes the bonus token
mean_acceptance_length = 1 + (num_accepted_tokens / num_drafts)
pos_matrix = np.array(self.accepted_tokens_per_pos_lists)
acceptance_rates = np.sum(pos_matrix, axis=0) / num_drafts
rates_str = ", ".join(f"{p:.3f}" for p in acceptance_rates)
log_fn(
"SpecDecoding metrics: "
"Mean acceptance length: %.2f, "
"Accepted throughput: %.2f tokens/s, "
"Drafted throughput: %.2f tokens/s, "
"Accepted: %d tokens, "
"Drafted: %d tokens, "
"Per-position acceptance rate: %s, "
"Avg Draft acceptance rate: %.1f%%",
mean_acceptance_length,
accepted_throughput,
draft_throughput,
num_accepted_tokens,
num_draft_tokens,
rates_str,
draft_acceptance_rate,
)
self.reset()
class SpecDecodingProm:
"""Record spec decoding metrics in Prometheus.
The acceptance rate can be calculated using a PromQL query:
rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) /
rate(vllm:spec_decode_num_draft_tokens_total[$interval])
The mean acceptance length (conventionally including bonus tokens)
can be calculated using:
1 + (
rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) /
rate(vllm:spec_decode_num_drafts[$interval]))
A per-position acceptance rate vector can be computed using
vllm:spec_decode_num_accepted_tokens_per_pos[$interval] /
vllm:spec_decode_num_drafts[$interval]
"""
_counter_cls = prometheus_client.Counter
def __init__(
self,
speculative_config: SpeculativeConfig | None,
labelnames: list[str],
per_engine_labelvalues: dict[int, list[str]],
):
self.spec_decoding_enabled = speculative_config is not None
if not self.spec_decoding_enabled:
return
counter_drafts = self._counter_cls(
name="vllm:spec_decode_num_drafts",
documentation="Number of spec decoding drafts.",
labelnames=labelnames,
)
self.counter_spec_decode_num_drafts = make_per_engine(
counter_drafts, per_engine_labelvalues
)
counter_draft_tokens = self._counter_cls(
name="vllm:spec_decode_num_draft_tokens",
documentation="Number of draft tokens.",
labelnames=labelnames,
)
self.counter_spec_decode_num_draft_tokens = make_per_engine(
counter_draft_tokens, per_engine_labelvalues
)
counter_accepted_tokens = self._counter_cls(
name="vllm:spec_decode_num_accepted_tokens",
documentation="Number of accepted tokens.",
labelnames=labelnames,
)
self.counter_spec_decode_num_accepted_tokens = make_per_engine(
counter_accepted_tokens, per_engine_labelvalues
)
assert speculative_config is not None
num_spec_tokens = (
speculative_config.num_speculative_tokens
if self.spec_decoding_enabled
else 0
)
pos_labelnames = labelnames + ["position"]
base_counter = self._counter_cls(
name="vllm:spec_decode_num_accepted_tokens_per_pos",
documentation="Accepted tokens per draft position.",
labelnames=pos_labelnames,
)
self.counter_spec_decode_num_accepted_tokens_per_pos: dict[
int, list[prometheus_client.Counter]
] = {
idx: [base_counter.labels(*lv, str(pos)) for pos in range(num_spec_tokens)]
for idx, lv in per_engine_labelvalues.items()
}
def observe(self, spec_decoding_stats: SpecDecodingStats, engine_idx: int = 0):
if not self.spec_decoding_enabled:
return
self.counter_spec_decode_num_drafts[engine_idx].inc(
spec_decoding_stats.num_drafts
)
self.counter_spec_decode_num_draft_tokens[engine_idx].inc(
spec_decoding_stats.num_draft_tokens
)
self.counter_spec_decode_num_accepted_tokens[engine_idx].inc(
spec_decoding_stats.num_accepted_tokens
)
for pos, counter in enumerate(
self.counter_spec_decode_num_accepted_tokens_per_pos[engine_idx]
):
counter.inc(spec_decoding_stats.num_accepted_tokens_per_pos[pos])
def make_per_engine(
counter: prometheus_client.Counter, per_engine_labelvalues: dict[int, list[str]]
):
"""Create a counter for each label value."""
return {
idx: counter.labels(*labelvalues)
for idx, labelvalues in per_engine_labelvalues.items()
}

View File

@@ -0,0 +1,291 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import numpy as np
from numba import get_num_threads, jit, njit, prange, set_num_threads
from vllm.config import VllmConfig
class NgramProposer:
def __init__(self, vllm_config: VllmConfig):
assert vllm_config.speculative_config is not None
assert vllm_config.speculative_config.prompt_lookup_min is not None
assert vllm_config.speculative_config.prompt_lookup_max is not None
# Minimum length of the n-gram to match.
self.min_n = vllm_config.speculative_config.prompt_lookup_min
# Maximum length of the n-gram to match.
self.max_n = vllm_config.speculative_config.prompt_lookup_max
# Number of tokens follow the match. If there are less than k
# tokens follow the match, we will return the maximum amount of
# tokens until the end.
self.k = vllm_config.speculative_config.num_speculative_tokens
# Maximum length of the model.
self.max_model_len = vllm_config.model_config.max_model_len
# Pre-allocate buffers for numba batch propose.
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
self.valid_ngram_draft = np.zeros((max_num_seqs, self.k), dtype=np.int32)
self.valid_ngram_num_drafts = np.zeros((max_num_seqs), dtype=np.int32)
# Threshold of total number of tokens in the batch to enable
# multi-threading in numba batch propose.
self.num_tokens_threshold = 8192
tp_size = vllm_config.parallel_config.tensor_parallel_size
cpu_count = os.cpu_count()
# Max number of threads for numba parallel processing.
if cpu_count:
# Divide by 2 to use physical cores
# and not logical cores (hyper-threading).
# Cap the number of threads to 8 to avoid using too many threads
# since other components like frontend (incl tokenization)
# and Structured Outputs also use multiple threads.
# TODO(ekagra-ranjan): bump up the cap from 1 to 8
# when TP parallelization for ngram is implemented.
self.num_numba_thread_available = min(1, (cpu_count // 2))
# Divide by tp_size to ensure each tensor parallel rank
# has some threads since all ranks will run this.
self.num_numba_thread_available //= tp_size
else:
self.num_numba_thread_available = 1
# Trigger Numba JIT compilation for N-gram proposer.
# This usually takes less than 1 second.
self.propose(
[np.array([])] * 1024,
[""] * 1024,
np.zeros(1024, dtype=np.int32),
np.zeros((1024, self.max_model_len), dtype=np.int32),
set(),
)
def batch_propose(
self,
num_requests: int,
valid_ngram_requests: list,
num_tokens_no_spec: np.ndarray,
token_ids_cpu: np.ndarray,
) -> list[list[int]]:
"""Batch version of ngram proposer using numba for acceleration.
Args:
valid_ngram_requests:
Set of indices of requests that need ngram proposals.
num_tokens_no_spec:
Numpy array of shape (batch_size,) representing the number
of tokens without speculative tokens for each request.
token_ids_cpu:
Numpy array of shape (batch_size, max_model_len)
representing the token IDs for each request.
Returns:
list[list[int]]:
A list where each element is a list of proposed
token IDs for the corresponding request.
"""
draft_token_ids: list[list[int]] = []
# Only run batch propose if there are requests needing ngram proposals.
# avoid calling numba function with empty list which causes error
# ValueError: cannot compute fingerprint of empty list
if num_ngram_requests := len(valid_ngram_requests):
original_num_numba_threads = get_num_threads()
# Ensure we use at least one thread.
# If total tokens is small, using multiple threads
# may slow down due to overhead.
total_tokens = np.sum(num_tokens_no_spec)
if total_tokens >= self.num_tokens_threshold:
final_num_threads = max(
1, min(self.num_numba_thread_available, num_ngram_requests)
)
set_num_threads(final_num_threads)
else:
set_num_threads(1)
batch_propose_numba(
valid_ngram_requests,
num_tokens_no_spec,
token_ids_cpu,
self.min_n,
self.max_n,
self.max_model_len,
self.k,
self.valid_ngram_draft,
self.valid_ngram_num_drafts,
)
# Restore original number of threads.
set_num_threads(original_num_numba_threads)
for i in range(num_requests):
if i in valid_ngram_requests and self.valid_ngram_num_drafts[i] > 0:
draft_token_ids.append(
self.valid_ngram_draft[i, : self.valid_ngram_num_drafts[i]].tolist()
)
else:
draft_token_ids.append([])
return draft_token_ids
def propose(
self,
sampled_token_ids: list[np.ndarray],
req_ids: list[str],
num_tokens_no_spec: np.ndarray,
token_ids_cpu: np.ndarray,
spec_decode_unsupported_reqs: set,
) -> list[list[int]]:
# find which requests need ngram proposals
valid_ngram_requests = []
for i, sampled_ids in enumerate(sampled_token_ids):
num_sampled_ids = sampled_ids.shape[0]
if not num_sampled_ids:
# Skip speculative decoding.
continue
# Skip requests that require sampling parameters that are not
# supported with speculative decoding.
req_id = req_ids[i]
if req_id in spec_decode_unsupported_reqs:
continue
num_tokens = num_tokens_no_spec[i]
if num_tokens >= self.max_model_len:
# Skip requests that have already reached the max model length.
continue
valid_ngram_requests.append(i)
draft_token_ids = self.batch_propose(
len(sampled_token_ids),
valid_ngram_requests,
num_tokens_no_spec,
token_ids_cpu,
)
return draft_token_ids
def load_model(self, *args, **kwargs):
# No model to load.
pass
@njit(parallel=True)
def batch_propose_numba(
valid_ngram_requests: list,
num_tokens_no_spec: np.ndarray,
token_ids_cpu: np.ndarray,
min_n: int,
max_n: int,
max_model_len: int,
k: int,
valid_ngram_draft: np.ndarray,
valid_ngram_num_drafts: np.ndarray,
):
for i in prange(len(valid_ngram_requests)):
idx = valid_ngram_requests[i]
num_tokens = num_tokens_no_spec[idx]
context_token_ids = token_ids_cpu[idx, :num_tokens]
drafter_output = _find_longest_matched_ngram_and_propose_tokens(
origin_tokens=context_token_ids,
min_ngram=min_n,
max_ngram=max_n,
max_model_len=max_model_len,
k=k,
)
valid_ngram_num_drafts[i] = drafter_output.shape[0]
if len(drafter_output):
valid_ngram_draft[i, : drafter_output.shape[0]] = drafter_output
@jit(nopython=True)
def _find_longest_matched_ngram_and_propose_tokens(
origin_tokens: np.ndarray,
min_ngram: int,
max_ngram: int,
max_model_len: int,
k: int,
) -> np.ndarray:
"""
Find the longest n-gram which matches the suffix of the given tokens
whose length is within [min_ngram, max_ngram] (inclusive).
If found, we will extract k right after the matched ngram.
"""
# Do not generate draft tokens is context is shorter than minimum n-gram
total_token = origin_tokens.shape[0]
if total_token < min_ngram:
return np.empty((0,), dtype=origin_tokens.dtype)
# Do not generate draft tokens beyond the max model length.
k = min(k, max_model_len - total_token)
if k <= 0:
return np.empty((0,), dtype=origin_tokens.dtype)
# Flip tokens, and the goal become to find longest ngram
# on the rightmost position which matches the prefix with
# length [min_n, max_n] (inclusive).
tokens = origin_tokens[::-1]
# Longest prefix (not including itself) which is a suffix of
# the current position.
# lps[i] = max{v, where tokens[0:v] == tokens[i+1-v:i+1]}
#
# As ngram is capped by max_ngram to save memory, we only need to
# store lps for the first max_ngram prefix.
lps = np.zeros(max_ngram, dtype=np.int32)
longest_ngram = 0
position = 0
# lps[0] always equal to 0, we start with index 1
prev_lps = 0
i = 1
while i < total_token:
# tokens[:prev_lps] is the longest prefix as a suffix of tokens[:i]
if tokens[prev_lps] == tokens[i]:
# Token match: tokens[:prev_lps+1] is the longest prefix as
# a suffix of tokens[:i+1]
prev_lps += 1
# Check if we found a longer valid ngram.
#
# Update position when longest_ngram matched prev_lps,
# as we want to get the target n-gram of the earliest position
# in the original tokens (i.e.
# latest position in the reversed tokens)
if prev_lps >= longest_ngram:
longest_ngram = prev_lps
position = i
if i < max_ngram:
# Store LPS for the first max_ngram prefix
lps[i] = prev_lps
if prev_lps == max_ngram:
# When prev_lps reached max_ngram, update prev_lps
# to lps[max_ngram-1] to avoid matching ngram
# longer than max_ngram
prev_lps = lps[max_ngram - 1]
i += 1
elif prev_lps != 0:
# Token mismatch: try the second longest prefix
# among all suffix of tokens[:i],
# which is the longest prefix of tokens[:prev_lps]
prev_lps = lps[prev_lps - 1]
else:
# Token mismatch, and no more prefix (except empty string)
# as a suffix of tokens[:i]
i += 1
if longest_ngram < min_ngram:
# No valid ngram is found
return np.empty((0,), dtype=origin_tokens.dtype)
# Flip the position back, so in origin_tokens,
# origin_tokens[total_token-1-position:total_token-1-position+longest_ngram]
# is the matched ngram, so we should start drafting tokens from
# total_token-1-position+longest_ngram
start_position = total_token - 1 - position + longest_ngram
k = min(k, total_token - start_position)
return origin_tokens[start_position : start_position + k]

View File

@@ -0,0 +1,103 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
from vllm.config import VllmConfig
from vllm.v1.worker.gpu_input_batch import InputBatch
class SuffixDecodingProposer:
"""
Speculative decoding proposer for Suffix Decoding (https://arxiv.org/pdf/2411.04975).
This class imports and uses the official implementation from Arctic Inference
(https://github.com/snowflakedb/ArcticInference).
"""
def __init__(self, vllm_config: VllmConfig):
config = vllm_config.speculative_config
self.num_speculative_tokens = config.num_speculative_tokens
self.max_tree_depth = config.suffix_decoding_max_tree_depth
self.max_spec_factor = config.suffix_decoding_max_spec_factor
self.min_token_prob = config.suffix_decoding_min_token_prob
self.max_model_len = vllm_config.model_config.max_model_len
# Lazy import to avoid error when Suffix Decoding is not used.
from arctic_inference.suffix_decoding import SuffixDecodingCache
# Initialize and empty cache. This object will take care of caching request
# outputs, evicting old requests, and manages the per-prompt suffix trees.
self.suffix_cache = SuffixDecodingCache(
max_tree_depth=config.suffix_decoding_max_tree_depth,
max_cached_requests=config.suffix_decoding_max_cached_requests,
)
def propose(
self,
input_batch: InputBatch,
sampled_token_ids: list[np.ndarray],
) -> list[list[int]]:
"""
Propose speculative tokens for each request in the input batch. Suffix Decoding
will speculate a dynamic number of tokens for each request every decoding step,
so each entry in the returned list may have different lengths.
"""
draft_token_ids: list[np.ndarray] = []
for i, sampled_ids in enumerate(sampled_token_ids):
if sampled_ids.shape[0] == 0:
# Skip speculative decoding for partial prefills.
draft_token_ids.append([])
continue
# Skip requests that require sampling parameters that are not
# supported with speculative decoding.
req_id = input_batch.req_ids[i]
if req_id in input_batch.spec_decode_unsupported_reqs:
draft_token_ids.append([])
continue
num_tokens = input_batch.num_tokens_no_spec[i]
if num_tokens >= self.max_model_len:
# Skip requests that have already reached the max model length.
draft_token_ids.append([])
continue
index = input_batch.req_id_to_index[req_id]
if req_id not in self.suffix_cache.active_requests:
if req_id in self.suffix_cache.cached_requests:
# Reset the suffix cache for this request.
self.suffix_cache.evict_cached_response(req_id)
num_prompt_tokens = input_batch.num_prompt_tokens[index]
prompt_token_ids = input_batch.token_ids_cpu[index, :num_prompt_tokens]
# Start a new request, this will build the suffix tree for that prompt.
self.suffix_cache.start_request(req_id, prompt_token_ids)
# Append the newly sampled ids to the suffix cache for this request.
self.suffix_cache.add_active_response(req_id, sampled_ids.tolist())
# Suffix decoding only uses the most recent tokens up to max_tree_depth, so
# we extract the pattern from the end of the input.
start = max(0, num_tokens - self.max_tree_depth)
pattern = input_batch.token_ids_cpu[i, start:num_tokens]
draft = self.suffix_cache.speculate(
req_id,
pattern,
max_spec_tokens=min(
self.num_speculative_tokens, self.max_model_len - num_tokens - 1
),
max_spec_factor=self.max_spec_factor,
min_token_prob=self.min_token_prob,
)
draft_token_ids.append(draft.token_ids)
# Stop requests that were not seen in the input batch.
for req_id in (
self.suffix_cache.active_requests - input_batch.req_id_to_index.keys()
):
self.suffix_cache.stop_request(req_id)
return draft_token_ids
def load_model(self, *args, **kwargs):
# No model to load.
pass

16
v1/spec_decode/utils.py Normal file
View File

@@ -0,0 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.sampling_params import SamplingParams
_SAMPLING_EPS = 1e-5
def is_spec_decode_unsupported(sampling_params: SamplingParams) -> bool:
"""True if request is incompatible with speculative decoding"""
return (
sampling_params.frequency_penalty != 0.0
or sampling_params.presence_penalty != 0.0
or sampling_params.repetition_penalty != 1.0
or sampling_params.min_p > _SAMPLING_EPS
or sampling_params.logprobs is not None
)