Files
sglang/python/sglang/srt/speculative/cpp_lookahead/lookahead.cpp
Zhihao Zhang e7bc600304 [Feature] Speculative decoding support lookahead (#9873)
Co-authored-by: a4zhangfei <a4zhangfei@qq.com>
Co-authored-by: Qiaolin-Yu <liin1211@outlook.com>
2025-09-18 16:42:41 -07:00

373 lines
13 KiB
C++

#include "lookahead.h"
#include <limits>
#include <vector>
namespace lookahead {
struct Node {
std::unordered_map<int32_t, int32_t> next;
};
Lookahead::Result fillResult(int last_token, int draft_token_num, std::vector<Node>& tree, int root) {
Lookahead::Result info;
std::vector<int32_t> prevs;
info.token.reserve(draft_token_num);
prevs.reserve(draft_token_num);
std::queue<std::tuple<int32_t, int32_t, int32_t>> queue;
info.token.emplace_back(last_token);
prevs.emplace_back(-1);
for (auto [token, next] : tree[root].next) {
queue.emplace(token, next, 0);
}
while (queue.size()) {
auto [token, next, prev] = queue.front();
queue.pop();
info.token.emplace_back(token);
prevs.emplace_back(prev);
for (auto [t, n] : tree[next].next) {
queue.emplace(t, n, info.token.size() - 1);
}
}
// zero padding to length
while (info.token.size() < draft_token_num) {
info.token.emplace_back(0);
prevs.emplace_back(0);
}
int n = info.token.size();
info.mask.resize(n * n, 0);
info.mask[0] = 1;
for (int i = 0; i < n; ++i) {
if (prevs[i] != -1) {
memcpy(&info.mask[i * n], &info.mask[prevs[i] * n], prevs[i] + 1);
}
info.mask[i * n + i] = 1;
}
return info;
}
Lookahead::Lookahead(size_t capacity, const Param& param) {
param_ = param;
nodes_.resize(capacity);
for (auto& node : nodes_) {
node_pool_.emplace_back(&node);
}
free_node_count_ = node_pool_.size();
root_ = getNode();
if (!(param_.branch_length > 1)) {
throw std::runtime_error(
"param_.branch_length must be greater than 1, current value: " + std::to_string(param_.branch_length));
}
if (!(param_.min_match_window_size > 0)) {
throw std::runtime_error(
"min_match_window_size must be greater than 0, current value: " + std::to_string(param_.min_match_window_size));
}
if (!(param_.min_match_window_size <= param_.max_match_window_size)) {
throw std::runtime_error(
"min_match_window_size must be less than or equal to max_match_window_size, current min_match_window_size: " +
std::to_string(param_.min_match_window_size) +
", max_match_window_size: " + std::to_string(param_.max_match_window_size));
}
if (!(param_.max_match_window_size < param_.branch_length)) {
throw std::runtime_error(
"max_match_window_size must be less than branch_length, current max_match_window_size: " +
std::to_string(param_.max_match_window_size) + ", branch_length: " + std::to_string(param_.branch_length));
}
if (!(param_.min_bfs_breadth > 0)) {
throw std::runtime_error(
"min_bfs_breadth must be greater than 0, current value: " + std::to_string(param_.min_bfs_breadth));
}
if (!(param_.min_bfs_breadth <= param_.max_bfs_breadth)) {
throw std::runtime_error(
"min_bfs_breadth must be less than or equal to max_bfs_breadth, current min_bfs_breadth: " +
std::to_string(param_.min_bfs_breadth) + ", max_bfs_breadth: " + std::to_string(param_.max_bfs_breadth));
}
if (!(param_.draft_token_num > 0)) {
throw std::runtime_error(
"draft_token_num must be greater than 0, current value: " + std::to_string(param_.draft_token_num));
}
for (auto config : param_.batch_draft_token_num) {
if (config != std::numeric_limits<decltype(config)>::max()) {
if (!(config <= param_.draft_token_num)) {
throw std::runtime_error(
"batch_draft_token_num config value " + std::to_string(config) +
" must be less than or equal to draft_token_num: " + std::to_string(param_.draft_token_num));
}
}
}
for (auto config : param_.batch_min_match_window_size) {
if (config != std::numeric_limits<decltype(config)>::max()) {
if (!(config >= param_.min_match_window_size)) {
throw std::runtime_error(
"batch_min_match_window_size config value " + std::to_string(config) +
" must be greater than or equal to min_match_window_size: " + std::to_string(param_.min_match_window_size));
}
if (!(config <= param_.max_match_window_size)) {
throw std::runtime_error(
"batch_min_match_window_size config value " + std::to_string(config) +
" must be less than or equal to max_match_window_size: " + std::to_string(param_.max_match_window_size));
}
}
}
quit_flag_ = false;
insert_worker_ = std::thread(&Lookahead::insert, this);
}
Lookahead::~Lookahead() {
quit_flag_ = true;
insert_queue_.close();
insert_worker_.join();
}
std::vector<std::pair<TrieNode*, int32_t>>
Lookahead::match(const std::vector<int32_t>& tokens, size_t batch_size) const {
auto draft_token_num = param_.get_draft_token_num(batch_size);
auto min_match_window_size = param_.get_min_match_window_size(batch_size);
auto max_match_window_size = param_.max_match_window_size;
std::vector<std::pair<TrieNode*, int32_t>> result;
result.reserve(param_.max_match_window_size - param_.min_match_window_size);
for (int32_t match_window_size = std::min(tokens.size(), param_.max_match_window_size);
match_window_size >= param_.min_match_window_size;
--match_window_size) {
auto start = tokens.data() + tokens.size() - match_window_size;
auto end = start + match_window_size;
auto cursor = root_;
while (start != end) {
auto iter = cursor->child.find(*start);
if (iter == cursor->child.end()) {
cursor = nullptr;
break;
}
++start;
cursor = iter->second;
}
if (cursor) {
result.emplace_back(std::make_pair(cursor, match_window_size));
}
}
return result;
}
void Lookahead::squeeze(size_t count) {
if (!(node_pool_.size() >= free_node_count_ + count)) {
throw std::runtime_error(
"Insufficient node size to release required nodes. "
"available to release: " +
std::to_string(node_pool_.size() - free_node_count_) + ", required to release: " + std::to_string(count));
}
while (count--) {
auto last = global_lru_.back();
global_lru_.pop_back();
if (!last->child.empty()) {
throw std::runtime_error("The node to be released still has child nodes and cannot be released. ");
}
last->parent->lru.erase(last->parent_lru_pos);
last->parent->sorted_children.erase(last);
last->parent->child.erase(last->token);
node_pool_[free_node_count_++] = last;
}
}
void Lookahead::synchronize() const {
while (!insert_queue_.empty()) {
std::this_thread::sleep_for(std::chrono::microseconds(10));
}
}
void Lookahead::insert() {
while (!quit_flag_) {
std::vector<int32_t> data;
if (!insert_queue_.dequeue(data)) {
continue;
}
const auto* token = data.data();
size_t size = data.size();
std::unique_lock<std::mutex> lock(mutex_);
for (size_t i = 0; i + param_.min_match_window_size < size; ++i) {
auto start = token + i;
auto end = start + std::min(size - i, param_.branch_length);
if (end - start > free_node_count_) {
squeeze(end - start - free_node_count_);
}
TrieNode* cursor = root_;
path_.clear();
while (start != end) {
auto token = *start;
auto iter = cursor->child.find(token);
if (iter == cursor->child.end()) {
iter = cursor->child.insert({token, getNode()}).first;
auto node = iter->second;
cursor->lru.emplace_front(node);
global_lru_.emplace_back(node);
node->token = token;
node->parent = cursor;
node->parent_lru_pos = cursor->lru.begin();
node->global_lru_pos = --global_lru_.end();
node->freq = 1;
cursor->sorted_children.insert(node);
} else {
auto node = iter->second;
cursor->sorted_children.erase(node);
node->freq++;
cursor->sorted_children.insert(node);
cursor->lru.splice(cursor->lru.begin(), cursor->lru, node->parent_lru_pos);
}
cursor = iter->second;
path_.emplace_back(cursor);
++start;
}
for (auto it = path_.rbegin(); it != path_.rend(); ++it) {
TrieNode* node = *it;
global_lru_.splice(global_lru_.begin(), global_lru_, node->global_lru_pos);
}
}
}
}
void Lookahead::asyncInsert(std::vector<std::vector<int32_t>>&& tokens) {
for (auto&& token : tokens) {
insert_queue_.enqueue(std::move(token));
}
}
Lookahead::Result Lookahead::matchBFS(const std::vector<int32_t>& tokens, size_t batch_size) const {
std::vector<std::pair<TrieNode*, int32_t>> nodes = match(tokens, batch_size);
double bfs_breadth_scale = double(param_.max_bfs_breadth - param_.min_bfs_breadth) /
(param_.max_match_window_size - param_.min_match_window_size + 1);
auto draft_token_num = param_.get_draft_token_num(batch_size);
std::vector<Node> tree(draft_token_num + 1);
int root = 0;
int cursor = 1;
for (auto [node, depth] : nodes) {
std::queue<std::tuple<int32_t, double, const TrieNode*>> queue; // parent, bfs_breadth, node
queue.push({root, (param_.max_match_window_size - depth) * bfs_breadth_scale + param_.min_bfs_breadth, node});
while (queue.size() && cursor <= draft_token_num) {
auto front = queue.front();
queue.pop();
auto parent = std::get<0>(front);
auto cur_breadth = std::get<1>(front);
auto iter = std::get<2>(front)->lru.begin();
auto breadth = std::max(1, int32_t(cur_breadth));
for (int i = 0; i < breadth && iter != std::get<2>(front)->lru.end() && cursor <= draft_token_num; ++i, ++iter) {
auto token = (*iter)->token;
auto pos = -1;
if (auto tit = tree[parent].next.find(token); tit != tree[parent].next.end()) {
pos = tit->second;
} else {
pos = tree[parent].next.insert(std::make_pair(token, cursor++)).first->second;
}
queue.emplace(pos, cur_breadth - bfs_breadth_scale, *iter);
}
}
}
return fillResult(tokens.back(), draft_token_num + 1, tree, root);
}
Lookahead::Result Lookahead::matchProb(const std::vector<int32_t>& tokens, size_t batch_size) const {
std::vector<std::pair<TrieNode*, int32_t>> nodes = match(tokens, batch_size);
auto draft_token_num = param_.get_draft_token_num(batch_size);
struct CompareByLastDouble {
bool operator()(
const std::tuple<double, const TrieNode*, double>& a, // parent_pos, node, final_prob
const std::tuple<double, const TrieNode*, double>& b) const {
return std::get<2>(a) < std::get<2>(b);
}
};
std::priority_queue<
std::tuple<double, const TrieNode*, double>,
std::vector<std::tuple<double, const TrieNode*, double>>,
CompareByLastDouble>
heap;
std::vector<Node> tree(draft_token_num + 1);
int root = 0;
int cursor = 1;
int top_k = param_.max_bfs_breadth;
auto addToHeap = [&heap, &top_k](int parent, const TrieNode* trie_node, double prob) -> void {
double sum_freq = 0.0;
int count = 0;
std::list<std::pair<TrieNode*, int32_t>> topk_children;
for (auto* child : trie_node->sorted_children) {
sum_freq += static_cast<double>(child->freq);
topk_children.emplace_back(child, child->freq);
if (++count >= top_k) break;
}
if (sum_freq <= 0) sum_freq = 1.0;
for (const auto& [child, freq] : topk_children) {
double norm_freq = static_cast<double>(freq) / sum_freq * prob;
heap.emplace(parent, child, norm_freq);
}
};
for (auto [node, _] : nodes) {
addToHeap(root, node, 1.0);
while (!heap.empty() && cursor <= draft_token_num) {
auto [parent, trie_node, prob] = heap.top(); // parent_pos, node, final_prob
heap.pop();
auto token = trie_node->token;
int pos = -1;
auto tit = tree[parent].next.find(token);
if (tit != tree[parent].next.end()) {
pos = tit->second;
} else {
pos = cursor++;
tree[parent].next[token] = pos;
}
addToHeap(pos, trie_node, prob);
}
}
return fillResult(tokens.back(), draft_token_num + 1, tree, root);
}
Lookahead::Result Lookahead::batchMatch(const std::vector<std::vector<int32_t>>& tokens) const {
std::unique_lock<std::mutex> lock(mutex_);
Result merged_result;
auto match_func = param_.match_type == "BFS" ? &Lookahead::matchBFS : &Lookahead::matchProb;
for (const auto& tks : tokens) {
Result res = (this->*match_func)(tks, tokens.size());
merged_result.token.insert(merged_result.token.end(), res.token.begin(), res.token.end());
merged_result.mask.insert(merged_result.mask.end(), res.mask.begin(), res.mask.end());
}
return merged_result;
}
void Lookahead::Result::truncate(size_t n) {
if (n < token.size()) {
int full_n = token.size();
for (int i = 1; i < n; ++i) {
memcpy(&mask[i * n], &mask[i * full_n], sizeof(mask[0]) * n);
}
token.resize(n);
mask.resize(n * n);
}
}
} // namespace lookahead