[speculative decoding] rename lookahead to ngram (#11010)
Co-authored-by: a4zhangfei <a4zhangfei@qq.com>
This commit is contained in:
@@ -1,16 +1,16 @@
|
||||
#include "lookahead.h"
|
||||
#include "ngram.h"
|
||||
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
namespace lookahead {
|
||||
namespace ngram {
|
||||
|
||||
struct Node {
|
||||
std::unordered_map<int32_t, int32_t> next;
|
||||
};
|
||||
|
||||
Lookahead::Result fillResult(int last_token, int draft_token_num, std::vector<Node>& tree, int root) {
|
||||
Lookahead::Result info;
|
||||
Ngram::Result fillResult(int last_token, int draft_token_num, std::vector<Node>& tree, int root) {
|
||||
Ngram::Result info;
|
||||
std::vector<int32_t> prevs;
|
||||
info.token.reserve(draft_token_num);
|
||||
prevs.reserve(draft_token_num);
|
||||
@@ -50,7 +50,7 @@ Lookahead::Result fillResult(int last_token, int draft_token_num, std::vector<No
|
||||
return info;
|
||||
}
|
||||
|
||||
Lookahead::Lookahead(size_t capacity, const Param& param) {
|
||||
Ngram::Ngram(size_t capacity, const Param& param) {
|
||||
param_ = param;
|
||||
nodes_.resize(capacity);
|
||||
for (auto& node : nodes_) {
|
||||
@@ -116,17 +116,16 @@ Lookahead::Lookahead(size_t capacity, const Param& param) {
|
||||
}
|
||||
|
||||
quit_flag_ = false;
|
||||
insert_worker_ = std::thread(&Lookahead::insert, this);
|
||||
insert_worker_ = std::thread(&Ngram::insert, this);
|
||||
}
|
||||
|
||||
Lookahead::~Lookahead() {
|
||||
Ngram::~Ngram() {
|
||||
quit_flag_ = true;
|
||||
insert_queue_.close();
|
||||
insert_worker_.join();
|
||||
}
|
||||
|
||||
std::vector<std::pair<TrieNode*, int32_t>>
|
||||
Lookahead::match(const std::vector<int32_t>& tokens, size_t batch_size) const {
|
||||
std::vector<std::pair<TrieNode*, int32_t>> Ngram::match(const std::vector<int32_t>& tokens, size_t batch_size) const {
|
||||
auto draft_token_num = param_.get_draft_token_num(batch_size);
|
||||
auto min_match_window_size = param_.get_min_match_window_size(batch_size);
|
||||
auto max_match_window_size = param_.max_match_window_size;
|
||||
@@ -154,7 +153,7 @@ Lookahead::match(const std::vector<int32_t>& tokens, size_t batch_size) const {
|
||||
return result;
|
||||
}
|
||||
|
||||
void Lookahead::squeeze(size_t count) {
|
||||
void Ngram::squeeze(size_t count) {
|
||||
if (!(node_pool_.size() >= free_node_count_ + count)) {
|
||||
throw std::runtime_error(
|
||||
"Insufficient node size to release required nodes. "
|
||||
@@ -177,13 +176,13 @@ void Lookahead::squeeze(size_t count) {
|
||||
}
|
||||
}
|
||||
|
||||
void Lookahead::synchronize() const {
|
||||
void Ngram::synchronize() const {
|
||||
while (!insert_queue_.empty()) {
|
||||
std::this_thread::sleep_for(std::chrono::microseconds(10));
|
||||
}
|
||||
}
|
||||
|
||||
void Lookahead::insert() {
|
||||
void Ngram::insert() {
|
||||
while (!quit_flag_) {
|
||||
std::vector<int32_t> data;
|
||||
if (!insert_queue_.dequeue(data)) {
|
||||
@@ -239,13 +238,13 @@ void Lookahead::insert() {
|
||||
}
|
||||
}
|
||||
|
||||
void Lookahead::asyncInsert(std::vector<std::vector<int32_t>>&& tokens) {
|
||||
void Ngram::asyncInsert(std::vector<std::vector<int32_t>>&& tokens) {
|
||||
for (auto&& token : tokens) {
|
||||
insert_queue_.enqueue(std::move(token));
|
||||
}
|
||||
}
|
||||
|
||||
Lookahead::Result Lookahead::matchBFS(const std::vector<int32_t>& tokens, size_t batch_size) const {
|
||||
Ngram::Result Ngram::matchBFS(const std::vector<int32_t>& tokens, size_t batch_size) const {
|
||||
std::vector<std::pair<TrieNode*, int32_t>> nodes = match(tokens, batch_size);
|
||||
|
||||
double bfs_breadth_scale = double(param_.max_bfs_breadth - param_.min_bfs_breadth) /
|
||||
@@ -284,7 +283,7 @@ Lookahead::Result Lookahead::matchBFS(const std::vector<int32_t>& tokens, size_t
|
||||
return fillResult(tokens.back(), draft_token_num + 1, tree, root);
|
||||
}
|
||||
|
||||
Lookahead::Result Lookahead::matchProb(const std::vector<int32_t>& tokens, size_t batch_size) const {
|
||||
Ngram::Result Ngram::matchProb(const std::vector<int32_t>& tokens, size_t batch_size) const {
|
||||
std::vector<std::pair<TrieNode*, int32_t>> nodes = match(tokens, batch_size);
|
||||
auto draft_token_num = param_.get_draft_token_num(batch_size);
|
||||
|
||||
@@ -346,10 +345,10 @@ Lookahead::Result Lookahead::matchProb(const std::vector<int32_t>& tokens, size_
|
||||
return fillResult(tokens.back(), draft_token_num + 1, tree, root);
|
||||
}
|
||||
|
||||
Lookahead::Result Lookahead::batchMatch(const std::vector<std::vector<int32_t>>& tokens) const {
|
||||
Ngram::Result Ngram::batchMatch(const std::vector<std::vector<int32_t>>& tokens) const {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
Result merged_result;
|
||||
auto match_func = param_.match_type == "BFS" ? &Lookahead::matchBFS : &Lookahead::matchProb;
|
||||
auto match_func = param_.match_type == "BFS" ? &Ngram::matchBFS : &Ngram::matchProb;
|
||||
for (const auto& tks : tokens) {
|
||||
Result res = (this->*match_func)(tks, tokens.size());
|
||||
merged_result.token.insert(merged_result.token.end(), res.token.begin(), res.token.end());
|
||||
@@ -358,7 +357,7 @@ Lookahead::Result Lookahead::batchMatch(const std::vector<std::vector<int32_t>>&
|
||||
return merged_result;
|
||||
}
|
||||
|
||||
void Lookahead::Result::truncate(size_t n) {
|
||||
void Ngram::Result::truncate(size_t n) {
|
||||
if (n < token.size()) {
|
||||
int full_n = token.size();
|
||||
for (int i = 1; i < n; ++i) {
|
||||
@@ -369,4 +368,4 @@ void Lookahead::Result::truncate(size_t n) {
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace lookahead
|
||||
} // namespace ngram
|
||||
@@ -15,7 +15,7 @@
|
||||
#include "param.h"
|
||||
#include "queue.h"
|
||||
|
||||
namespace lookahead {
|
||||
namespace ngram {
|
||||
|
||||
struct TrieNode {
|
||||
std::unordered_map<int32_t, TrieNode*> child;
|
||||
@@ -34,7 +34,7 @@ struct TrieNode {
|
||||
std::multiset<TrieNode*, CompareByFreq> sorted_children;
|
||||
};
|
||||
|
||||
class Lookahead {
|
||||
class Ngram {
|
||||
std::vector<TrieNode> nodes_;
|
||||
std::vector<TrieNode*> node_pool_;
|
||||
size_t free_node_count_;
|
||||
@@ -61,12 +61,12 @@ class Lookahead {
|
||||
std::vector<std::tuple<int32_t, int32_t, int32_t, int32_t>> match_tmp_data_;
|
||||
|
||||
public:
|
||||
Lookahead(size_t capacity, const Param& param);
|
||||
Lookahead() = default;
|
||||
~Lookahead();
|
||||
Ngram(size_t capacity, const Param& param);
|
||||
Ngram() = default;
|
||||
~Ngram();
|
||||
|
||||
static Lookahead& instance() {
|
||||
static Lookahead instance;
|
||||
static Ngram& instance() {
|
||||
static Ngram instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
@@ -107,4 +107,4 @@ class Lookahead {
|
||||
void insert();
|
||||
};
|
||||
|
||||
} // namespace lookahead
|
||||
} // namespace ngram
|
||||
@@ -1,7 +1,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# from sglang.op.lookahead import Lookahead, Param
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Tuple
|
||||
@@ -12,17 +10,17 @@ from torch.utils.cpp_extension import load
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_abs_path = os.path.dirname(os.path.abspath(__file__))
|
||||
lookahead_cache_cpp = load(
|
||||
name="lookahead_cache_cpp",
|
||||
ngram_cache_cpp = load(
|
||||
name="ngram_cache_cpp",
|
||||
sources=[
|
||||
f"{_abs_path}/lookahead_cache_binding.cpp",
|
||||
f"{_abs_path}/lookahead.cpp",
|
||||
f"{_abs_path}/ngram_cache_binding.cpp",
|
||||
f"{_abs_path}/ngram.cpp",
|
||||
],
|
||||
extra_cflags=["-O3", "-std=c++20"],
|
||||
)
|
||||
|
||||
|
||||
class LookaheadCache:
|
||||
class NgramCache:
|
||||
def __init__(
|
||||
self,
|
||||
branch_length=18,
|
||||
@@ -34,7 +32,7 @@ class LookaheadCache:
|
||||
match_type="BFS",
|
||||
capacity=1000000,
|
||||
):
|
||||
param = lookahead_cache_cpp.Param()
|
||||
param = ngram_cache_cpp.Param()
|
||||
param.branch_length = branch_length
|
||||
param.min_match_window_size = min_match_window_size
|
||||
param.max_match_window_size = max_match_window_size
|
||||
@@ -42,7 +40,7 @@ class LookaheadCache:
|
||||
param.max_bfs_breadth = max_bfs_breadth
|
||||
param.draft_token_num = draft_token_num
|
||||
param.match_type = match_type
|
||||
self.cache = lookahead_cache_cpp.Lookahead(capacity, param)
|
||||
self.cache = ngram_cache_cpp.Ngram(capacity, param)
|
||||
|
||||
self.default_mask = np.ones((1, 1), dtype=np.int64)
|
||||
self.draft_token_num = draft_token_num
|
||||
@@ -131,7 +129,7 @@ if __name__ == "__main__":
|
||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
|
||||
[1, 2, 3, 44, 55, 66, 77, 88, 99, 100],
|
||||
]
|
||||
cache = LookaheadCache(branch_length=12, draft_token_num=8)
|
||||
cache = NgramCache(branch_length=12, draft_token_num=8)
|
||||
cache.batch_put(token_ids)
|
||||
|
||||
cache.synchronize()
|
||||
@@ -1,19 +1,19 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include "lookahead.h"
|
||||
#include "ngram.h"
|
||||
|
||||
PYBIND11_MODULE(lookahead_cache_cpp, m) {
|
||||
using namespace lookahead;
|
||||
PYBIND11_MODULE(ngram_cache_cpp, m) {
|
||||
using namespace ngram;
|
||||
namespace py = pybind11;
|
||||
m.doc() = "";
|
||||
|
||||
py::class_<Lookahead>(m, "Lookahead")
|
||||
py::class_<Ngram>(m, "Ngram")
|
||||
.def(py::init<size_t, const Param&>(), py::arg("capacity"), py::arg("param"))
|
||||
.def("asyncInsert", &Lookahead::asyncInsert, "")
|
||||
.def("batchMatch", &Lookahead::batchMatch, "")
|
||||
.def("reset", &Lookahead::reset, "")
|
||||
.def("synchronize", &Lookahead::synchronize, "");
|
||||
.def("asyncInsert", &Ngram::asyncInsert, "")
|
||||
.def("batchMatch", &Ngram::batchMatch, "")
|
||||
.def("reset", &Ngram::reset, "")
|
||||
.def("synchronize", &Ngram::synchronize, "");
|
||||
|
||||
py::class_<Param>(m, "Param")
|
||||
.def(py::init<>())
|
||||
@@ -35,9 +35,9 @@ PYBIND11_MODULE(lookahead_cache_cpp, m) {
|
||||
.def("resetBatchReturnTokenNum", &Param::resetBatchReturnTokenNum, "")
|
||||
.def("detail", &Param::detail, "");
|
||||
|
||||
py::class_<Lookahead::Result>(m, "Result")
|
||||
py::class_<Ngram::Result>(m, "Result")
|
||||
.def(py::init<>())
|
||||
.def_readwrite("token", &Lookahead::Result::token)
|
||||
.def_readwrite("mask", &Lookahead::Result::mask)
|
||||
.def("truncate", &Lookahead::Result::truncate);
|
||||
.def_readwrite("token", &Ngram::Result::token)
|
||||
.def_readwrite("mask", &Ngram::Result::mask)
|
||||
.def("truncate", &Ngram::Result::truncate);
|
||||
}
|
||||
@@ -9,7 +9,7 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace lookahead {
|
||||
namespace ngram {
|
||||
|
||||
struct Param {
|
||||
bool enable;
|
||||
@@ -122,4 +122,4 @@ struct Param {
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace lookahead
|
||||
} // namespace ngram
|
||||
@@ -42,7 +42,7 @@ elif is_hip():
|
||||
|
||||
|
||||
@dataclass
|
||||
class LookaheadVerifyInput:
|
||||
class NgramVerifyInput:
|
||||
def __init__(
|
||||
self,
|
||||
draft_token: torch.Tensor,
|
||||
@@ -408,5 +408,5 @@ class LookaheadVerifyInput:
|
||||
def filter_batch(self, new_indices: torch.Tensor):
|
||||
pass
|
||||
|
||||
def merge_batch(self, spec_info: LookaheadVerifyInput):
|
||||
def merge_batch(self, spec_info: NgramVerifyInput):
|
||||
pass
|
||||
@@ -12,8 +12,8 @@ from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.speculative.cpp_lookahead.lookahead_cache import LookaheadCache
|
||||
from sglang.srt.speculative.lookahead_utils import LookaheadVerifyInput
|
||||
from sglang.srt.speculative.cpp_ngram.ngram_cache import NgramCache
|
||||
from sglang.srt.speculative.ngram_utils import NgramVerifyInput
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.srt.utils import broadcast_pyobj
|
||||
|
||||
@@ -22,7 +22,7 @@ logger = logging.getLogger(__name__)
|
||||
USE_FULL_MASK = True
|
||||
|
||||
|
||||
class LOOKAHEADWorker:
|
||||
class NGRAMWorker:
|
||||
def __init__(
|
||||
self,
|
||||
server_args: ServerArgs,
|
||||
@@ -38,9 +38,9 @@ class LOOKAHEADWorker:
|
||||
self.tp_rank = tp_rank
|
||||
self.page_size = server_args.page_size
|
||||
self.draft_token_num: int = server_args.speculative_num_draft_tokens
|
||||
self.branch_length: int = server_args.speculative_lookahead_branch_length
|
||||
self.branch_length: int = server_args.speculative_ngram_branch_length
|
||||
self.max_match_window_size: int = (
|
||||
server_args.speculative_lookahead_max_match_window_size
|
||||
server_args.speculative_ngram_max_match_window_size
|
||||
)
|
||||
|
||||
self.max_batch_size = target_worker.max_running_requests
|
||||
@@ -48,18 +48,18 @@ class LOOKAHEADWorker:
|
||||
|
||||
self._init_preallocated_tensors()
|
||||
|
||||
self.lookahead_cache = LookaheadCache(
|
||||
min_match_window_size=server_args.speculative_lookahead_min_match_window_size,
|
||||
max_match_window_size=server_args.speculative_lookahead_max_match_window_size,
|
||||
min_bfs_breadth=server_args.speculative_lookahead_min_bfs_breadth,
|
||||
max_bfs_breadth=server_args.speculative_lookahead_max_bfs_breadth,
|
||||
capacity=server_args.speculative_lookahead_capacity,
|
||||
branch_length=server_args.speculative_lookahead_branch_length,
|
||||
self.ngram_cache = NgramCache(
|
||||
min_match_window_size=server_args.speculative_ngram_min_match_window_size,
|
||||
max_match_window_size=server_args.speculative_ngram_max_match_window_size,
|
||||
min_bfs_breadth=server_args.speculative_ngram_min_bfs_breadth,
|
||||
max_bfs_breadth=server_args.speculative_ngram_max_bfs_breadth,
|
||||
capacity=server_args.speculative_ngram_capacity,
|
||||
branch_length=server_args.speculative_ngram_branch_length,
|
||||
draft_token_num=server_args.speculative_num_draft_tokens,
|
||||
)
|
||||
|
||||
def clear_cache_pool(self):
|
||||
self.lookahead_cache.reset()
|
||||
self.ngram_cache.reset()
|
||||
|
||||
def _efficient_concat_last_n(self, seq1: List[int], seq2: List[int], n: int):
|
||||
seq2_len = len(seq2)
|
||||
@@ -124,14 +124,14 @@ class LOOKAHEADWorker:
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
bs = batch.batch_size()
|
||||
|
||||
self.lookahead_cache.synchronize()
|
||||
self.ngram_cache.synchronize()
|
||||
batch_tokens = []
|
||||
for req in batch.reqs:
|
||||
check_token = self._efficient_concat_last_n(
|
||||
req.origin_input_ids, req.output_ids, self.max_match_window_size
|
||||
)
|
||||
batch_tokens.append(check_token)
|
||||
req_drafts, mask = self.lookahead_cache.batch_get(batch_tokens)
|
||||
req_drafts, mask = self.ngram_cache.batch_get(batch_tokens)
|
||||
total_draft_token_num = len(req_drafts)
|
||||
|
||||
# Check if speculative decoding is needed; here we always enforce it
|
||||
@@ -184,9 +184,9 @@ class LOOKAHEADWorker:
|
||||
tree_mask.append(req_mask.flatten())
|
||||
tree_mask = torch.cat(tree_mask, dim=0)
|
||||
|
||||
batch.spec_algorithm = SpeculativeAlgorithm.LOOKAHEAD
|
||||
batch.spec_algorithm = SpeculativeAlgorithm.NGRAM
|
||||
batch.forward_mode = ForwardMode.TARGET_VERIFY
|
||||
batch.spec_info = LookaheadVerifyInput(
|
||||
batch.spec_info = NgramVerifyInput(
|
||||
draft_tokens,
|
||||
tree_mask,
|
||||
positions,
|
||||
@@ -197,7 +197,7 @@ class LOOKAHEADWorker:
|
||||
)
|
||||
batch.spec_info.prepare_for_verify(batch, self.page_size)
|
||||
|
||||
def _update_lookahead_cache(self, batch: ScheduleBatch):
|
||||
def _update_ngram_cache(self, batch: ScheduleBatch):
|
||||
batch_tokens = []
|
||||
for req in batch.reqs:
|
||||
# FIXME: Whether to insert 'extend' into the cache or not, after testing,
|
||||
@@ -209,7 +209,7 @@ class LOOKAHEADWorker:
|
||||
req.origin_input_ids, req.output_ids, self.branch_length
|
||||
)
|
||||
batch_tokens.append(put_ids)
|
||||
self.lookahead_cache.batch_put(batch_tokens)
|
||||
self.ngram_cache.batch_put(batch_tokens)
|
||||
|
||||
def forward_batch_speculative_generation(self, batch: ScheduleBatch):
|
||||
self._prepare_for_speculative_decoding(batch)
|
||||
@@ -227,7 +227,7 @@ class LOOKAHEADWorker:
|
||||
logits_output, next_token_ids, num_accepted_tokens = verify_input.verify(
|
||||
batch, logits_output, self.page_size
|
||||
)
|
||||
self._update_lookahead_cache(batch)
|
||||
self._update_ngram_cache(batch)
|
||||
batch.forward_mode = ForwardMode.DECODE
|
||||
|
||||
else:
|
||||
@@ -6,7 +6,7 @@ class SpeculativeAlgorithm(IntEnum):
|
||||
EAGLE = auto()
|
||||
EAGLE3 = auto()
|
||||
STANDALONE = auto()
|
||||
LOOKAHEAD = auto()
|
||||
NGRAM = auto()
|
||||
|
||||
def is_none(self):
|
||||
return self == SpeculativeAlgorithm.NONE
|
||||
@@ -20,8 +20,8 @@ class SpeculativeAlgorithm(IntEnum):
|
||||
def is_standalone(self):
|
||||
return self == SpeculativeAlgorithm.STANDALONE
|
||||
|
||||
def is_lookahead(self):
|
||||
return self == SpeculativeAlgorithm.LOOKAHEAD
|
||||
def is_ngram(self):
|
||||
return self == SpeculativeAlgorithm.NGRAM
|
||||
|
||||
@staticmethod
|
||||
def from_string(name: str):
|
||||
@@ -29,7 +29,7 @@ class SpeculativeAlgorithm(IntEnum):
|
||||
"EAGLE": SpeculativeAlgorithm.EAGLE,
|
||||
"EAGLE3": SpeculativeAlgorithm.EAGLE3,
|
||||
"STANDALONE": SpeculativeAlgorithm.STANDALONE,
|
||||
"LOOKAHEAD": SpeculativeAlgorithm.LOOKAHEAD,
|
||||
"NGRAM": SpeculativeAlgorithm.NGRAM,
|
||||
None: SpeculativeAlgorithm.NONE,
|
||||
}
|
||||
if name is not None:
|
||||
|
||||
Reference in New Issue
Block a user