104 lines
4.4 KiB
Python
104 lines
4.4 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import numpy as np
|
|
|
|
from vllm.config import VllmConfig
|
|
from vllm.v1.worker.gpu_input_batch import InputBatch
|
|
|
|
|
|
class SuffixDecodingProposer:
|
|
"""
|
|
Speculative decoding proposer for Suffix Decoding (https://arxiv.org/pdf/2411.04975).
|
|
This class imports and uses the official implementation from Arctic Inference
|
|
(https://github.com/snowflakedb/ArcticInference).
|
|
"""
|
|
|
|
def __init__(self, vllm_config: VllmConfig):
|
|
config = vllm_config.speculative_config
|
|
self.num_speculative_tokens = config.num_speculative_tokens
|
|
self.max_tree_depth = config.suffix_decoding_max_tree_depth
|
|
self.max_spec_factor = config.suffix_decoding_max_spec_factor
|
|
self.min_token_prob = config.suffix_decoding_min_token_prob
|
|
self.max_model_len = vllm_config.model_config.max_model_len
|
|
|
|
# Lazy import to avoid error when Suffix Decoding is not used.
|
|
from arctic_inference.suffix_decoding import SuffixDecodingCache
|
|
|
|
# Initialize and empty cache. This object will take care of caching request
|
|
# outputs, evicting old requests, and manages the per-prompt suffix trees.
|
|
self.suffix_cache = SuffixDecodingCache(
|
|
max_tree_depth=config.suffix_decoding_max_tree_depth,
|
|
max_cached_requests=config.suffix_decoding_max_cached_requests,
|
|
)
|
|
|
|
def propose(
|
|
self,
|
|
input_batch: InputBatch,
|
|
sampled_token_ids: list[np.ndarray],
|
|
) -> list[list[int]]:
|
|
"""
|
|
Propose speculative tokens for each request in the input batch. Suffix Decoding
|
|
will speculate a dynamic number of tokens for each request every decoding step,
|
|
so each entry in the returned list may have different lengths.
|
|
"""
|
|
draft_token_ids: list[np.ndarray] = []
|
|
for i, sampled_ids in enumerate(sampled_token_ids):
|
|
if sampled_ids.shape[0] == 0:
|
|
# Skip speculative decoding for partial prefills.
|
|
draft_token_ids.append([])
|
|
continue
|
|
|
|
# Skip requests that require sampling parameters that are not
|
|
# supported with speculative decoding.
|
|
req_id = input_batch.req_ids[i]
|
|
if req_id in input_batch.spec_decode_unsupported_reqs:
|
|
draft_token_ids.append([])
|
|
continue
|
|
|
|
num_tokens = input_batch.num_tokens_no_spec[i]
|
|
if num_tokens >= self.max_model_len:
|
|
# Skip requests that have already reached the max model length.
|
|
draft_token_ids.append([])
|
|
continue
|
|
|
|
index = input_batch.req_id_to_index[req_id]
|
|
if req_id not in self.suffix_cache.active_requests:
|
|
if req_id in self.suffix_cache.cached_requests:
|
|
# Reset the suffix cache for this request.
|
|
self.suffix_cache.evict_cached_response(req_id)
|
|
num_prompt_tokens = input_batch.num_prompt_tokens[index]
|
|
prompt_token_ids = input_batch.token_ids_cpu[index, :num_prompt_tokens]
|
|
# Start a new request, this will build the suffix tree for that prompt.
|
|
self.suffix_cache.start_request(req_id, prompt_token_ids)
|
|
|
|
# Append the newly sampled ids to the suffix cache for this request.
|
|
self.suffix_cache.add_active_response(req_id, sampled_ids.tolist())
|
|
|
|
# Suffix decoding only uses the most recent tokens up to max_tree_depth, so
|
|
# we extract the pattern from the end of the input.
|
|
start = max(0, num_tokens - self.max_tree_depth)
|
|
pattern = input_batch.token_ids_cpu[i, start:num_tokens]
|
|
draft = self.suffix_cache.speculate(
|
|
req_id,
|
|
pattern,
|
|
max_spec_tokens=min(
|
|
self.num_speculative_tokens, self.max_model_len - num_tokens - 1
|
|
),
|
|
max_spec_factor=self.max_spec_factor,
|
|
min_token_prob=self.min_token_prob,
|
|
)
|
|
|
|
draft_token_ids.append(draft.token_ids)
|
|
|
|
# Stop requests that were not seen in the input batch.
|
|
for req_id in (
|
|
self.suffix_cache.active_requests - input_batch.req_id_to_index.keys()
|
|
):
|
|
self.suffix_cache.stop_request(req_id)
|
|
|
|
return draft_token_ids
|
|
|
|
def load_model(self, *args, **kwargs):
|
|
# No model to load.
|
|
pass
|