# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os import numpy as np from numba import get_num_threads, jit, njit, prange, set_num_threads from vllm.config import VllmConfig class NgramProposer: def __init__(self, vllm_config: VllmConfig): assert vllm_config.speculative_config is not None assert vllm_config.speculative_config.prompt_lookup_min is not None assert vllm_config.speculative_config.prompt_lookup_max is not None # Minimum length of the n-gram to match. self.min_n = vllm_config.speculative_config.prompt_lookup_min # Maximum length of the n-gram to match. self.max_n = vllm_config.speculative_config.prompt_lookup_max # Number of tokens follow the match. If there are less than k # tokens follow the match, we will return the maximum amount of # tokens until the end. self.k = vllm_config.speculative_config.num_speculative_tokens # Maximum length of the model. self.max_model_len = vllm_config.model_config.max_model_len # Pre-allocate buffers for numba batch propose. max_num_seqs = vllm_config.scheduler_config.max_num_seqs self.valid_ngram_draft = np.zeros((max_num_seqs, self.k), dtype=np.int32) self.valid_ngram_num_drafts = np.zeros((max_num_seqs), dtype=np.int32) # Threshold of total number of tokens in the batch to enable # multi-threading in numba batch propose. self.num_tokens_threshold = 8192 tp_size = vllm_config.parallel_config.tensor_parallel_size cpu_count = os.cpu_count() # Max number of threads for numba parallel processing. if cpu_count: # Divide by 2 to use physical cores # and not logical cores (hyper-threading). # Cap the number of threads to 8 to avoid using too many threads # since other components like frontend (incl tokenization) # and Structured Outputs also use multiple threads. # TODO(ekagra-ranjan): bump up the cap from 1 to 8 # when TP parallelization for ngram is implemented. self.num_numba_thread_available = min(1, (cpu_count // 2)) # Divide by tp_size to ensure each tensor parallel rank # has some threads since all ranks will run this. self.num_numba_thread_available //= tp_size else: self.num_numba_thread_available = 1 # Trigger Numba JIT compilation for N-gram proposer. # This usually takes less than 1 second. self.propose( [[]] * 1024, [""] * 1024, np.zeros(1024, dtype=np.int32), np.zeros((1024, self.max_model_len), dtype=np.int32), set(), ) def batch_propose( self, num_requests: int, valid_ngram_requests: list, num_tokens_no_spec: np.ndarray, token_ids_cpu: np.ndarray, ) -> list[list[int]]: """Batch version of ngram proposer using numba for acceleration. Args: valid_ngram_requests: Set of indices of requests that need ngram proposals. num_tokens_no_spec: Numpy array of shape (batch_size,) representing the number of tokens without speculative tokens for each request. token_ids_cpu: Numpy array of shape (batch_size, max_model_len) representing the token IDs for each request. Returns: list[list[int]]: A list where each element is a list of proposed token IDs for the corresponding request. """ draft_token_ids: list[list[int]] = [] # Only run batch propose if there are requests needing ngram proposals. # avoid calling numba function with empty list which causes error # ValueError: cannot compute fingerprint of empty list if num_ngram_requests := len(valid_ngram_requests): original_num_numba_threads = get_num_threads() # Ensure we use at least one thread. # If total tokens is small, using multiple threads # may slow down due to overhead. total_tokens = np.sum(num_tokens_no_spec) if total_tokens >= self.num_tokens_threshold: final_num_threads = max( 1, min(self.num_numba_thread_available, num_ngram_requests) ) set_num_threads(final_num_threads) else: set_num_threads(1) batch_propose_numba( valid_ngram_requests, num_tokens_no_spec, token_ids_cpu, self.min_n, self.max_n, self.max_model_len, self.k, self.valid_ngram_draft, self.valid_ngram_num_drafts, ) # Restore original number of threads. set_num_threads(original_num_numba_threads) for i in range(num_requests): if i in valid_ngram_requests and self.valid_ngram_num_drafts[i] > 0: draft_token_ids.append( self.valid_ngram_draft[i, : self.valid_ngram_num_drafts[i]].tolist() ) else: draft_token_ids.append([]) return draft_token_ids def propose( self, sampled_token_ids: list[list[int]], req_ids: list[str], num_tokens_no_spec: np.ndarray, token_ids_cpu: np.ndarray, spec_decode_unsupported_reqs: set, ) -> list[list[int]]: # find which requests need ngram proposals valid_ngram_requests = [] for i, sampled_ids in enumerate(sampled_token_ids): num_sampled_ids = len(sampled_ids) if not num_sampled_ids: # Skip speculative decoding. continue # Skip requests that require sampling parameters that are not # supported with speculative decoding. req_id = req_ids[i] if req_id in spec_decode_unsupported_reqs: continue num_tokens = num_tokens_no_spec[i] if num_tokens >= self.max_model_len: # Skip requests that have already reached the max model length. continue valid_ngram_requests.append(i) draft_token_ids = self.batch_propose( len(sampled_token_ids), valid_ngram_requests, num_tokens_no_spec, token_ids_cpu, ) return draft_token_ids def load_model(self, *args, **kwargs): # No model to load. pass @njit(parallel=True) def batch_propose_numba( valid_ngram_requests: list, num_tokens_no_spec: np.ndarray, token_ids_cpu: np.ndarray, min_n: int, max_n: int, max_model_len: int, k: int, valid_ngram_draft: np.ndarray, valid_ngram_num_drafts: np.ndarray, ): for i in prange(len(valid_ngram_requests)): idx = valid_ngram_requests[i] num_tokens = num_tokens_no_spec[idx] context_token_ids = token_ids_cpu[idx, :num_tokens] drafter_output = _find_longest_matched_ngram_and_propose_tokens( origin_tokens=context_token_ids, min_ngram=min_n, max_ngram=max_n, max_model_len=max_model_len, k=k, ) valid_ngram_num_drafts[idx] = drafter_output.shape[0] if len(drafter_output): valid_ngram_draft[idx, : drafter_output.shape[0]] = drafter_output @jit(nopython=True) def _find_longest_matched_ngram_and_propose_tokens( origin_tokens: np.ndarray, min_ngram: int, max_ngram: int, max_model_len: int, k: int, ) -> np.ndarray: """ Find the longest n-gram which matches the suffix of the given tokens whose length is within [min_ngram, max_ngram] (inclusive). If found, we will extract k right after the matched ngram. """ # Do not generate draft tokens is context is shorter than minimum n-gram total_token = origin_tokens.shape[0] if total_token < min_ngram: return np.empty((0,), dtype=origin_tokens.dtype) # Do not generate draft tokens beyond the max model length. k = min(k, max_model_len - total_token) if k <= 0: return np.empty((0,), dtype=origin_tokens.dtype) # Flip tokens, and the goal become to find longest ngram # on the rightmost position which matches the prefix with # length [min_n, max_n] (inclusive). tokens = origin_tokens[::-1] # Longest prefix (not including itself) which is a suffix of # the current position. # lps[i] = max{v, where tokens[0:v] == tokens[i+1-v:i+1]} # # As ngram is capped by max_ngram to save memory, we only need to # store lps for the first max_ngram prefix. lps = np.zeros(max_ngram, dtype=np.int32) longest_ngram = 0 position = 0 # lps[0] always equal to 0, we start with index 1 prev_lps = 0 i = 1 while i < total_token: # tokens[:prev_lps] is the longest prefix as a suffix of tokens[:i] if tokens[prev_lps] == tokens[i]: # Token match: tokens[:prev_lps+1] is the longest prefix as # a suffix of tokens[:i+1] prev_lps += 1 # Check if we found a longer valid ngram. # # Update position when longest_ngram matched prev_lps, # as we want to get the target n-gram of the earliest position # in the original tokens (i.e. # latest position in the reversed tokens) if prev_lps >= longest_ngram: longest_ngram = prev_lps position = i if i < max_ngram: # Store LPS for the first max_ngram prefix lps[i] = prev_lps if prev_lps == max_ngram: # When prev_lps reached max_ngram, update prev_lps # to lps[max_ngram-1] to avoid matching ngram # longer than max_ngram prev_lps = lps[max_ngram - 1] i += 1 elif prev_lps != 0: # Token mismatch: try the second-longest prefix # among all suffix of tokens[:i], # which is the longest prefix of tokens[:prev_lps] prev_lps = lps[prev_lps - 1] else: # Token mismatch, and no more prefix (except empty string) # as a suffix of tokens[:i] i += 1 if longest_ngram < min_ngram: # No valid ngram is found return np.empty((0,), dtype=origin_tokens.dtype) # Flip the position back, so in origin_tokens, # origin_tokens[total_token-1-position:total_token-1-position+longest_ngram] # is the matched ngram, so we should start drafting tokens from # total_token-1-position+longest_ngram start_position = total_token - 1 - position + longest_ngram k = min(k, total_token - start_position) return origin_tokens[start_position : start_position + k]