Co-authored-by: a4zhangfei <a4zhangfei@qq.com> Co-authored-by: Qiaolin-Yu <liin1211@outlook.com>
373 lines
13 KiB
C++
373 lines
13 KiB
C++
#include "lookahead.h"
|
|
|
|
#include <limits>
|
|
#include <vector>
|
|
|
|
namespace lookahead {
|
|
|
|
struct Node {
|
|
std::unordered_map<int32_t, int32_t> next;
|
|
};
|
|
|
|
Lookahead::Result fillResult(int last_token, int draft_token_num, std::vector<Node>& tree, int root) {
|
|
Lookahead::Result info;
|
|
std::vector<int32_t> prevs;
|
|
info.token.reserve(draft_token_num);
|
|
prevs.reserve(draft_token_num);
|
|
std::queue<std::tuple<int32_t, int32_t, int32_t>> queue;
|
|
info.token.emplace_back(last_token);
|
|
prevs.emplace_back(-1);
|
|
|
|
for (auto [token, next] : tree[root].next) {
|
|
queue.emplace(token, next, 0);
|
|
}
|
|
while (queue.size()) {
|
|
auto [token, next, prev] = queue.front();
|
|
queue.pop();
|
|
info.token.emplace_back(token);
|
|
prevs.emplace_back(prev);
|
|
for (auto [t, n] : tree[next].next) {
|
|
queue.emplace(t, n, info.token.size() - 1);
|
|
}
|
|
}
|
|
|
|
// zero padding to length
|
|
while (info.token.size() < draft_token_num) {
|
|
info.token.emplace_back(0);
|
|
prevs.emplace_back(0);
|
|
}
|
|
|
|
int n = info.token.size();
|
|
info.mask.resize(n * n, 0);
|
|
info.mask[0] = 1;
|
|
for (int i = 0; i < n; ++i) {
|
|
if (prevs[i] != -1) {
|
|
memcpy(&info.mask[i * n], &info.mask[prevs[i] * n], prevs[i] + 1);
|
|
}
|
|
info.mask[i * n + i] = 1;
|
|
}
|
|
|
|
return info;
|
|
}
|
|
|
|
Lookahead::Lookahead(size_t capacity, const Param& param) {
|
|
param_ = param;
|
|
nodes_.resize(capacity);
|
|
for (auto& node : nodes_) {
|
|
node_pool_.emplace_back(&node);
|
|
}
|
|
free_node_count_ = node_pool_.size();
|
|
root_ = getNode();
|
|
|
|
if (!(param_.branch_length > 1)) {
|
|
throw std::runtime_error(
|
|
"param_.branch_length must be greater than 1, current value: " + std::to_string(param_.branch_length));
|
|
}
|
|
if (!(param_.min_match_window_size > 0)) {
|
|
throw std::runtime_error(
|
|
"min_match_window_size must be greater than 0, current value: " + std::to_string(param_.min_match_window_size));
|
|
}
|
|
if (!(param_.min_match_window_size <= param_.max_match_window_size)) {
|
|
throw std::runtime_error(
|
|
"min_match_window_size must be less than or equal to max_match_window_size, current min_match_window_size: " +
|
|
std::to_string(param_.min_match_window_size) +
|
|
", max_match_window_size: " + std::to_string(param_.max_match_window_size));
|
|
}
|
|
if (!(param_.max_match_window_size < param_.branch_length)) {
|
|
throw std::runtime_error(
|
|
"max_match_window_size must be less than branch_length, current max_match_window_size: " +
|
|
std::to_string(param_.max_match_window_size) + ", branch_length: " + std::to_string(param_.branch_length));
|
|
}
|
|
if (!(param_.min_bfs_breadth > 0)) {
|
|
throw std::runtime_error(
|
|
"min_bfs_breadth must be greater than 0, current value: " + std::to_string(param_.min_bfs_breadth));
|
|
}
|
|
if (!(param_.min_bfs_breadth <= param_.max_bfs_breadth)) {
|
|
throw std::runtime_error(
|
|
"min_bfs_breadth must be less than or equal to max_bfs_breadth, current min_bfs_breadth: " +
|
|
std::to_string(param_.min_bfs_breadth) + ", max_bfs_breadth: " + std::to_string(param_.max_bfs_breadth));
|
|
}
|
|
if (!(param_.draft_token_num > 0)) {
|
|
throw std::runtime_error(
|
|
"draft_token_num must be greater than 0, current value: " + std::to_string(param_.draft_token_num));
|
|
}
|
|
for (auto config : param_.batch_draft_token_num) {
|
|
if (config != std::numeric_limits<decltype(config)>::max()) {
|
|
if (!(config <= param_.draft_token_num)) {
|
|
throw std::runtime_error(
|
|
"batch_draft_token_num config value " + std::to_string(config) +
|
|
" must be less than or equal to draft_token_num: " + std::to_string(param_.draft_token_num));
|
|
}
|
|
}
|
|
}
|
|
for (auto config : param_.batch_min_match_window_size) {
|
|
if (config != std::numeric_limits<decltype(config)>::max()) {
|
|
if (!(config >= param_.min_match_window_size)) {
|
|
throw std::runtime_error(
|
|
"batch_min_match_window_size config value " + std::to_string(config) +
|
|
" must be greater than or equal to min_match_window_size: " + std::to_string(param_.min_match_window_size));
|
|
}
|
|
if (!(config <= param_.max_match_window_size)) {
|
|
throw std::runtime_error(
|
|
"batch_min_match_window_size config value " + std::to_string(config) +
|
|
" must be less than or equal to max_match_window_size: " + std::to_string(param_.max_match_window_size));
|
|
}
|
|
}
|
|
}
|
|
|
|
quit_flag_ = false;
|
|
insert_worker_ = std::thread(&Lookahead::insert, this);
|
|
}
|
|
|
|
Lookahead::~Lookahead() {
|
|
quit_flag_ = true;
|
|
insert_queue_.close();
|
|
insert_worker_.join();
|
|
}
|
|
|
|
std::vector<std::pair<TrieNode*, int32_t>>
|
|
Lookahead::match(const std::vector<int32_t>& tokens, size_t batch_size) const {
|
|
auto draft_token_num = param_.get_draft_token_num(batch_size);
|
|
auto min_match_window_size = param_.get_min_match_window_size(batch_size);
|
|
auto max_match_window_size = param_.max_match_window_size;
|
|
std::vector<std::pair<TrieNode*, int32_t>> result;
|
|
result.reserve(param_.max_match_window_size - param_.min_match_window_size);
|
|
for (int32_t match_window_size = std::min(tokens.size(), param_.max_match_window_size);
|
|
match_window_size >= param_.min_match_window_size;
|
|
--match_window_size) {
|
|
auto start = tokens.data() + tokens.size() - match_window_size;
|
|
auto end = start + match_window_size;
|
|
auto cursor = root_;
|
|
while (start != end) {
|
|
auto iter = cursor->child.find(*start);
|
|
if (iter == cursor->child.end()) {
|
|
cursor = nullptr;
|
|
break;
|
|
}
|
|
++start;
|
|
cursor = iter->second;
|
|
}
|
|
if (cursor) {
|
|
result.emplace_back(std::make_pair(cursor, match_window_size));
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
|
|
void Lookahead::squeeze(size_t count) {
|
|
if (!(node_pool_.size() >= free_node_count_ + count)) {
|
|
throw std::runtime_error(
|
|
"Insufficient node size to release required nodes. "
|
|
"available to release: " +
|
|
std::to_string(node_pool_.size() - free_node_count_) + ", required to release: " + std::to_string(count));
|
|
}
|
|
while (count--) {
|
|
auto last = global_lru_.back();
|
|
global_lru_.pop_back();
|
|
|
|
if (!last->child.empty()) {
|
|
throw std::runtime_error("The node to be released still has child nodes and cannot be released. ");
|
|
}
|
|
|
|
last->parent->lru.erase(last->parent_lru_pos);
|
|
last->parent->sorted_children.erase(last);
|
|
last->parent->child.erase(last->token);
|
|
|
|
node_pool_[free_node_count_++] = last;
|
|
}
|
|
}
|
|
|
|
void Lookahead::synchronize() const {
|
|
while (!insert_queue_.empty()) {
|
|
std::this_thread::sleep_for(std::chrono::microseconds(10));
|
|
}
|
|
}
|
|
|
|
void Lookahead::insert() {
|
|
while (!quit_flag_) {
|
|
std::vector<int32_t> data;
|
|
if (!insert_queue_.dequeue(data)) {
|
|
continue;
|
|
}
|
|
const auto* token = data.data();
|
|
size_t size = data.size();
|
|
std::unique_lock<std::mutex> lock(mutex_);
|
|
|
|
for (size_t i = 0; i + param_.min_match_window_size < size; ++i) {
|
|
auto start = token + i;
|
|
auto end = start + std::min(size - i, param_.branch_length);
|
|
|
|
if (end - start > free_node_count_) {
|
|
squeeze(end - start - free_node_count_);
|
|
}
|
|
|
|
TrieNode* cursor = root_;
|
|
path_.clear();
|
|
while (start != end) {
|
|
auto token = *start;
|
|
auto iter = cursor->child.find(token);
|
|
if (iter == cursor->child.end()) {
|
|
iter = cursor->child.insert({token, getNode()}).first;
|
|
auto node = iter->second;
|
|
|
|
cursor->lru.emplace_front(node);
|
|
global_lru_.emplace_back(node);
|
|
|
|
node->token = token;
|
|
node->parent = cursor;
|
|
node->parent_lru_pos = cursor->lru.begin();
|
|
node->global_lru_pos = --global_lru_.end();
|
|
node->freq = 1;
|
|
cursor->sorted_children.insert(node);
|
|
} else {
|
|
auto node = iter->second;
|
|
cursor->sorted_children.erase(node);
|
|
node->freq++;
|
|
cursor->sorted_children.insert(node);
|
|
cursor->lru.splice(cursor->lru.begin(), cursor->lru, node->parent_lru_pos);
|
|
}
|
|
cursor = iter->second;
|
|
path_.emplace_back(cursor);
|
|
++start;
|
|
}
|
|
|
|
for (auto it = path_.rbegin(); it != path_.rend(); ++it) {
|
|
TrieNode* node = *it;
|
|
global_lru_.splice(global_lru_.begin(), global_lru_, node->global_lru_pos);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
void Lookahead::asyncInsert(std::vector<std::vector<int32_t>>&& tokens) {
|
|
for (auto&& token : tokens) {
|
|
insert_queue_.enqueue(std::move(token));
|
|
}
|
|
}
|
|
|
|
Lookahead::Result Lookahead::matchBFS(const std::vector<int32_t>& tokens, size_t batch_size) const {
|
|
std::vector<std::pair<TrieNode*, int32_t>> nodes = match(tokens, batch_size);
|
|
|
|
double bfs_breadth_scale = double(param_.max_bfs_breadth - param_.min_bfs_breadth) /
|
|
(param_.max_match_window_size - param_.min_match_window_size + 1);
|
|
|
|
auto draft_token_num = param_.get_draft_token_num(batch_size);
|
|
std::vector<Node> tree(draft_token_num + 1);
|
|
int root = 0;
|
|
int cursor = 1;
|
|
|
|
for (auto [node, depth] : nodes) {
|
|
std::queue<std::tuple<int32_t, double, const TrieNode*>> queue; // parent, bfs_breadth, node
|
|
queue.push({root, (param_.max_match_window_size - depth) * bfs_breadth_scale + param_.min_bfs_breadth, node});
|
|
while (queue.size() && cursor <= draft_token_num) {
|
|
auto front = queue.front();
|
|
queue.pop();
|
|
|
|
auto parent = std::get<0>(front);
|
|
auto cur_breadth = std::get<1>(front);
|
|
auto iter = std::get<2>(front)->lru.begin();
|
|
|
|
auto breadth = std::max(1, int32_t(cur_breadth));
|
|
for (int i = 0; i < breadth && iter != std::get<2>(front)->lru.end() && cursor <= draft_token_num; ++i, ++iter) {
|
|
auto token = (*iter)->token;
|
|
auto pos = -1;
|
|
if (auto tit = tree[parent].next.find(token); tit != tree[parent].next.end()) {
|
|
pos = tit->second;
|
|
} else {
|
|
pos = tree[parent].next.insert(std::make_pair(token, cursor++)).first->second;
|
|
}
|
|
queue.emplace(pos, cur_breadth - bfs_breadth_scale, *iter);
|
|
}
|
|
}
|
|
}
|
|
|
|
return fillResult(tokens.back(), draft_token_num + 1, tree, root);
|
|
}
|
|
|
|
Lookahead::Result Lookahead::matchProb(const std::vector<int32_t>& tokens, size_t batch_size) const {
|
|
std::vector<std::pair<TrieNode*, int32_t>> nodes = match(tokens, batch_size);
|
|
auto draft_token_num = param_.get_draft_token_num(batch_size);
|
|
|
|
struct CompareByLastDouble {
|
|
bool operator()(
|
|
const std::tuple<double, const TrieNode*, double>& a, // parent_pos, node, final_prob
|
|
const std::tuple<double, const TrieNode*, double>& b) const {
|
|
return std::get<2>(a) < std::get<2>(b);
|
|
}
|
|
};
|
|
|
|
std::priority_queue<
|
|
std::tuple<double, const TrieNode*, double>,
|
|
std::vector<std::tuple<double, const TrieNode*, double>>,
|
|
CompareByLastDouble>
|
|
heap;
|
|
|
|
std::vector<Node> tree(draft_token_num + 1);
|
|
|
|
int root = 0;
|
|
int cursor = 1;
|
|
int top_k = param_.max_bfs_breadth;
|
|
|
|
auto addToHeap = [&heap, &top_k](int parent, const TrieNode* trie_node, double prob) -> void {
|
|
double sum_freq = 0.0;
|
|
int count = 0;
|
|
std::list<std::pair<TrieNode*, int32_t>> topk_children;
|
|
for (auto* child : trie_node->sorted_children) {
|
|
sum_freq += static_cast<double>(child->freq);
|
|
topk_children.emplace_back(child, child->freq);
|
|
if (++count >= top_k) break;
|
|
}
|
|
if (sum_freq <= 0) sum_freq = 1.0;
|
|
for (const auto& [child, freq] : topk_children) {
|
|
double norm_freq = static_cast<double>(freq) / sum_freq * prob;
|
|
heap.emplace(parent, child, norm_freq);
|
|
}
|
|
};
|
|
|
|
for (auto [node, _] : nodes) {
|
|
addToHeap(root, node, 1.0);
|
|
|
|
while (!heap.empty() && cursor <= draft_token_num) {
|
|
auto [parent, trie_node, prob] = heap.top(); // parent_pos, node, final_prob
|
|
heap.pop();
|
|
auto token = trie_node->token;
|
|
int pos = -1;
|
|
auto tit = tree[parent].next.find(token);
|
|
if (tit != tree[parent].next.end()) {
|
|
pos = tit->second;
|
|
} else {
|
|
pos = cursor++;
|
|
tree[parent].next[token] = pos;
|
|
}
|
|
addToHeap(pos, trie_node, prob);
|
|
}
|
|
}
|
|
|
|
return fillResult(tokens.back(), draft_token_num + 1, tree, root);
|
|
}
|
|
|
|
Lookahead::Result Lookahead::batchMatch(const std::vector<std::vector<int32_t>>& tokens) const {
|
|
std::unique_lock<std::mutex> lock(mutex_);
|
|
Result merged_result;
|
|
auto match_func = param_.match_type == "BFS" ? &Lookahead::matchBFS : &Lookahead::matchProb;
|
|
for (const auto& tks : tokens) {
|
|
Result res = (this->*match_func)(tks, tokens.size());
|
|
merged_result.token.insert(merged_result.token.end(), res.token.begin(), res.token.end());
|
|
merged_result.mask.insert(merged_result.mask.end(), res.mask.begin(), res.mask.end());
|
|
}
|
|
return merged_result;
|
|
}
|
|
|
|
void Lookahead::Result::truncate(size_t n) {
|
|
if (n < token.size()) {
|
|
int full_n = token.size();
|
|
for (int i = 1; i < n; ++i) {
|
|
memcpy(&mask[i * n], &mask[i * full_n], sizeof(mask[0]) * n);
|
|
}
|
|
token.resize(n);
|
|
mask.resize(n * n);
|
|
}
|
|
}
|
|
|
|
} // namespace lookahead
|