This commit is contained in:
root
2026-04-09 11:23:47 +08:00
parent 8082d5f4b2
commit 72387e4fa8
1885 changed files with 611521 additions and 1 deletions

View File

View File

@@ -0,0 +1,75 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch.nn as nn
from typing_extensions import override
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.v1.spec_decode.eagle import SpecDecodeBaseProposer
from vllm.v1.spec_decode.utils import create_vllm_config_for_draft_model
logger = init_logger(__name__)
class DraftModelProposer(SpecDecodeBaseProposer):
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
runner=None,
):
super().__init__(
vllm_config=vllm_config,
device=device,
pass_hidden_states_to_model=False,
runner=runner,
)
self._raise_if_vocab_size_mismatch()
self._raise_if_draft_tp_mismatch()
def _raise_if_vocab_size_mismatch(self):
self.speculative_config.verify_equal_vocab_size_if_draft_model()
def _raise_if_draft_tp_mismatch(self):
# Note(Tomas Ruiz) If we run the target model with TP > 1 and
# the draft model with TP = 1, then the different TP ranks collide.
# Specifically when all ranks compile the draft model on rank 0
# (because TP=1), then the torch compile cache is overwritten and corrupted.
# We need a mechanism like this: https://github.com/vllm-project/vllm/pull/5414
# To prevent this error, we assert that both TP sizes must be the same.
spec_cfg = self.speculative_config
tgt_tp = spec_cfg.target_parallel_config.tensor_parallel_size
draft_tp = spec_cfg.draft_parallel_config.tensor_parallel_size
if draft_tp != tgt_tp:
raise ValueError(
f"Currently, 'draft_tensor_parallel_size' and 'tensor_parallel_size' "
f"must be the same. Got {draft_tp} and {tgt_tp}. "
"Please pass 'draft_tensor_parallel_size' in the speculative_config."
)
@override
def _get_model(self) -> nn.Module:
# Draft models may be quantized or on different parallelism,
# so we load them with a modified vllm config
from vllm.compilation.backends import set_model_tag
temp_vllm_config = create_vllm_config_for_draft_model(self.vllm_config)
with set_model_tag("draft_model"):
model = get_model(
vllm_config=temp_vllm_config,
prefix="draft_model",
)
return model
@override
def _maybe_share_embeddings(self, target_language_model: nn.Module) -> None:
# Draft models don't share embeddings with the target model
pass
@override
def _maybe_share_lm_head(self, target_language_model: nn.Module) -> None:
# Draft models don't share lm_head with the target model
pass

1765
vllm/v1/spec_decode/eagle.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,78 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models.interfaces import is_mixture_of_experts
from vllm.v1.sample.metadata import SamplingMetadata
# Initialize logger
logger = init_logger(__name__)
class MedusaProposer:
"""
Medusa proposer class for generating token sequences
"""
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
):
# Save config parameters
self.vllm_config = vllm_config
assert vllm_config.speculative_config is not None, (
"Speculative config must be set"
)
self.spec_config = vllm_config.speculative_config
self.device = device
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
self.hidden_size = self.spec_config.draft_model_config.get_hidden_size()
self.dtype = vllm_config.model_config.dtype
def propose(
self,
target_hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
slot_mappings: dict[str, torch.Tensor]
| list[dict[str, torch.Tensor]]
| None = None, # unused
) -> torch.Tensor:
# Generate blocks and compute logits
blocks = self.model(target_hidden_states)
logits = self.model.compute_logits(blocks)
# Compute argmax for each Medusa head and stack into a single tensor
# Shape: [batch_size, num_heads]
draft_tokens = torch.stack([logit.argmax(dim=-1) for logit in logits], dim=1)
return draft_tokens
def load_model(self, target_model: nn.Module) -> None:
from vllm.compilation.backends import set_model_tag
with set_model_tag("medusa_head"):
self.model = get_model(
vllm_config=self.vllm_config,
model_config=self.spec_config.draft_model_config,
)
assert not (
is_mixture_of_experts(self.model)
and self.vllm_config.parallel_config.enable_eplb
), "EPLB for Medusa is not supported"
@torch.inference_mode()
def dummy_run(self, num_tokens: int) -> None:
hidden_states = torch.zeros(
(self.max_num_tokens, self.hidden_size),
dtype=self.dtype,
device=self.device,
)
with set_forward_context(None, self.vllm_config, num_tokens=num_tokens):
self.model(hidden_states)

View File

@@ -0,0 +1,66 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import numpy as np
import torch
@dataclass
class SpecDecodeMetadata:
# [num_tokens]
draft_token_ids: torch.Tensor
# [batch_size]
num_draft_tokens: list[int]
# [batch_size]
cu_num_draft_tokens: torch.Tensor
# [batch_size]
cu_num_sampled_tokens: torch.Tensor
# [num_tokens]
target_logits_indices: torch.Tensor
# [batch_size]
bonus_logits_indices: torch.Tensor
# [num_tokens + batch_size]
logits_indices: torch.Tensor
def __post_init__(self):
self.max_spec_len = max(self.num_draft_tokens)
@classmethod
def make_dummy(
cls,
draft_token_ids: list[list[int]],
device: torch.device,
) -> "SpecDecodeMetadata":
batch_size = len(draft_token_ids)
num_draft_tokens = [len(ids) for ids in draft_token_ids]
num_sampled_tokens = [len(ids) + 1 for ids in draft_token_ids]
flattened_draft_token_ids = sum(draft_token_ids, [])
num_tokens = len(flattened_draft_token_ids)
draft_token_ids_tensor = torch.tensor(
flattened_draft_token_ids, dtype=torch.int32, device=device
)
cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32)
cu_num_draft_tokens_tensor = torch.from_numpy(cu_num_draft_tokens).to(device)
cu_num_sampled_tokens = np.cumsum(num_sampled_tokens, dtype=np.int32)
cu_num_sampled_tokens_tensor = torch.from_numpy(cu_num_sampled_tokens).to(
device
)
target_logits_indices = torch.zeros(
num_tokens, dtype=torch.int32, device=device
)
bonus_logits_indices = torch.zeros(batch_size, dtype=torch.int32, device=device)
logits_indices = torch.zeros(
num_tokens + batch_size, dtype=torch.int32, device=device
)
return cls(
draft_token_ids=draft_token_ids_tensor,
num_draft_tokens=num_draft_tokens,
cu_num_draft_tokens=cu_num_draft_tokens_tensor,
cu_num_sampled_tokens=cu_num_sampled_tokens_tensor,
target_logits_indices=target_logits_indices,
bonus_logits_indices=bonus_logits_indices,
logits_indices=logits_indices,
)

View File

@@ -0,0 +1,225 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from dataclasses import dataclass, field
import numpy as np
import prometheus_client
from vllm.config import SpeculativeConfig
from vllm.logger import init_logger
logger = init_logger(__name__)
@dataclass
class SpecDecodingStats:
"""Per-step iteration decoding stats from scheduler.
Each scheduler step, statistics on spec decoding performance are
aggregated across requests by the scheduler and returned to the
frontend in EngineCoreOutputs->SchedulerStats.
"""
num_spec_tokens: int
num_drafts: int = 0
num_draft_tokens: int = 0
num_accepted_tokens: int = 0
num_accepted_tokens_per_pos: list[int] = field(default_factory=list)
@classmethod
def new(cls, num_spec_tokens: int) -> "SpecDecodingStats":
return cls(
num_spec_tokens=num_spec_tokens,
num_accepted_tokens_per_pos=[0] * num_spec_tokens,
)
def observe_draft(self, num_draft_tokens: int, num_accepted_tokens: int):
self.num_drafts += 1
self.num_draft_tokens += num_draft_tokens
self.num_accepted_tokens += num_accepted_tokens
assert num_accepted_tokens <= self.num_spec_tokens
for i in range(num_accepted_tokens):
self.num_accepted_tokens_per_pos[i] += 1
class SpecDecodingLogging:
"""Aggregate and log spec decoding metrics.
LoggingStatLogger aggregates per-iteration metrics over a set
time interval using observe() and then logs them using log()
before resetting to zero.
"""
def __init__(self):
self.reset()
def reset(self):
self.num_drafts: list[int] = []
self.num_draft_tokens: list[int] = []
self.num_accepted_tokens: list[int] = []
self.accepted_tokens_per_pos_lists: list[list[int]] = []
self.last_log_time = time.monotonic()
def observe(self, spec_decoding_stats: SpecDecodingStats):
self.num_drafts.append(spec_decoding_stats.num_drafts)
self.num_draft_tokens.append(spec_decoding_stats.num_draft_tokens)
self.num_accepted_tokens.append(spec_decoding_stats.num_accepted_tokens)
self.accepted_tokens_per_pos_lists.append(
spec_decoding_stats.num_accepted_tokens_per_pos
)
def log(self, log_fn=logger.info):
if not self.num_drafts:
return
num_drafts = np.sum(self.num_drafts)
num_draft_tokens = np.sum(self.num_draft_tokens)
num_accepted_tokens = np.sum(self.num_accepted_tokens)
draft_throughput = 0
accepted_throughput = 0
elapsed_time = time.monotonic() - self.last_log_time
if elapsed_time > 0:
draft_throughput = num_draft_tokens / elapsed_time
accepted_throughput = num_accepted_tokens / elapsed_time
draft_acceptance_rate = (
num_accepted_tokens / num_draft_tokens * 100
if num_draft_tokens > 0
else float("nan")
)
# Conventionally, mean acceptance length includes the bonus token
mean_acceptance_length = 1 + (num_accepted_tokens / num_drafts)
pos_matrix = np.array(self.accepted_tokens_per_pos_lists)
acceptance_rates = np.sum(pos_matrix, axis=0) / num_drafts
rates_str = ", ".join(f"{p:.3f}" for p in acceptance_rates)
log_fn(
"SpecDecoding metrics: "
"Mean acceptance length: %.2f, "
"Accepted throughput: %.2f tokens/s, "
"Drafted throughput: %.2f tokens/s, "
"Accepted: %d tokens, "
"Drafted: %d tokens, "
"Per-position acceptance rate: %s, "
"Avg Draft acceptance rate: %.1f%%",
mean_acceptance_length,
accepted_throughput,
draft_throughput,
num_accepted_tokens,
num_draft_tokens,
rates_str,
draft_acceptance_rate,
)
self.reset()
class SpecDecodingProm:
"""Record spec decoding metrics in Prometheus.
The acceptance rate can be calculated using a PromQL query:
rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) /
rate(vllm:spec_decode_num_draft_tokens_total[$interval])
The mean acceptance length (conventionally including bonus tokens)
can be calculated using:
1 + (
rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) /
rate(vllm:spec_decode_num_drafts[$interval]))
A per-position acceptance rate vector can be computed using
vllm:spec_decode_num_accepted_tokens_per_pos[$interval] /
vllm:spec_decode_num_drafts[$interval]
"""
_counter_cls = prometheus_client.Counter
def __init__(
self,
speculative_config: SpeculativeConfig | None,
labelnames: list[str],
per_engine_labelvalues: dict[int, list[object]],
):
self.spec_decoding_enabled = speculative_config is not None
if not self.spec_decoding_enabled:
return
counter_drafts = self._counter_cls(
name="vllm:spec_decode_num_drafts",
documentation="Number of spec decoding drafts.",
labelnames=labelnames,
)
self.counter_spec_decode_num_drafts = make_per_engine(
counter_drafts, per_engine_labelvalues
)
counter_draft_tokens = self._counter_cls(
name="vllm:spec_decode_num_draft_tokens",
documentation="Number of draft tokens.",
labelnames=labelnames,
)
self.counter_spec_decode_num_draft_tokens = make_per_engine(
counter_draft_tokens, per_engine_labelvalues
)
counter_accepted_tokens = self._counter_cls(
name="vllm:spec_decode_num_accepted_tokens",
documentation="Number of accepted tokens.",
labelnames=labelnames,
)
self.counter_spec_decode_num_accepted_tokens = make_per_engine(
counter_accepted_tokens, per_engine_labelvalues
)
assert speculative_config is not None
num_spec_tokens = (
speculative_config.num_speculative_tokens
if self.spec_decoding_enabled
else 0
)
pos_labelnames = labelnames + ["position"]
base_counter = self._counter_cls(
name="vllm:spec_decode_num_accepted_tokens_per_pos",
documentation="Accepted tokens per draft position.",
labelnames=pos_labelnames,
)
self.counter_spec_decode_num_accepted_tokens_per_pos: dict[
int, list[prometheus_client.Counter]
] = {
idx: [base_counter.labels(*lv, str(pos)) for pos in range(num_spec_tokens)]
for idx, lv in per_engine_labelvalues.items()
}
def observe(self, spec_decoding_stats: SpecDecodingStats, engine_idx: int = 0):
if not self.spec_decoding_enabled:
return
self.counter_spec_decode_num_drafts[engine_idx].inc(
spec_decoding_stats.num_drafts
)
self.counter_spec_decode_num_draft_tokens[engine_idx].inc(
spec_decoding_stats.num_draft_tokens
)
self.counter_spec_decode_num_accepted_tokens[engine_idx].inc(
spec_decoding_stats.num_accepted_tokens
)
for pos, counter in enumerate(
self.counter_spec_decode_num_accepted_tokens_per_pos[engine_idx]
):
counter.inc(spec_decoding_stats.num_accepted_tokens_per_pos[pos])
def make_per_engine(
counter: prometheus_client.Counter,
per_engine_labelvalues: dict[int, list[object]],
):
"""Create a counter for each label value."""
return {
idx: counter.labels(*labelvalues)
for idx, labelvalues in per_engine_labelvalues.items()
}

View File

@@ -0,0 +1,285 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import numpy as np
import torch
from numba import get_num_threads, jit, njit, prange, set_num_threads
from vllm.config import VllmConfig
class NgramProposer:
def __init__(self, vllm_config: VllmConfig):
assert vllm_config.speculative_config is not None
assert vllm_config.speculative_config.prompt_lookup_min is not None
assert vllm_config.speculative_config.prompt_lookup_max is not None
# Minimum length of the n-gram to match.
self.min_n = vllm_config.speculative_config.prompt_lookup_min
# Maximum length of the n-gram to match.
self.max_n = vllm_config.speculative_config.prompt_lookup_max
# Number of tokens follow the match. If there are less than k
# tokens follow the match, we will return the maximum amount of
# tokens until the end.
self.k = vllm_config.speculative_config.num_speculative_tokens
# Maximum length of the model.
self.max_model_len = vllm_config.model_config.max_model_len
# Pre-allocate buffers for numba batch propose.
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
self.valid_ngram_draft = np.zeros((max_num_seqs, self.k), dtype=np.int32)
self.valid_ngram_num_drafts = np.zeros((max_num_seqs), dtype=np.int32)
# Threshold of total number of tokens in the batch to enable
# multi-threading in numba batch propose.
self.num_tokens_threshold = 8192
tp_size = vllm_config.parallel_config.tensor_parallel_size
cpu_count = os.cpu_count()
# Max number of threads for numba parallel processing.
if cpu_count:
# Divide by 2 to use physical cores
# and not logical cores (hyper-threading).
# Cap the number of threads to 8 to avoid using too many threads
# since other components like frontend (incl tokenization)
# and Structured Outputs also use multiple threads.
# TODO(ekagra-ranjan): bump up the cap from 1 to 8
# when TP parallelization for ngram is implemented.
self.num_numba_thread_available = min(1, (cpu_count // 2))
# Divide by tp_size to ensure each tensor parallel rank
# has some threads since all ranks will run this.
self.num_numba_thread_available //= tp_size
else:
self.num_numba_thread_available = 1
# Trigger Numba JIT compilation for N-gram proposer.
# This usually takes less than 1 second.
self.propose(
[[]] * 1024,
np.zeros(1024, dtype=np.int32),
np.zeros((1024, self.max_model_len), dtype=np.int32),
)
def batch_propose(
self,
num_requests: int,
valid_ngram_requests: list,
num_tokens_no_spec: np.ndarray,
token_ids_cpu: np.ndarray,
) -> list[list[int]]:
"""Batch version of ngram proposer using numba for acceleration.
Args:
valid_ngram_requests:
Set of indices of requests that need ngram proposals.
num_tokens_no_spec:
Numpy array of shape (batch_size,) representing the number
of tokens without speculative tokens for each request.
token_ids_cpu:
Numpy array of shape (batch_size, max_model_len)
representing the token IDs for each request.
Returns:
list[list[int]]:
A list where each element is a list of proposed
token IDs for the corresponding request.
"""
draft_token_ids: list[list[int]] = []
# Only run batch propose if there are requests needing ngram proposals.
# avoid calling numba function with empty list which causes error
# ValueError: cannot compute fingerprint of empty list
if num_ngram_requests := len(valid_ngram_requests):
original_num_numba_threads = get_num_threads()
# Ensure we use at least one thread.
# If total tokens is small, using multiple threads
# may slow down due to overhead.
total_tokens = np.sum(num_tokens_no_spec)
if total_tokens >= self.num_tokens_threshold:
final_num_threads = max(
1, min(self.num_numba_thread_available, num_ngram_requests)
)
set_num_threads(final_num_threads)
else:
set_num_threads(1)
batch_propose_numba(
valid_ngram_requests,
num_tokens_no_spec,
token_ids_cpu,
self.min_n,
self.max_n,
self.max_model_len,
self.k,
self.valid_ngram_draft,
self.valid_ngram_num_drafts,
)
# Restore original number of threads.
set_num_threads(original_num_numba_threads)
for i in range(num_requests):
if i in valid_ngram_requests and self.valid_ngram_num_drafts[i] > 0:
draft_token_ids.append(
self.valid_ngram_draft[i, : self.valid_ngram_num_drafts[i]].tolist()
)
else:
draft_token_ids.append([])
return draft_token_ids
def propose(
self,
sampled_token_ids: list[list[int]],
num_tokens_no_spec: np.ndarray,
token_ids_cpu: np.ndarray,
slot_mappings: dict[str, torch.Tensor]
| list[dict[str, torch.Tensor]]
| None = None, # unused
) -> list[list[int]]:
# find which requests need ngram proposals
valid_ngram_requests = []
for i, sampled_ids in enumerate(sampled_token_ids):
num_sampled_ids = len(sampled_ids)
if not num_sampled_ids:
# Skip speculative decoding.
continue
num_tokens = num_tokens_no_spec[i]
if num_tokens >= self.max_model_len:
# Skip requests that have already reached the max model length.
continue
valid_ngram_requests.append(i)
draft_token_ids = self.batch_propose(
len(sampled_token_ids),
valid_ngram_requests,
num_tokens_no_spec,
token_ids_cpu,
)
return draft_token_ids
def load_model(self, *args, **kwargs):
# No model to load.
pass
@njit(parallel=True)
def batch_propose_numba(
valid_ngram_requests: list,
num_tokens_no_spec: np.ndarray,
token_ids_cpu: np.ndarray,
min_n: int,
max_n: int,
max_model_len: int,
k: int,
valid_ngram_draft: np.ndarray,
valid_ngram_num_drafts: np.ndarray,
):
for i in prange(len(valid_ngram_requests)):
idx = valid_ngram_requests[i]
num_tokens = num_tokens_no_spec[idx]
context_token_ids = token_ids_cpu[idx, :num_tokens]
drafter_output = _find_longest_matched_ngram_and_propose_tokens(
origin_tokens=context_token_ids,
min_ngram=min_n,
max_ngram=max_n,
max_model_len=max_model_len,
k=k,
)
valid_ngram_num_drafts[idx] = drafter_output.shape[0]
if len(drafter_output):
valid_ngram_draft[idx, : drafter_output.shape[0]] = drafter_output
@jit(nopython=True)
def _find_longest_matched_ngram_and_propose_tokens(
origin_tokens: np.ndarray,
min_ngram: int,
max_ngram: int,
max_model_len: int,
k: int,
) -> np.ndarray:
"""
Find the longest n-gram which matches the suffix of the given tokens
whose length is within [min_ngram, max_ngram] (inclusive).
If found, we will extract k right after the matched ngram.
"""
# Do not generate draft tokens is context is shorter than minimum n-gram
total_token = origin_tokens.shape[0]
if total_token < min_ngram:
return np.empty((0,), dtype=origin_tokens.dtype)
# Do not generate draft tokens beyond the max model length.
k = min(k, max_model_len - total_token)
if k <= 0:
return np.empty((0,), dtype=origin_tokens.dtype)
# Flip tokens, and the goal become to find longest ngram
# on the rightmost position which matches the prefix with
# length [min_n, max_n] (inclusive).
tokens = origin_tokens[::-1]
# Longest prefix (not including itself) which is a suffix of
# the current position.
# lps[i] = max{v, where tokens[0:v] == tokens[i+1-v:i+1]}
#
# As ngram is capped by max_ngram to save memory, we only need to
# store lps for the first max_ngram prefix.
lps = np.zeros(max_ngram, dtype=np.int32)
longest_ngram = 0
position = 0
# lps[0] always equal to 0, we start with index 1
prev_lps = 0
i = 1
while i < total_token:
# tokens[:prev_lps] is the longest prefix as a suffix of tokens[:i]
if tokens[prev_lps] == tokens[i]:
# Token match: tokens[:prev_lps+1] is the longest prefix as
# a suffix of tokens[:i+1]
prev_lps += 1
# Check if we found a longer valid ngram.
#
# Update position when longest_ngram matched prev_lps,
# as we want to get the target n-gram of the earliest position
# in the original tokens (i.e.
# latest position in the reversed tokens)
if prev_lps >= longest_ngram:
longest_ngram = prev_lps
position = i
if i < max_ngram:
# Store LPS for the first max_ngram prefix
lps[i] = prev_lps
if prev_lps == max_ngram:
# When prev_lps reached max_ngram, update prev_lps
# to lps[max_ngram-1] to avoid matching ngram
# longer than max_ngram
prev_lps = lps[max_ngram - 1]
i += 1
elif prev_lps != 0:
# Token mismatch: try the second-longest prefix
# among all suffix of tokens[:i],
# which is the longest prefix of tokens[:prev_lps]
prev_lps = lps[prev_lps - 1]
else:
# Token mismatch, and no more prefix (except empty string)
# as a suffix of tokens[:i]
i += 1
if longest_ngram < min_ngram:
# No valid ngram is found
return np.empty((0,), dtype=origin_tokens.dtype)
# Flip the position back, so in origin_tokens,
# origin_tokens[total_token-1-position:total_token-1-position+longest_ngram]
# is the matched ngram, so we should start drafting tokens from
# total_token-1-position+longest_ngram
start_position = total_token - 1 - position + longest_ngram
k = min(k, total_token - start_position)
return origin_tokens[start_position : start_position + k]

View File

@@ -0,0 +1,101 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.config import VllmConfig
from vllm.v1.worker.gpu_input_batch import InputBatch
class SuffixDecodingProposer:
"""
Speculative decoding proposer for Suffix Decoding (https://arxiv.org/pdf/2411.04975).
This class imports and uses the official implementation from Arctic Inference
(https://github.com/snowflakedb/ArcticInference).
"""
def __init__(self, vllm_config: VllmConfig):
config = vllm_config.speculative_config
assert config is not None, "Speculative config must be set"
self.num_speculative_tokens = config.num_speculative_tokens
self.max_tree_depth = config.suffix_decoding_max_tree_depth
self.max_spec_factor = config.suffix_decoding_max_spec_factor
self.min_token_prob = config.suffix_decoding_min_token_prob
self.max_model_len = vllm_config.model_config.max_model_len
# Lazy import to avoid error when Suffix Decoding is not used.
from arctic_inference.suffix_decoding import SuffixDecodingCache
# Initialize and empty cache. This object will take care of caching request
# outputs, evicting old requests, and manages the per-prompt suffix trees.
self.suffix_cache = SuffixDecodingCache(
max_tree_depth=config.suffix_decoding_max_tree_depth,
max_cached_requests=config.suffix_decoding_max_cached_requests,
)
def propose(
self,
input_batch: InputBatch,
sampled_token_ids: list[list[int]],
slot_mappings: dict[str, torch.Tensor]
| list[dict[str, torch.Tensor]]
| None = None, # unused
) -> list[list[int]]:
"""
Propose speculative tokens for each request in the input batch. Suffix Decoding
will speculate a dynamic number of tokens for each request every decoding step,
so each entry in the returned list may have different lengths.
"""
draft_token_ids: list[list[int]] = []
for i, sampled_ids in enumerate(sampled_token_ids):
if not sampled_ids:
# Skip speculative decoding for partial prefills.
draft_token_ids.append([])
continue
req_id = input_batch.req_ids[i]
num_tokens = input_batch.num_tokens_no_spec[i]
if num_tokens >= self.max_model_len:
# Skip requests that have already reached the max model length.
draft_token_ids.append([])
continue
index = input_batch.req_id_to_index[req_id]
if req_id not in self.suffix_cache.active_requests:
if req_id in self.suffix_cache.cached_requests:
# Reset the suffix cache for this request.
self.suffix_cache.evict_cached_response(req_id)
num_prompt_tokens = input_batch.num_prompt_tokens[index]
prompt_token_ids = input_batch.token_ids_cpu[index, :num_prompt_tokens]
# Start a new request, this will build the suffix tree for that prompt.
self.suffix_cache.start_request(req_id, prompt_token_ids)
# Append the newly sampled ids to the suffix cache for this request.
self.suffix_cache.add_active_response(req_id, sampled_ids)
# Suffix decoding only uses the most recent tokens up to max_tree_depth, so
# we extract the pattern from the end of the input.
start = max(0, num_tokens - self.max_tree_depth)
pattern = input_batch.token_ids_cpu[i, start:num_tokens]
draft = self.suffix_cache.speculate(
req_id,
pattern,
max_spec_tokens=min(
self.num_speculative_tokens, self.max_model_len - num_tokens - 1
),
max_spec_factor=self.max_spec_factor,
min_token_prob=self.min_token_prob,
)
draft_token_ids.append(draft.token_ids)
# Stop requests that were not seen in the input batch.
for req_id in (
self.suffix_cache.active_requests - input_batch.req_id_to_index.keys()
):
self.suffix_cache.stop_request(req_id)
return draft_token_ids
def load_model(self, *args, **kwargs):
# No model to load.
pass

View File

@@ -0,0 +1,357 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.config import VllmConfig, replace
from vllm.triton_utils import tl, triton
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
)
PADDING_SLOT_ID = -1
@triton.jit
def eagle_prepare_inputs_padded_kernel(
cu_num_draft_tokens_ptr, # [num_reqs]
valid_sampled_tokens_count_ptr, # [num_reqs]
query_start_loc_gpu_ptr, # [num_reqs + 1]
token_indices_to_sample_ptr, # [num_reqs] (output)
num_rejected_tokens_gpu_ptr, # [num_reqs] (output)
num_reqs, # tl.int32
):
"""
Fused kernel for Eagle prepare_input_padded. This kernel computes the
token index to sample for each request, taking into account the number
of draft tokens and the number of valid sampled tokens (which is one more than
the number of accepted tokens).
"""
req_idx = tl.program_id(axis=0)
if req_idx >= num_reqs:
return
# Calculate num_draft_tokens from cu_num_draft_tokens, which is an inclusive
# cumulative sum (first entry is the first value, not zero).
cu_draft_curr = tl.load(cu_num_draft_tokens_ptr + req_idx)
num_draft_tokens = 0
if req_idx == 0:
num_draft_tokens = cu_draft_curr
else:
cu_draft_prev = tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
num_draft_tokens = cu_draft_curr - cu_draft_prev
valid_count = tl.load(valid_sampled_tokens_count_ptr + req_idx)
num_rejected_tokens = num_draft_tokens + 1 - valid_count
num_rejected_tokens = tl.where(num_draft_tokens > 0, num_rejected_tokens, 0)
# query_start_loc[req_idx + 1] is the start position of the next request,
# which is one past the last token of this request.
q_last_tok_idx = tl.load(query_start_loc_gpu_ptr + req_idx + 1) - 1
index_to_sample = q_last_tok_idx - num_rejected_tokens
tl.store(token_indices_to_sample_ptr + req_idx, index_to_sample)
tl.store(num_rejected_tokens_gpu_ptr + req_idx, num_rejected_tokens)
@triton.jit
def eagle_prepare_next_token_padded_kernel(
sampled_token_ids_ptr, # [num_reqs, num_sampled_tokens_per_req]
discard_request_mask_ptr, # [num_reqs]
backup_next_token_ids_ptr, # [num_reqs]
next_token_ids_ptr, # [num_reqs] (output)
valid_sampled_tokens_count_ptr, # [num_reqs] (output)
vocab_size, # tl.int32
num_sampled_tokens_per_req, # tl.int32 (num_spec_tokens + 1)
num_reqs, # tl.int32
stride_sampled_token_ids, # tl.int32 (stride for dim 0)
BLOCK_SIZE_TOKENS: tl.constexpr, # Power-of-2 >= num_sampled_tokens_per_req
):
"""
Fused kernel for Eagle prepare_next_token_ids_padded. This kernel computes the
number of valid (1 + accepted) tokens for each request, and the corresponding
"next" token id to sample from during speculative decoding. This is the
"last accepted token" from the sampled tokens, or the backup token if no
tokens were accepted or if the request is marked as discarded.
"""
req_idx = tl.program_id(axis=0)
if req_idx >= num_reqs:
return
# Check if this request is discarded.
is_discarded = tl.load(discard_request_mask_ptr + req_idx)
if is_discarded:
backup_token = tl.load(backup_next_token_ids_ptr + req_idx)
valid_count = tl.full((), 0, dtype=tl.uint32)
tl.store(next_token_ids_ptr + req_idx, backup_token)
tl.store(valid_sampled_tokens_count_ptr + req_idx, valid_count)
else:
# Count the number of valid tokens among the sampled tokens.
token_offs = tl.arange(0, BLOCK_SIZE_TOKENS)
token_mask = token_offs < num_sampled_tokens_per_req
row_ptr = sampled_token_ids_ptr + req_idx * stride_sampled_token_ids
token_ids = tl.load(row_ptr + token_offs, mask=token_mask, other=-1)
# Rejected tokens are -1, valid tokens are in [0, vocab_size)
is_valid_mask = (token_ids != -1) & (token_ids < vocab_size) & token_mask
valid_count = tl.sum(is_valid_mask)
if valid_count > 0:
# Guaranteed to be well-defined since
# valid_count > 0 implies is_valid_mask is not empty
last_valid_index = tl.max(tl.where(is_valid_mask, token_offs, -1))
# Select the token at that index, using a sum trick since
# we don't want to load again to access token_ids[last_valid_index].
last_valid_token = tl.sum(
tl.where(token_offs == last_valid_index, token_ids, 0)
)
tl.store(next_token_ids_ptr + req_idx, last_valid_token)
else:
# No valid tokens found, use backup token
backup_token = tl.load(backup_next_token_ids_ptr + req_idx)
tl.store(next_token_ids_ptr + req_idx, backup_token)
tl.store(valid_sampled_tokens_count_ptr + req_idx, valid_count)
def compute_new_slot_mapping(
cad: CommonAttentionMetadata,
new_positions: torch.Tensor,
is_rejected_token_mask: torch.Tensor,
block_size: int,
num_new_tokens: int,
max_model_len: int,
):
batch_size, n_blocks_per_req = cad.block_table_tensor.shape
req_indices = torch.arange(batch_size, device=cad.query_start_loc.device)
req_indices = torch.repeat_interleave(
req_indices,
cad.naive_query_lens() + num_new_tokens,
output_size=len(new_positions),
)
# Clamp the positions to prevent an out-of-bounds error when indexing
# into block_table_tensor.
clamped_positions = torch.clamp(new_positions, max=max_model_len - 1)
block_table_indices = (
req_indices * n_blocks_per_req + clamped_positions // block_size
)
block_nums = cad.block_table_tensor.view(-1)[block_table_indices]
block_offsets = clamped_positions % block_size
new_slot_mapping = block_nums * block_size + block_offsets
# Mask out the position ids that exceed the max model length.
exceeds_max_model_len = new_positions >= max_model_len
new_slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID)
# Mask out rejected tokens to prevent saves to the KV cache.
new_slot_mapping.masked_fill_(is_rejected_token_mask, PADDING_SLOT_ID)
return new_slot_mapping
def create_vllm_config_for_draft_model(
target_model_vllm_config: VllmConfig,
) -> VllmConfig:
"""The vllm_config is configured for the target model, e.g.
its quant_config and parallel_config. But the draft model is potentially
quantized differently, and has potentially different tensor_parallel_size.
This function creates a new vllm_config configured for the drafter.
The vllm_config is useful when loading the draft model with get_model().
"""
old = target_model_vllm_config
assert old.speculative_config is not None, "speculative_config is not set"
old_spec_config = old.speculative_config
new_parallel_config = replace(
old_spec_config.draft_parallel_config, rank=old.parallel_config.rank
)
new: VllmConfig = replace(
old,
quant_config=None,
parallel_config=new_parallel_config,
model_config=old_spec_config.draft_model_config,
)
return new
def extend_all_queries_by_N(
common_attn_metadata: CommonAttentionMetadata,
N: int,
arange: torch.Tensor,
new_slot_mapping: torch.Tensor,
) -> CommonAttentionMetadata:
"""
Creates a new CommonAttentionMetadata with all query lengths increased by N.
Also all seq lens are increased by N.
This is useful e.g. in speculative decoding with parallel drafting, where we
extend each sequence by N tokens and predict all tokens in one pass.
The slot mapping is computed externally, as it requires more information.
"""
cad = common_attn_metadata
# query start loc must be increased by [+0, +N, +2N, ..., +batch_size * N]
new_query_start_loc = cad.query_start_loc + N * arange[: len(cad.query_start_loc)]
new_query_start_loc_cpu = cad.query_start_loc_cpu + N * torch.arange(
len(cad.query_start_loc_cpu), dtype=torch.int32
)
new_cad = cad.replace(
query_start_loc=new_query_start_loc,
query_start_loc_cpu=new_query_start_loc_cpu,
seq_lens=cad.seq_lens + N,
# each request is extended by N tokens -> batch_size * N tokens are added
num_actual_tokens=cad.num_actual_tokens + cad.batch_size() * N,
# All query lens increase by N, so max query len increases by N
max_query_len=cad.max_query_len + N,
max_seq_len=cad.max_seq_len + N,
slot_mapping=new_slot_mapping,
)
return new_cad
# Unified copy/expand kernel
@triton.jit
def copy_and_expand_eagle_inputs_kernel(
# (Padded) Inputs from the target model
target_token_ids_ptr, # [total_tokens_in_batch]
target_positions_ptr, # [total_tokens_in_batch]
next_token_ids_ptr, # [num_reqs]
# Outputs to the drafting buffers
out_input_ids_ptr, # [total_draft_tokens_in_batch] (output)
out_positions_ptr, # [total_draft_tokens_in_batch] (output)
out_is_rejected_token_mask_ptr, # [total_draft_tokens_in_batch] (output)
out_is_masked_token_mask_ptr, # [total_draft_tokens_in_batch] (output)
out_new_token_indices_ptr, # [num_padding_slots_per_request * num_reqs] (output)
out_hidden_state_mapping_ptr, # [total_tokens_in_batch]
# Input metadata
query_start_loc_ptr, # [num_reqs + 1], last value is the total num input tokens
query_end_loc_ptr, # [num_reqs]
padding_token_id, # tl.int32
parallel_drafting_token_id, # tl.int32
# Sizing info
total_input_tokens, # tl.int32
num_padding_slots_per_request, # tl.int32
shift_input_ids, # tl.bool
BLOCK_SIZE_TOKENS: tl.constexpr, # Blocks along token dim to handle prefills
):
"""
Copy and expand inputs from the target model to the drafting buffers for Eagle
speculative decoding. This kernel handles padding slots and parallel drafting
tokens, if enabled.
"""
request_idx = tl.program_id(axis=0)
token_batch_idx = tl.program_id(axis=1)
# Load query locations
query_start_loc = tl.load(query_start_loc_ptr + request_idx)
next_query_start_loc = tl.load(query_start_loc_ptr + request_idx + 1)
query_end_loc = tl.load(query_end_loc_ptr + request_idx)
# Calculate number of valid tokens to copy and input offset
# With shift_input_ids=True, we skip the first token
# Output layout: each request gets (input_len + num_padding_slots_per_request) slots
# But with shift, we lose one token per request
if shift_input_ids:
num_valid_tokens = query_end_loc - query_start_loc
input_offset = 1
output_start = query_start_loc + request_idx * (
num_padding_slots_per_request - 1
)
else:
num_valid_tokens = query_end_loc - query_start_loc + 1
input_offset = 0
output_start = query_start_loc + request_idx * num_padding_slots_per_request
# Number of rejected tokens from previous speculation
num_rejected = next_query_start_loc - query_end_loc - 1
# Total output tokens for this request
total_output_tokens = (
num_valid_tokens + num_padding_slots_per_request + num_rejected
)
# Process tokens in this block
j = token_batch_idx * BLOCK_SIZE_TOKENS + tl.arange(0, BLOCK_SIZE_TOKENS)
# Compute masks for different output regions:
# [0, num_valid_tokens): valid tokens copied from input
# [num_valid_tokens]: bonus token from next_token_ids
# (num_valid_tokens, num_valid_tokens + num_padding_slots_per_request):
# parallel drafting slots
# [num_valid_tokens + num_padding_slots_per_request, total_output_tokens):
# rejected slots
in_bounds = j < total_output_tokens
is_valid_region = j < num_valid_tokens
is_bonus_region = j == num_valid_tokens
is_parallel_draft_region = (j > num_valid_tokens) & (
j < num_valid_tokens + num_padding_slots_per_request
)
is_rejected_region = j >= num_valid_tokens + num_padding_slots_per_request
# Compute output indices
out_idx = output_start + j
# For valid tokens, compute input index
in_idx = query_start_loc + input_offset + j
# Clamp to avoid out-of-bounds access (masked loads still need valid addresses)
in_idx_clamped = tl.minimum(in_idx, total_input_tokens - 1)
# Load input tokens (masked to valid region)
token_ids = tl.load(
target_token_ids_ptr + in_idx_clamped, mask=is_valid_region & in_bounds, other=0
)
# Load the starting position for this request (first position in the sequence)
start_pos = tl.load(target_positions_ptr + query_start_loc)
# Load bonus token for this request
bonus_token = tl.load(next_token_ids_ptr + request_idx)
# Build final token_ids based on region
token_ids = tl.where(is_bonus_region, bonus_token, token_ids)
token_ids = tl.where(
is_parallel_draft_region, parallel_drafting_token_id, token_ids
)
token_ids = tl.where(is_rejected_region, padding_token_id, token_ids)
# Build final positions:
# Positions are NOT shifted - they start from the first input position and increment
# Output position j gets start_pos + j
# (e.g., input positions [5,6,7] -> output [5,6,7,8,9,...])
positions = start_pos + j
# Rejected positions are don't-care, set to 0
positions = tl.where(is_rejected_region, 0, positions)
# Compute output masks
is_rejected_out = is_rejected_region & in_bounds
is_masked_out = is_parallel_draft_region & in_bounds
# Compute indices of new tokens (bonus + parallel drafting) for sampling
# New tokens are at positions
# [num_valid_tokens, num_valid_tokens + num_padding_slots_per_request)
is_new_token_region = (j >= num_valid_tokens) & (
j < num_valid_tokens + num_padding_slots_per_request
)
new_token_local_idx = (
j - num_valid_tokens
) # 0 for bonus, 1, 2, ... for parallel drafting
new_token_out_idx = (
request_idx * num_padding_slots_per_request + new_token_local_idx
)
# Compute hidden state mapping (source index -> destination index)
# This maps each input position to its corresponding output position
# Hidden states don't get shifted, so we map all input tokens (including rejected)
if shift_input_ids:
num_input_tokens_this_request = next_query_start_loc - query_start_loc
is_input_region = j < num_input_tokens_this_request
src_idx = query_start_loc + j
tl.store(out_hidden_state_mapping_ptr + src_idx, out_idx, mask=is_input_region)
# Store outputs
tl.store(out_input_ids_ptr + out_idx, token_ids, mask=in_bounds)
tl.store(out_positions_ptr + out_idx, positions, mask=in_bounds)
tl.store(out_is_rejected_token_mask_ptr + out_idx, is_rejected_out, mask=in_bounds)
tl.store(out_is_masked_token_mask_ptr + out_idx, is_masked_out, mask=in_bounds)
tl.store(
out_new_token_indices_ptr + new_token_out_idx,
out_idx,
mask=is_new_token_region & in_bounds,
)