From 815dce0554793d0788faf4eaacf0c7271c070e95 Mon Sep 17 00:00:00 2001 From: yukavio <67678385+yukavio@users.noreply.github.com> Date: Thu, 2 Jan 2025 19:22:34 +0800 Subject: [PATCH] Eagle speculative decoding part 4: Add EAGLE2 worker (#2150) Co-authored-by: kavioyu Co-authored-by: Lianmin Zheng --- .../engine/EAGLE_offline_batch_inference.py | 37 ++ .../srt/speculative/build_eagle_tree.py | 347 ++++++++++ python/sglang/srt/speculative/eagle_utils.py | 618 ++++++++++++++++++ python/sglang/srt/speculative/eagle_worker.py | 170 +++++ test/srt/run_suite.py | 1 + test/srt/test_eagle_infer.py | 39 ++ 6 files changed, 1212 insertions(+) create mode 100644 examples/runtime/engine/EAGLE_offline_batch_inference.py create mode 100644 python/sglang/srt/speculative/build_eagle_tree.py create mode 100644 python/sglang/srt/speculative/eagle_utils.py create mode 100644 python/sglang/srt/speculative/eagle_worker.py create mode 100644 test/srt/test_eagle_infer.py diff --git a/examples/runtime/engine/EAGLE_offline_batch_inference.py b/examples/runtime/engine/EAGLE_offline_batch_inference.py new file mode 100644 index 000000000..0885959b3 --- /dev/null +++ b/examples/runtime/engine/EAGLE_offline_batch_inference.py @@ -0,0 +1,37 @@ +import sglang as sgl + + +def main(): + # Sample prompts. + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + # Create a sampling params object. + sampling_params = {"temperature": 0, "max_new_tokens": 30} + + # Create an LLM. + llm = sgl.Engine( + model_path="meta-llama/Llama-2-7b-chat-hf", + speculative_algorithm="EAGLE", + speculative_draft_model_path="lmzheng/sglang-EAGLE-llama2-chat-7B", + speculative_num_steps=3, + speculative_eagle_topk=4, + speculative_num_draft_tokens=16, + ) + + outputs = llm.generate(prompts, sampling_params) + + # Print the outputs. + for prompt, output in zip(prompts, outputs): + print("===============================") + print(f"Prompt: {prompt}\nGenerated text: {output['text']}") + + +# The __main__ condition is necessary here because we use "spawn" to create subprocesses +# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine +if __name__ == "__main__": + main() diff --git a/python/sglang/srt/speculative/build_eagle_tree.py b/python/sglang/srt/speculative/build_eagle_tree.py new file mode 100644 index 000000000..6412825ed --- /dev/null +++ b/python/sglang/srt/speculative/build_eagle_tree.py @@ -0,0 +1,347 @@ +import cutex +import torch + +# parent_table [bs,topk*depth+)] +# selected_index [bs,draft_token_num-1)] +# verified_seq_len [bs] +# tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] = [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token] +# positions [bs*draft_token] +# retrive_index [b, draft_token, depth+2] +kernels = cutex.SourceModule( + """ +//cuda +__global__ void build_tree(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, + Tensor tree_mask, Tensor positions, Tensor retrive_index, int topk, int depth, int draft_token_num) { + int bid = blockIdx.x; + int tid = threadIdx.x; + if (tid >= draft_token_num){ + return; + } + int seq_tree_idx = draft_token_num * draft_token_num * bid; + for(int i=0; i 1: + index = tl.arange(0, draft_token_num) + mask_left = index != max_index + remained_index = tl.where(mask_max and mask_left, index, 0) + max_index = tl.max(remained_index) + + tl.store(accept_length + pid, accept_len) + retrive_index_ptr = retrive_index + (retrive_start + max_index) * max_len + retrive_offset = tl.arange(0, max_len_upper) + retrive_load_mask = retrive_offset < accept_len + 1 + data = tl.load(retrive_index_ptr + retrive_offset, mask=retrive_load_mask) + + tl.store( + accept_index + pid * max_len + retrive_offset, data, mask=retrive_load_mask + ) + + extract_load_ptr = accept_index + pid * max_len + accept_len + if accept_len == max_len - 1: + extract_data = tl.load(extract_load_ptr - 1) + tl.store(extract_index + pid * 2, extract_data) + extract_data = tl.load(extract_load_ptr) + tl.store(extract_index + pid * 2 + 1, extract_data) + + else: + extract_data = tl.load(extract_load_ptr) + tl.store(extract_index + pid * 2, extract_data) + + +@triton.jit +def create_extend_spec_info( + verified_id, + seq_len, + accept_len, + accept_len_cum, + positions, + new_verified_id, + accept_len_upper: tl.constexpr, +): + pid = tl.program_id(axis=0) + offset = 0 if pid == 0 else tl.load(accept_len_cum + pid - 1) + seq_length = tl.load(seq_len + pid) + accept_length = tl.load(accept_len + pid) + positions_ptr = positions + offset + data = tl.arange(0, accept_len_upper) + mask = data < accept_length + tl.store(positions_ptr + data, seq_length - accept_length + data, mask) + + offset = tl.load(accept_len_cum + pid) - 1 + verified_id_data = tl.load(verified_id + offset) + tl.store(new_verified_id + pid, verified_id_data) + + +@triton.jit +def assign_req_to_token_pool( + req_pool_indices, + req_to_token, + start_offset, + end_offset, + out_cache_loc, + pool_len: tl.constexpr, + bs_upper: tl.constexpr, +): + BLOCK_SIZE: tl.constexpr = 32 + pid = tl.program_id(axis=0) + kv_start = tl.load(start_offset + pid) + kv_end = tl.load(end_offset + pid) + token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len + + length_offset = tl.arange(0, bs_upper) + start = tl.load(start_offset + length_offset, mask=length_offset < pid) + end = tl.load(end_offset + length_offset, mask=length_offset < pid) + out_offset = tl.sum(end - start, axis=0) + + out_cache_ptr = out_cache_loc + out_offset + + save_offset = tl.arange(0, BLOCK_SIZE) + kv_start + load_offset = tl.arange(0, BLOCK_SIZE) + + num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE) + for _ in range(num_loop): + mask = save_offset < kv_end + data = tl.load(out_cache_ptr + load_offset, mask=mask) + tl.store(token_pool + save_offset, data, mask=mask) + save_offset += BLOCK_SIZE + load_offset += BLOCK_SIZE + + +@triton.jit +def generate_draft_decode_kv_indices( + req_pool_indices, + req_to_token, + paged_kernel_lens, + kv_indices, + iters: tl.constexpr, + topk: tl.constexpr, + pool_len: tl.constexpr, + bs_upper: tl.constexpr, + iter_upper: tl.constexpr, +): + BLOCK_SIZE: tl.constexpr = 128 + bid = tl.program_id(axis=0) + topk_id = tl.program_id(axis=1) + + load_offset = tl.arange(0, bs_upper) + seq_lens = tl.load(paged_kernel_lens + load_offset, mask=load_offset < bid) + seq_len = tl.load(paged_kernel_lens + bid) + cum_seq_len = tl.sum(seq_lens) + + kv_offset = cum_seq_len * topk + bid * iters * topk + topk_id * (seq_len + iters) + kv_ptr = kv_indices + kv_offset + token_pool_ptr = req_to_token + tl.load(req_pool_indices + bid) * pool_len + + kv_offset = tl.arange(0, BLOCK_SIZE) + num_loop = tl.cdiv(seq_len, BLOCK_SIZE) + for _ in range(num_loop): + mask = kv_offset < seq_len + data = tl.load(token_pool_ptr + kv_offset, mask=mask) + tl.store(kv_ptr + kv_offset, data, mask=mask) + kv_offset += BLOCK_SIZE + + extend_offset = tl.arange(0, iter_upper) + extend_data = tl.load( + token_pool_ptr + seq_len + tl.arange(0, iter_upper) * topk + topk_id, + mask=extend_offset < iters, + ) + tl.store(kv_ptr + seq_len + extend_offset, extend_data, mask=extend_offset < iters) + + +class EAGLEDraftInput(SpecInfo): + hidden_states: torch.Tensor = None + verified_id: torch.Tensor = None + positions: torch.Tensor = None + accept_length: torch.Tensor = None + has_finished: bool = False + unfinished_index: List[int] = None + + def init(self, server_args: ServerArgs): + self.prev_mode = ForwardMode.DECODE + self.sample_output = None + self.topk: int = server_args.speculative_eagle_topk + self.num_verify_token: int = server_args.speculative_num_draft_tokens + self.spec_steps = server_args.speculative_num_steps + + self.scores: torch.Tensor = None + self.score_list: List[torch.Tensor] = [] + self.token_list: List[torch.Tensor] = [] + self.origin_score_list: List[torch.Tensor] = [] # used for sampling + self.parents_list: List[torch.Tensor] = [] + self.cache_list: List[torch.Tenor] = [] + self.iter = 0 + self.root_token: int = None + + assert self.topk <= 10, "topk should <= 10" + + def prepare_for_extend(self, batch: ForwardBatch): + req_pool_indices = batch.alloc_req_slots(len(batch.reqs)) + out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel()) + batch.out_cache_loc = out_cache_loc + + pt = 0 + for i, req in enumerate(batch.reqs): + req.req_pool_idx = req_pool_indices[i] + pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids) + assert seq_len - pre_len == req.extend_input_len + + if pre_len > 0: + batch.req_to_token_pool.req_to_token[req.req_pool_idx][ + :pre_len + ] = req.prefix_indices + + batch.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = ( + out_cache_loc[pt : pt + req.extend_input_len] + ) + + pt += req.extend_input_len + + seq_lens = [0] + batch.extend_lens + input_ids = batch.input_ids.tolist() + verified_id = batch.spec_info.verified_id.tolist() + model_input_ids = [] + for i in range(len(seq_lens) - 1): + model_input_ids.extend( + input_ids[seq_lens[i] + 1 : seq_lens[i + 1]] + [verified_id[i]] + ) + batch.input_ids = torch.tensor( + model_input_ids, dtype=torch.int32, device="cuda" + ) + + def capture_for_decode( + self, + sample_output: SampleOutput, + hidden_states: torch.Tensor, + prev_mode: ForwardMode, + ): + self.sample_output = sample_output + self.prev_mode = prev_mode + self.hidden_states = hidden_states + + def prepare_for_decode(self, batch: ScheduleBatch): + prob = self.sample_output # b * (1/topk), vocab + top = torch.topk(prob, self.topk, dim=-1) + topk_index, topk_p = top.indices, top.values # b * (1/topk), topk + if self.prev_mode == ForwardMode.DECODE: + scores = torch.mul( + self.scores.unsqueeze(2), topk_p.reshape(-1, self.topk, self.topk) + ) # (b, topk) mul (b * topk ,topk) -> b, topk, topk + topk_cs = torch.topk( + scores.flatten(start_dim=1), self.topk, dim=-1 + ) # (b, topk) + topk_cs_index, topk_cs_p = topk_cs.indices, topk_cs.values + self.scores = topk_cs_p + + selected_input_index = topk_cs_index.flatten() // self.topk # b* topk + + batch.spec_info.hidden_states = batch.spec_info.hidden_states[ + selected_input_index, : + ] + topk_index = topk_index.reshape(-1, self.topk**2) + batch.input_ids = torch.gather( + topk_index, index=topk_cs_index, dim=1 + ).flatten() + batch.out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel()) + self.score_list.append(scores) # b, topk, topk + self.token_list.append(topk_index) # b, topk*topk + self.origin_score_list.append(topk_p.reshape(topk_index.shape)) + self.parents_list.append( + topk_cs_index + (self.topk**2 * (self.iter - 1) + self.topk) + ) # b, topk + + elif self.prev_mode in (ForwardMode.EXTEND, ForwardMode.DRAFT_EXTEND): + self.scores = topk_p # b, top_k + self.score_list.append(topk_p.unsqueeze(1)) + self.token_list.append(topk_index) + self.origin_score_list.append(topk_p) + batch.spec_info.hidden_states = ( + batch.spec_info.hidden_states.repeat_interleave(self.topk, 0) + ) + batch.input_ids = topk_index.flatten() + batch.out_cache_loc = batch.alloc_token_slots(topk_index.numel()) + self.parents_list.append( + torch.arange(-1, self.topk, dtype=torch.long, device="cuda") + .unsqueeze(0) + .repeat(self.scores.shape[0], 1) + ) # b, topk+1 + self.cache_list.append(batch.out_cache_loc) + self.positions = ( + batch.seq_lens[:, None] + + torch.ones([1, self.topk], device="cuda", dtype=torch.long) * self.iter + ).flatten() + + bs = batch.seq_lens.numel() + assign_req_to_token_pool[(bs,)]( + batch.req_pool_indices, + batch.req_to_token_pool.req_to_token, + batch.seq_lens + self.topk * self.iter, + batch.seq_lens + self.topk * (self.iter + 1), + batch.out_cache_loc, + batch.req_to_token_pool.req_to_token.shape[1], + triton.next_power_of_2(bs), + ) + self.iter += 1 + + def prepare_extend_after_decode(self, batch: ScheduleBatch): + batch.out_cache_loc = batch.alloc_token_slots(self.verified_id.numel()) + batch.extend_lens = (self.accept_length + 1).tolist() + + pt = 0 + seq_lens = batch.seq_lens.tolist() + + i = 0 + + for req in batch.reqs: + if req.finished(): + continue + # assert seq_len - pre_len == req.extend_input_len + input_len = self.accept_length[i] + 1 + seq_len = seq_lens[i] + batch.req_to_token_pool.req_to_token[req.req_pool_idx][ + seq_len - input_len : seq_len + ] = batch.out_cache_loc[pt : pt + input_len] + pt += input_len + i += 1 + + self.positions = torch.empty_like(self.verified_id) + new_verified_id = torch.empty_like(self.accept_length, dtype=torch.long) + self.accept_length.add_(1) + + create_extend_spec_info[(self.accept_length.numel(),)]( + self.verified_id, + batch.seq_lens, + self.accept_length, + torch.cumsum(self.accept_length, axis=0, dtype=torch.int), + self.positions, + new_verified_id, + triton.next_power_of_2(self.spec_steps + 1), + ) + + batch.input_ids = self.verified_id + self.verified_id = new_verified_id + + def prepare_for_verify(self, batch: ScheduleBatch): + score_list = torch.cat(self.score_list, dim=1).flatten( + 1 + ) # b, n, topk; n= 1+(self.iter-1)*self.topk + ss_token_list = torch.cat( + self.token_list, dim=1 + ) # b, (self.topk+(self.iter-1)*self.topk) + origin_token_list = torch.cat(self.origin_score_list, dim=1) + top_scores = torch.topk(score_list, self.num_verify_token - 1, dim=-1) + top_scores_index = top_scores.indices + top_scores_index = torch.sort(top_scores_index).values + + draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1) + scores = torch.gather(origin_token_list, index=top_scores_index, dim=1) + draft_tokens = torch.cat((self.verified_id.unsqueeze(1), draft_tokens), dim=1) + parent_list = torch.cat(self.parents_list[:-1], dim=1) + + tree_mask, position, retrive_index, retrive_cum_len = build_tree_kernel( + parent_list, + top_scores_index, + batch.seq_lens, + self.topk, + self.iter - 1, + self.num_verify_token, + ) + + return EagleVerifyInput( + draft_tokens.flatten(), + scores.flatten(), + tree_mask, + position, + retrive_index, + retrive_cum_len, + self.num_verify_token, + ) + + def generate_attn_arg_decode( + self, + req_pool_indices: torch.Tensor, + paged_kernel_lens: torch.Tensor, + req_to_token: torch.Tensor, + ): + seq_num = req_pool_indices.numel() + bs = self.topk * req_pool_indices.numel() + seq_len = self.positions.reshape(-1).contiguous() + + cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda") + cum_kv_seq_len[1:] = torch.cumsum(seq_len + 1, dim=0) + total_len = torch.sum(paged_kernel_lens).item() + + kv_indices = torch.empty( + (total_len * self.topk + seq_num * self.iter * self.topk,), + dtype=torch.int32, + device="cuda", + ) + + generate_draft_decode_kv_indices[(req_pool_indices.numel(), self.topk)]( + req_pool_indices, + req_to_token, + paged_kernel_lens, + kv_indices, + self.iter, + self.topk, + req_to_token.shape[1], + triton.next_power_of_2(seq_num), + triton.next_power_of_2(self.spec_steps), + ) + return bs, kv_indices, cum_kv_seq_len + + def clear(self): + self.iter = 0 + self.score_list.clear() + self.positions = None + + def clear_draft_cache(self, batch): + draft_cache = torch.cat(self.cache_list, dim=0) + batch.token_to_kv_pool.free(draft_cache) + + def generate_attn_arg_prefill( + self, + req_pool_indices: torch.Tensor, + paged_kernel_lens: torch.Tensor, + req_to_token: torch.Tensor, + ): + bs = self.accept_length.numel() + qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda") + qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0) + + cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda") + cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0) + kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda") + + create_flashinfer_kv_indices_triton[(bs,)]( + req_to_token, + req_pool_indices, + paged_kernel_lens, + cum_kv_seq_len, + None, + kv_indices, + req_to_token.size(1), + ) + + return kv_indices, cum_kv_seq_len, qo_indptr, None + + def merge_batch(self, spec_info: EAGLEDraftInput): + + self.hidden_states = torch.cat( + [self.hidden_states, spec_info.hidden_states], axis=0 + ) + self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0) + # self.positions = torch.cat([self.positions, spec_info.positions], axis=0) + self.sample_output = torch.cat([self.sample_output, spec_info.sample_output]) + + +class EagleVerifyInput(SpecInfo): + def __init__( + self, + draft_token: torch.Tensor, + draft_score: torch.Tensor, + tree_mask: torch.Tensor, + positions: torch.Tensor, + retrive_index: torch.Tensor, + retrive_cum_len: torch.Tensor, + draft_token_num: int, + ): + self.draft_token = draft_token + self.draft_score = draft_score + self.custom_mask = tree_mask + self.positions = positions + self.retrive_index = retrive_index + self.retrive_cum_len = retrive_cum_len + self.draft_token_num = draft_token_num + + def prepare_for_verify(self, batch: ScheduleBatch): + batch.input_ids = self.draft_token + batch.out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel()) + bs = batch.seq_lens.numel() + assign_req_to_token_pool[(bs,)]( + batch.req_pool_indices, + batch.req_to_token_pool.req_to_token, + batch.seq_lens, + batch.seq_lens + self.draft_token_num, + batch.out_cache_loc, + batch.req_to_token_pool.req_to_token.shape[1], + triton.next_power_of_2(bs), + ) + + def generate_attn_arg_prefill( + self, + req_pool_indices: torch.Tensor, + paged_kernel_lens: torch.Tensor, + req_to_token: torch.Tensor, + ): + batch_size = len(req_pool_indices) + qo_indptr = torch.arange( + 0, + (1 + batch_size) * self.draft_token_num, + step=self.draft_token_num, + dtype=torch.int32, + device="cuda", + ) + + cum_kv_seq_len = torch.zeros( + (batch_size + 1,), dtype=torch.int32, device="cuda" + ) + + paged_kernel_lens = paged_kernel_lens + self.draft_token_num + cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0) + + kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda") + + create_flashinfer_kv_indices_triton[(batch_size,)]( + req_to_token, + req_pool_indices, + paged_kernel_lens, + cum_kv_seq_len, + None, + kv_indices, + req_to_token.size(1), + ) + return kv_indices, cum_kv_seq_len, qo_indptr, self.custom_mask + + def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Tensor: + predict = torch.argmax(logits_output.next_token_logits, dim=-1) + predict = torch.cat( + [predict, torch.full([1], -1, dtype=torch.long, device="cuda")], dim=-1 + ) + draft_token = torch.cat( + [self.draft_token, torch.full([1], -1, dtype=torch.long, device="cuda")], + dim=-1, + ) + target_predict = predict[self.retrive_index] + candidates = draft_token[self.retrive_index] + # logits = logits_output.next_token_logits[self.retrive_index] + # target_predict = torch.argmax(logits[:, :-1], dim=-1) + accept_mask = candidates[:, 1:] == target_predict[:, :-1] + accept_mask = (torch.cumprod(accept_mask, dim=1)).sum(dim=1) + bs = self.retrive_cum_len.numel() - 1 + + max_draft_len = self.retrive_index.shape[-1] + accept_index = torch.full( + (bs, max_draft_len), -1, dtype=torch.long, device="cuda" + ) + accept_length = torch.empty((bs,), dtype=torch.int, device="cuda") + extract_index = torch.full((bs * 2,), 0, dtype=torch.int, device="cuda") + eagle_verify_retrive[(bs,)]( + self.retrive_index.contiguous(), + accept_mask.contiguous(), + self.retrive_cum_len, + accept_index, + accept_length, + extract_index, + max_draft_len, + self.draft_token_num, + triton.next_power_of_2(max_draft_len), + ) + + accept_index = accept_index[accept_index != -1] + # extract_index = extract_index[extract_index != 0] + + draft_input = EAGLEDraftInput() + + accept_length_cpu = accept_length.tolist() + verified_id = predict[accept_index] + verified_id_cpu = verified_id.tolist() + + evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool) + evict_mask[accept_index] = False + mem_need_free_idx = batch.out_cache_loc[evict_mask] + batch.token_to_kv_pool.free(mem_need_free_idx) + assign_req_to_token_pool[(bs,)]( + batch.req_pool_indices, + batch.req_to_token_pool.req_to_token, + batch.seq_lens, + batch.seq_lens + accept_length + 1, + batch.out_cache_loc[accept_index], + batch.req_to_token_pool.req_to_token.shape[1], + triton.next_power_of_2(bs), + ) + batch.seq_lens.add_(accept_length + 1) + new_accept_index = [] + unfinished_index = [] + finished_extend_len = {} # {rid:accept_length + 1} + # retracted_reqs, new_token_ratio = batch.retract_decode() + + low = 0 + for i, (req, verified_len) in enumerate(zip(batch.reqs, accept_length_cpu)): + req.output_ids.extend(verified_id_cpu[low : low + verified_len + 1]) + req.check_finished() + if req.finished(): + draft_input.has_finished = True + finished_extend_len[req.rid] = verified_len + 1 + else: + new_accept_index.append(accept_index[low : low + verified_len + 1]) + unfinished_index.append(i) + low += verified_len + 1 + + if len(new_accept_index) > 0: + new_accept_index = torch.cat(new_accept_index, dim=0) + draft_input.verified_id = predict[new_accept_index] + draft_input.hidden_states = batch.spec_info.hidden_states[new_accept_index] + draft_input.accept_length = accept_length[unfinished_index] + draft_input.unfinished_index = unfinished_index + + logits_output.next_token_logits = logits_output.next_token_logits[accept_index] + return draft_input, logits_output, verified_id, finished_extend_len diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py new file mode 100644 index 000000000..6701c66ac --- /dev/null +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -0,0 +1,170 @@ +from typing import List, Optional, Union + +import torch + +from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.managers.schedule_batch import Req, ScheduleBatch +from sglang.srt.managers.tp_worker import TpModelWorker +from sglang.srt.model_executor.forward_batch_info import ( + CaptureHiddenMode, + ForwardBatch, + ForwardMode, +) +from sglang.srt.model_executor.model_runner import ModelRunner +from sglang.srt.server_args import ServerArgs +from sglang.srt.speculative.eagle_utils import EAGLEDraftInput + + +class EAGLEWorker(TpModelWorker): + + def __init__( + self, + server_args: ServerArgs, + gpu_id: int, + tp_rank: int, + dp_rank: Optional[int], + nccl_port: int, + target_worker: TpModelWorker, + ): + # Do not capture cuda graph in `super().__init__()` + # We will capture it later + backup_disable_cuda_graph = server_args.disable_cuda_graph + server_args.disable_cuda_graph = True + super().__init__( + gpu_id=gpu_id, + tp_rank=tp_rank, + server_args=server_args, + nccl_port=nccl_port, + dp_rank=dp_rank, + is_draft_worker=True, + ) + self.target_worker = target_worker + self.server_args = server_args + + # Share the embedding and lm_head + embed, head = self.target_worker.model_runner.model.get_embed_and_head() + self.model_runner.model.set_embed_and_head(embed, head) + self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph + self.model_runner.init_cuda_graphs() + + def forward_draft_decode(self, batch: ScheduleBatch): + batch.spec_info.prepare_for_decode(batch) + model_worker_batch = batch.get_model_worker_batch() + forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) + forward_batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST + logits_output = self.model_runner.forward(forward_batch) + self.capture_for_decode(logits_output, forward_batch) + + def forward_draft_extend(self, batch: ScheduleBatch): + self._swap_mem_pool(batch, self.model_runner) + batch.spec_info.prepare_for_extend(batch) + model_worker_batch = batch.get_model_worker_batch() + forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) + forward_batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST + logits_output = self.model_runner.forward(forward_batch) + self.capture_for_decode(logits_output, forward_batch) + self._swap_mem_pool(batch, self.target_worker.model_runner) + + def forward_batch_speculative_generation(self, batch: ScheduleBatch): + if batch.forward_mode.is_decode(): + prev_spec_info = batch.spec_info + self._swap_mem_pool(batch, self.model_runner) + for i in range(self.server_args.speculative_num_steps): + self.forward_draft_decode(batch) + batch.spec_info.clear_draft_cache(batch) + self._swap_mem_pool(batch, self.target_worker.model_runner) + ( + next_draft_input, + logits_output, + verified_id, + self.finish_extend_len, + model_worker_batch, + ) = self.verify(batch) + next_draft_input.init(self.server_args) + batch.spec_info = next_draft_input + # if it is None, means all requsets are finished + if batch.spec_info.verified_id is not None: + self.forward_extend_after_decode(batch) + batch.spec_info = prev_spec_info + return logits_output, verified_id, model_worker_batch, next_draft_input + + else: + spec_info = EAGLEDraftInput() + spec_info.init(self.server_args) + model_worker_batch = batch.get_model_worker_batch() + model_worker_batch.spec_info = spec_info + spec_info.capture_hidden_mode = CaptureHiddenMode.FULL + logits_output, next_token_ids = self.target_worker.forward_batch_generation( + model_worker_batch + ) + model_worker_batch.spec_info.verified_id = next_token_ids + model_worker_batch.spec_info.hidden_states = logits_output.hidden_states + batch.spec_info = spec_info + self.forward_draft_extend(batch) + batch.spec_info = None + return logits_output, next_token_ids, model_worker_batch, spec_info + + def verify(self, batch: ScheduleBatch): + verify_input = batch.spec_info.prepare_for_verify(batch) + batch.forward_mode = ForwardMode.TARGET_VERIFY + verify_input.prepare_for_verify(batch) + batch.spec_info = verify_input + batch.spec_info.capture_hidden_mode = CaptureHiddenMode.FULL + model_worker_batch = batch.get_model_worker_batch() + logits_output, _ = self.target_worker.forward_batch_generation( + model_worker_batch, skip_sample=True + ) + verify_input.hidden_states = logits_output.hidden_states + res = verify_input.verify(batch, logits_output) + batch.forward_mode = ForwardMode.DECODE + return res + (model_worker_batch,) + + def _swap_mem_pool(self, batch: ScheduleBatch, runner: ModelRunner): + batch.token_to_kv_pool = runner.token_to_kv_pool + batch.req_to_token_pool = runner.req_to_token_pool + + def forward_extend_after_decode(self, batch: ScheduleBatch): + self._swap_mem_pool(batch, self.model_runner) + batch.forward_mode = ForwardMode.DRAFT_EXTEND + if batch.spec_info.has_finished: + index = batch.spec_info.unfinished_index + seq_lens = batch.seq_lens + batch.seq_lens = batch.seq_lens[index] + batch.spec_info.prepare_extend_after_decode(batch) + model_worker_batch = batch.get_model_worker_batch() + forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) + forward_batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST + logits_output = self.model_runner.forward(forward_batch) + batch.spec_info.hidden_states = logits_output.hidden_states + self.capture_for_decode(logits_output, forward_batch) + batch.forward_mode = ForwardMode.DECODE + if batch.spec_info.has_finished: + batch.seq_lens = seq_lens + self._swap_mem_pool(batch, self.target_worker.model_runner) + + def capture_for_decode(self, logits_output, forward_batch): + if isinstance(logits_output, LogitsProcessorOutput): + logits = logits_output.next_token_logits + sample_output = torch.softmax( + logits, dim=-1 + ) # TODO: Support more sampling method @kavioyu + forward_batch.spec_info.capture_for_decode( + sample_output, logits_output.hidden_states, forward_batch.forward_mode + ) + + # Don't support prefix share now. + def finish_request(self, reqs: Union[Req, List[Req]]): + if not isinstance(reqs, List): + reqs = [reqs] + for req in reqs: + req_len = ( + len(req.origin_input_ids) + + len(req.output_ids) + - self.finish_extend_len[req.rid] + - 1 + ) + kv_indices = self.model_runner.req_to_token_pool.req_to_token[ + req.req_pool_idx + ][:req_len] + self.model_runner.token_to_kv_pool.free(kv_indices) + self.model_runner.req_to_token_pool.free(req.req_pool_idx) diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index d670a2d35..f8a935894 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -13,6 +13,7 @@ suites = { "test_abort.py", "test_chunked_prefill.py", "test_double_sparsity.py", + "test_eagle_infer.py", "test_embedding_openai_server.py", "test_eval_accuracy_mini.py", "test_get_weights_by_name.py", diff --git a/test/srt/test_eagle_infer.py b/test/srt/test_eagle_infer.py new file mode 100644 index 000000000..609d4411d --- /dev/null +++ b/test/srt/test_eagle_infer.py @@ -0,0 +1,39 @@ +import unittest + +import sglang as sgl + + +class TestEAGLEEngine(unittest.TestCase): + + def test_eagle_accuracy(self): + prompt = "Today is a sunny day and I like" + target_model_path = "meta-llama/Llama-2-7b-chat-hf" + speculative_draft_model_path = "lmzheng/sglang-EAGLE-llama2-chat-7B" + + sampling_params = {"temperature": 0, "max_new_tokens": 8} + + engine = sgl.Engine( + model_path=target_model_path, + speculative_draft_model_path=speculative_draft_model_path, + speculative_algorithm="EAGLE", + speculative_num_steps=3, + speculative_eagle_topk=4, + speculative_num_draft_tokens=16, + ) + out1 = engine.generate(prompt, sampling_params)["text"] + engine.shutdown() + + engine = sgl.Engine(model_path=target_model_path) + out2 = engine.generate(prompt, sampling_params)["text"] + engine.shutdown() + + print("==== Answer 1 ====") + print(out1) + + print("==== Answer 2 ====") + print(out2) + self.assertEqual(out1, out2) + + +if __name__ == "__main__": + unittest.main()