diff --git a/python/pyproject.toml b/python/pyproject.toml index ebc8139f2..5820e73ed 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -43,7 +43,7 @@ runtime_common = [ srt = [ "sglang[runtime_common]", - "sgl-kernel==0.0.5.post1", + "sgl-kernel==0.0.5.post2", "flashinfer_python==0.2.3", "torch==2.5.1", "vllm>=0.6.4.post1,<=0.7.2", diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index cf95f0fb2..32a11c15c 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -283,7 +283,7 @@ async def classify_request(obj: EmbeddingReqInput, request: Request): return _create_error_response(e) -@app.post("/flush_cache") +@app.api_route("/flush_cache", methods=["GET", "POST"]) async def flush_cache(): """Flush the radix cache.""" _global_state.tokenizer_manager.flush_cache() diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 558141d74..b74dcc39d 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -895,7 +895,6 @@ class Scheduler(SchedulerOutputProcessorMixin): f"#token: {num_used}, " f"token usage: {num_used / self.max_total_num_tokens:.2f}, " f"gen throughput (token/s): {self.last_gen_throughput:.2f}, " - f"largest-len: {self._largest_prefill_decode_len}, " f"#queue-req: {len(self.waiting_queue)}, " ) spec_accept_length = 0 @@ -913,7 +912,6 @@ class Scheduler(SchedulerOutputProcessorMixin): f"token usage: {num_used / self.max_total_num_tokens:.2f}, " f"accept len: {spec_accept_length:.2f}, " f"gen throughput (token/s): {self.last_gen_throughput:.2f}, " - f"largest-len: {self._largest_prefill_decode_len}, " f"#queue-req: {len(self.waiting_queue)}, " ) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 99ef14c2c..445476f07 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -117,7 +117,9 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner): else: capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)] else: - capture_bs = list(range(1, 33)) + # Since speculative decoding requires more cuda graph memory, we + # capture less. + capture_bs = list(range(1, 9)) + list(range(9, 33, 2)) + [64, 96, 128, 160] if _is_hip: capture_bs += [i * 8 for i in range(21, 33)] @@ -125,16 +127,11 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner): if max(capture_bs) > model_runner.req_to_token_pool.size: # In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests # is very small. We add more values here to make sure we capture the maximum bs. - capture_bs = list( - sorted( - set( - capture_bs - + [model_runner.req_to_token_pool.size - 1] - + [model_runner.req_to_token_pool.size] - ) - ) - ) + capture_bs += [model_runner.req_to_token_pool.size - 1] + [ + model_runner.req_to_token_pool.size + ] + capture_bs = list(sorted(set(capture_bs))) capture_bs = [ bs for bs in capture_bs @@ -508,7 +505,9 @@ class CudaGraphRunner: self.raw_num_token = raw_num_token self.bs = bs - def replay(self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False): + def replay( + self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False + ) -> LogitsProcessorOutput: if not skip_attn_backend_init: self.replay_prepare(forward_batch) else: diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 554c9592d..9bdcc2d3b 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -285,7 +285,6 @@ class ServerArgs: if self.speculative_algorithm == "EAGLE": if self.max_running_requests is None: self.max_running_requests = 32 - self.disable_cuda_graph_padding = True self.disable_overlap_schedule = True logger.info( "Overlap scheduler is disabled because of using " diff --git a/python/sglang/srt/speculative/build_eagle_tree.py b/python/sglang/srt/speculative/build_eagle_tree.py index fba411479..b26d2c2e2 100644 --- a/python/sglang/srt/speculative/build_eagle_tree.py +++ b/python/sglang/srt/speculative/build_eagle_tree.py @@ -3,8 +3,13 @@ from typing import List import torch -from sgl_kernel import build_tree_kernel as sgl_build_tree_kernel -from sgl_kernel import build_tree_kernel_efficient as sgl_build_tree_kernel_efficient + +from sglang.srt.utils import is_cuda_available + +if is_cuda_available(): + from sgl_kernel import ( + build_tree_kernel_efficient as sgl_build_tree_kernel_efficient, + ) def build_tree_kernel_efficient_preprocess( @@ -23,7 +28,6 @@ def build_tree_kernel_efficient_preprocess( top_scores = torch.topk(score_list, num_verify_tokens - 1, dim=-1) top_scores_index = top_scores.indices top_scores_index = torch.sort(top_scores_index).values - draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1) draft_tokens = torch.cat((verified_id.unsqueeze(1), draft_tokens), dim=1).flatten() @@ -108,296 +112,6 @@ def build_tree_kernel_efficient( ) -def build_tree_kernel( - verified_id: torch.Tensor, - score_list: List[torch.Tensor], - token_list: List[torch.Tensor], - parents_list: List[torch.Tensor], - seq_lens: torch.Tensor, - seq_lens_sum: int, - topk: int, - spec_steps: int, - num_verify_tokens: int, -): - parent_list, top_scores_index, draft_tokens = ( - build_tree_kernel_efficient_preprocess( - verified_id, - score_list, - token_list, - parents_list, - num_verify_tokens, - ) - ) - - bs = seq_lens.numel() - device = seq_lens.device - - tree_mask = torch.full( - ( - seq_lens_sum * num_verify_tokens - + num_verify_tokens * num_verify_tokens * bs, - ), - True, - device=device, - ) - retrive_index = torch.full( - (bs, num_verify_tokens, spec_steps + 2), -1, device=device, dtype=torch.long - ) - positions = torch.empty((bs * num_verify_tokens,), device=device, dtype=torch.long) - - sgl_build_tree_kernel( - parent_list, - top_scores_index, - seq_lens.to(torch.int32), - tree_mask, - positions, - retrive_index, - topk, - spec_steps, - num_verify_tokens, - ) - - index = retrive_index.sum(dim=-1) != -spec_steps - 2 - cum_len = torch.cumsum(torch.sum(index, dim=-1), dim=-1) - retrive_cum_len = torch.zeros( - (cum_len.numel() + 1,), dtype=torch.int32, device="cuda" - ) - retrive_cum_len[1:] = cum_len - # TODO: this indexing cause a synchronization, optimize this - retrive_index = retrive_index[index] - return tree_mask, positions, retrive_index, retrive_cum_len, draft_tokens - - -def test_build_tree_kernel(): - def findp(p_i, index, parent_list): - pos = index // 10 - index_list = index.tolist() - parent_list = parent_list.tolist() - res = [p_i] - while True: - p = pos[p_i] - if p == 0: - break - token_idx = parent_list[p] - p_i = index_list.index(token_idx) - res.append(p_i) - return res - - def create_mask(seq_len, draft_token, index, parent_list, max_depth): - mask = [] - positions = [] - retrive_index = [] - for i, lens in enumerate(seq_len.tolist()): - first_mask = torch.full((lens + draft_token,), True) - first_mask[-(draft_token - 1) :] = False - positions.append(lens) - mask.append(first_mask) - seq_order = [] - first_index = torch.Tensor([0] + [-1] * (depth + 1)).cuda().to(torch.long) - r_index = [first_index] - for j in range(draft_token - 1): - mask.append(torch.full((lens + 1,), True)) - idx = findp(j, index, parent_list) - - seq_order.append(idx) - positions.append(len(idx) + seq_len) - t = torch.full((draft_token - 1,), False) - t[idx] = True - mask.append(t) - - for i in range(1, draft_token - 1): - is_leaf = 0 - for j in range(draft_token - 1): - if i in seq_order[j]: - is_leaf += 1 - - if is_leaf == 1: - order_list = [0] + [x + 1 for x in seq_order[i][::-1]] - for _ in range(max_depth + 1 - len(seq_order[i])): - order_list.append(-1) - order = torch.Tensor(order_list).cuda().to(torch.long) - r_index.append(order) - retrive_index.append(torch.stack(r_index)) - - return ( - torch.cat(mask).cuda(), - torch.Tensor(positions).cuda().to(torch.long), - torch.stack(retrive_index), - ) - - index = ( - torch.Tensor( - [ - 0, - 1, - 2, - 3, - 10, - 11, - 12, - 13, - 20, - 21, - 22, - 30, - 110, - 130, - 150, - 160, - 210, - 211, - 212, - 213, - 214, - 215, - 216, - 217, - 218, - 219, - 220, - 230, - 310, - 311, - 312, - 313, - 314, - 315, - 316, - 317, - 320, - 321, - 322, - 330, - 360, - 380, - 390, - 410, - 411, - 412, - 413, - 414, - 415, - 416, - 417, - 418, - 419, - 420, - 421, - 422, - 423, - 430, - 431, - 440, - 441, - 460, - 470, - ] - ) - .to(torch.long) - .cuda() - ) - - parent_list = ( - torch.Tensor( - [ - -1, - 0, - 1, - 2, - 3, - 4, - 5, - 6, - 7, - 8, - 9, - 10, - 11, - 12, - 20, - 30, - 21, - 13, - 22, - 40, - 23, - 110, - 130, - 160, - 150, - 190, - 120, - 111, - 121, - 200, - 180, - 210, - 211, - 212, - 213, - 214, - 215, - 216, - 220, - 230, - 217, - 310, - 311, - 312, - 313, - 320, - 314, - 321, - 315, - 316, - 317, - ] - ) - .to(torch.long) - .cuda() - ) - - verified_seq_len = torch.Tensor([47]).to(torch.long).cuda() - bs = verified_seq_len.shape[0] - topk = 10 - depth = 5 # depth <= 10 - num_draft_token = 64 - - tree_mask = torch.full( - ( - torch.sum(verified_seq_len).item() * num_draft_token - + num_draft_token * num_draft_token * bs, - ), - True, - ).cuda() - retrive_index = torch.full( - (bs, num_draft_token, depth + 2), -1, device="cuda", dtype=torch.long - ) - positions = torch.empty((bs * num_draft_token,), device="cuda", dtype=torch.long) - - sgl_build_tree_kernel( - parent_list.unsqueeze(0), - index.unsqueeze(0), - verified_seq_len, - tree_mask, - positions, - retrive_index, - topk, - depth, - num_draft_token, - ) - - retrive_index = retrive_index[retrive_index.sum(dim=-1) != -depth - 2] - - c_mask, c_positions, c_retive_index = create_mask( - verified_seq_len, num_draft_token, index, parent_list, depth - ) - - assert torch.allclose(tree_mask, c_mask), "tree mask has error." - assert torch.allclose(positions, c_positions), "positions has error." - assert torch.allclose(retrive_index, c_retive_index), "retrive_index has error." - - def test_build_tree_kernel_efficient(): verified_id = torch.tensor([29974, 13], device="cuda", dtype=torch.int32) score_list = [ @@ -611,59 +325,6 @@ def test_build_tree_kernel_efficient(): depth = 4 num_draft_token = 8 - tree_mask, position, retrive_index, retrive_cum_len, draft_tokens = ( - build_tree_kernel( - verified_id=verified_id, - score_list=score_list, - token_list=token_list, - parents_list=parents_list, - seq_lens=seq_lens, - seq_lens_sum=torch.sum(seq_lens).item(), - topk=topk, - spec_steps=depth, - num_verify_tokens=num_draft_token, - ) - ) - - from sglang.srt.utils import first_rank_print - - first_rank_print("=========== build tree kernel ==========") - # first_rank_print(f"{tree_mask=}", flush=True) - first_rank_print(f"{position=}", flush=True) - first_rank_print(f"{retrive_index=}", flush=True) - first_rank_print(f"{retrive_cum_len=}", flush=True) - first_rank_print(f"{draft_tokens=}", flush=True) - assert position.tolist() == [5, 6, 6, 7, 7, 8, 8, 9, 10, 11, 12, 12, 12, 12, 13, 14] - assert retrive_index.tolist() == [ - [0, -1, -1, -1, -1, -1], - [0, 2, 4, 6, -1, -1], - [0, 1, 3, 5, 7, -1], - [8, -1, -1, -1, -1, -1], - [8, 9, 10, -1, -1, -1], - [8, 9, 12, -1, -1, -1], - [8, 9, 13, -1, -1, -1], - [8, 9, 11, 14, 15, -1], - ] - assert retrive_cum_len.tolist() == [0, 3, 8] - assert draft_tokens.tolist() == [ - 29974, - 29896, - 29906, - 29889, - 29974, - 29946, - 29896, - 29946, - 13, - 13, - 22550, - 4136, - 16492, - 8439, - 29871, - 29941, - ] - ( tree_mask, position, @@ -725,4 +386,3 @@ def test_build_tree_kernel_efficient(): if __name__ == "__main__": test_build_tree_kernel_efficient() - test_build_tree_kernel() diff --git a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py index e5410ec00..88ee3a486 100644 --- a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py @@ -22,6 +22,10 @@ from sglang.srt.speculative.eagle_utils import EagleDraftInput if TYPE_CHECKING: from sglang.srt.speculative.eagle_worker import EAGLEWorker +import logging + +logger = logging.getLogger(__name__) + class EAGLEDraftCudaGraphRunner: def __init__(self, eagle_worker: EAGLEWorker): @@ -33,13 +37,10 @@ class EAGLEDraftCudaGraphRunner: self.enable_torch_compile = model_runner.server_args.enable_torch_compile self.disable_padding = model_runner.server_args.disable_cuda_graph_padding self.tp_size = self.model_runner.tp_size - self.dp_size = model_runner.server_args.dp_size self.topk = model_runner.server_args.speculative_eagle_topk self.speculative_num_steps = model_runner.server_args.speculative_num_steps server_args = model_runner.server_args - assert self.disable_padding - # Batch sizes to capture self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner) self.num_tokens_per_bs = server_args.speculative_eagle_topk @@ -169,6 +170,13 @@ class EAGLEDraftCudaGraphRunner: set_global_graph_memory_pool(graph.pool()) return graph, out + def _postprocess_output_to_raw_bs(self, out, raw_bs): + score_list, token_list, parents_list = out + score_list = [x[:raw_bs] for x in score_list] + token_list = [x[:raw_bs] for x in token_list] + parents_list = [x[:raw_bs] for x in parents_list] + return (score_list, token_list, parents_list) + def replay(self, forward_batch: ForwardBatch): assert forward_batch.out_cache_loc is not None raw_bs = forward_batch.batch_size @@ -180,6 +188,9 @@ class EAGLEDraftCudaGraphRunner: if bs != raw_bs: self.seq_lens.fill_(1) self.out_cache_loc.zero_() + self.positions.zero_() + + num_tokens = bs * self.num_tokens_per_bs # Common inputs self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) @@ -193,11 +204,25 @@ class EAGLEDraftCudaGraphRunner: self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states) # Attention backend + if bs != raw_bs: + forward_batch.batch_size = bs + forward_batch.seq_lens = self.seq_lens[:bs] + forward_batch.req_pool_indices = self.req_pool_indices[:bs] + forward_batch.positions = self.positions[:num_tokens] + self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph( - forward_batch, forward_batch.batch_size + forward_batch, bs ) # Replay self.graphs[bs].replay() + out = self.output_buffers[bs] - return self.output_buffers[bs] + if bs != raw_bs: + out = self._postprocess_output_to_raw_bs(out, raw_bs) + forward_batch.batch_size = raw_bs + forward_batch.positions = self.positions[:raw_num_token] + forward_batch.seq_lens = self.seq_lens[:raw_bs] + forward_batch.req_pool_indices = self.req_pool_indices[:raw_bs] + + return out diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py index fa8dc5f21..3dc2a9699 100644 --- a/python/sglang/srt/speculative/eagle_utils.py +++ b/python/sglang/srt/speculative/eagle_utils.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, List, Optional import torch import torch.nn.functional as F @@ -13,18 +13,24 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode -from sglang.srt.speculative.build_eagle_tree import ( - build_tree_kernel, - build_tree_kernel_efficient, -) +from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient from sglang.srt.utils import is_cuda_available if is_cuda_available(): - from sgl_kernel import tree_speculative_sampling_target_only + from sgl_kernel import ( + top_k_renorm_prob, + top_p_renorm_prob, + tree_speculative_sampling_target_only, + verify_tree_greedy, + ) if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import ScheduleBatch +import logging + +logger = logging.getLogger(__name__) + @dataclass class EagleDraftInput: @@ -47,12 +53,9 @@ class EagleDraftInput: kv_indptr: torch.Tensor = None kv_indices: torch.Tensor = None - # indices of unfinished requests during extend-after-decode - # e.g. [0, 2, 3, 4] if only the 1st request is finished - keep_indices: List[int] = None + all_padding_lens: Optional[torch.Tensor] = None def prepare_for_extend(self, batch: ScheduleBatch): - assert batch.input_ids.numel() == batch.out_cache_loc.shape[0] # Prefill only generate 1 token. assert len(self.verified_id) == len(batch.seq_lens) @@ -64,27 +67,18 @@ class EagleDraftInput: ) pt += extend_len - def prepare_extend_after_decode(self, batch: ScheduleBatch, speculative_num_steps): - assert self.verified_id.numel() == batch.out_cache_loc.shape[0] + def prepare_extend_after_decode( + self, + batch: ScheduleBatch, + speculative_num_steps: int, + ): + assert len(self.verified_id) == len(batch.out_cache_loc) accept_length_cpu = batch.spec_info.accept_length_cpu batch.extend_lens = [x + 1 for x in accept_length_cpu] batch.extend_num_tokens = sum(batch.extend_lens) batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend + batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend seq_lens_cpu = batch.seq_lens.tolist() - assert len(batch.req_pool_indices) == len(batch.reqs) - - pt = 0 - i = 0 - self.keep_indices = [] - for idx, req in enumerate(batch.reqs): - if req.finished(): - continue - self.keep_indices.append(idx) - # assert seq_len - pre_len == req.extend_input_len - input_len = batch.extend_lens[i] - seq_len = seq_lens_cpu[i] - pt += input_len - i += 1 self.positions = torch.empty_like(self.verified_id, dtype=torch.long) new_verified_id = torch.empty_like(self.accept_length, dtype=torch.int32) @@ -112,10 +106,6 @@ class EagleDraftInput: req_to_token: torch.Tensor, ): bs = self.accept_length.numel() - keep_indices = torch.tensor(self.keep_indices, device=req_pool_indices.device) - req_pool_indices = req_pool_indices[keep_indices] - assert req_pool_indices.shape[0] == bs - assert req_pool_indices.shape[0] == paged_kernel_lens.shape[0] qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda") qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0) @@ -172,7 +162,7 @@ class EagleVerifyOutput: # Accepeted token length per sequence in a batch in CPU. accept_length_per_req_cpu: List[int] # Accepeted indices from logits_output.next_token_logits - accepeted_indices_cpu: List[int] + accepeted_indices: torch.Tensor @dataclass @@ -200,67 +190,38 @@ class EagleVerifyInput: topk: int, spec_steps: int, num_verify_tokens: int, - is_all_greedy: bool, ): - if is_all_greedy: - tree_mask, position, retrive_index, retrive_cum_len, draft_tokens = ( - build_tree_kernel( - verified_id, - score_list, # b, n, topk; n= 1 + (num_steps-1) * self.topk - token_list, - parents_list, - seq_lens, - seq_lens_sum, - topk, - spec_steps, - num_verify_tokens, - ) - ) + ( + tree_mask, + position, + retrive_index, + retrive_next_token, + retrive_next_sibling, + draft_tokens, + ) = build_tree_kernel_efficient( + verified_id, + score_list, + token_list, + parents_list, + seq_lens, + seq_lens_sum, + topk, + spec_steps, + num_verify_tokens, + ) - return cls( - draft_tokens, - tree_mask, - position, - retrive_index, - None, - None, - retrive_cum_len, - num_verify_tokens, - spec_steps, - CaptureHiddenMode.FULL, - ) - else: - ( - tree_mask, - position, - retrive_index, - retrive_next_token, - retrive_next_sibling, - draft_tokens, - ) = build_tree_kernel_efficient( - verified_id, - score_list, - token_list, - parents_list, - seq_lens, - seq_lens_sum, - topk, - spec_steps, - num_verify_tokens, - ) - - return cls( - draft_tokens, - tree_mask, - position, - retrive_index, - retrive_next_token, - retrive_next_sibling, - None, - num_verify_tokens, - spec_steps, - CaptureHiddenMode.FULL, - ) + return cls( + draft_tokens, + tree_mask, + position, + retrive_index, + retrive_next_token, + retrive_next_sibling, + None, + num_verify_tokens, + spec_steps, + CaptureHiddenMode.FULL, + ) def prepare_for_verify(self, batch: ScheduleBatch): batch.input_ids = self.draft_token @@ -291,7 +252,6 @@ class EagleVerifyInput: dtype=torch.int32, device="cuda", ) - cum_kv_seq_len = torch.zeros( (batch_size + 1,), dtype=torch.int32, device="cuda" ) @@ -304,7 +264,6 @@ class EagleVerifyInput: dtype=torch.int32, device="cuda", ) - create_flashinfer_kv_indices_triton[(batch_size,)]( req_to_token, req_pool_indices, @@ -322,65 +281,79 @@ class EagleVerifyInput: logits_output: torch.Tensor, token_to_kv_pool_allocator: TokenToKVPoolAllocator, ) -> torch.Tensor: - """WARNING: This API in-place modifies the states of logits_output - + """ Verify and find accepted tokens based on logits output and batch (which contains spec decoding information). + WARNING: This API in-place modifies the states of logits_output + This API updates values inside logits_output based on the accepted tokens. I.e., logits_output.next_token_logits only contains accepeted token logits. """ - draft_token = torch.cat( - [self.draft_token, torch.full([1], -1, dtype=torch.int32, device="cuda")], - dim=-1, - ) - candidates = draft_token[self.retrive_index] - if batch.sampling_info.is_all_greedy: - # temp == 0 - bs = self.retrive_cum_len.numel() - 1 - predict = torch.argmax(logits_output.next_token_logits, dim=-1) - predict = torch.cat( - [predict, torch.full([1], -1, dtype=torch.int32, device="cuda")], dim=-1 - ) - target_predict = predict[self.retrive_index] - # logits = logits_output.next_token_logits[self.retrive_index] - # target_predict = torch.argmax(logits[:, :-1], dim=-1) - accept_mask = candidates[:, 1:] == target_predict[:, :-1] + bs = self.retrive_index.shape[0] + candidates = self.draft_token.reshape(bs, self.draft_token_num) + sampling_info = batch.sampling_info - accept_mask = (torch.cumprod(accept_mask, dim=1)).sum(dim=1) - max_draft_len = self.retrive_index.shape[-1] - accept_index = torch.full( - (bs, max_draft_len), -1, dtype=torch.int32, device="cuda" + predict_shape = list(logits_output.next_token_logits.shape)[:-1] + predict_shape[-1] += 1 + predict = torch.empty(predict_shape, dtype=torch.int32, device="cuda") + accept_index = torch.full( + (bs, self.spec_steps + 1), -1, dtype=torch.int32, device="cuda" + ) + accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda") + + if sampling_info.penalizer_orchestrator.is_required: + # This is a relaxed version of penalties for speculative decoding. + linear_penalty = torch.zeros( + (bs, logits_output.next_token_logits.shape[1]), + dtype=torch.float32, + device="cuda", ) - accept_length = torch.empty((bs,), dtype=torch.int, device="cuda") - extract_index = torch.full((bs * 2,), 0, dtype=torch.int, device="cuda") - eagle_verify_retrive[(bs,)]( - self.retrive_index.contiguous(), - accept_mask.contiguous(), - self.retrive_cum_len, - accept_index, - accept_length, - extract_index, - max_draft_len, - self.draft_token_num, - triton.next_power_of_2(max_draft_len), + sampling_info.apply_logits_bias(linear_penalty) + logits_output.next_token_logits.add_( + torch.repeat_interleave(linear_penalty, self.draft_token_num, dim=0) + ) + + if batch.sampling_info.is_all_greedy: + target_predict = torch.argmax(logits_output.next_token_logits, dim=-1) + target_predict = target_predict.reshape(bs, self.draft_token_num) + + verify_tree_greedy( + predicts=predict, # mutable + accept_index=accept_index, # mutable + accept_token_num=accept_length, # mutable + candidates=candidates.to(torch.int32), + retrive_index=self.retrive_index.to(torch.int32), + retrive_next_token=self.retrive_next_token.to(torch.int32), + retrive_next_sibling=self.retrive_next_sibling.to(torch.int32), + target_predict=target_predict.to(torch.int32), ) else: - # temp > 0 - bs = self.retrive_index.shape[0] - predict_shape = list(logits_output.next_token_logits.shape)[:-1] - predict_shape[-1] += 1 - target_logits = logits_output.next_token_logits[self.retrive_index] - predict = torch.full(predict_shape, -1, dtype=torch.int32, device="cuda") - accept_index = torch.full( - (bs, self.spec_steps + 1), -1, dtype=torch.int32, device="cuda" + # apply temperature and get target probs + expanded_temperature = torch.repeat_interleave( + sampling_info.temperatures, self.draft_token_num, dim=0 + ) # (bs * draft_token_num, 1) + + target_probs = F.softmax( + logits_output.next_token_logits / expanded_temperature, dim=-1 + ) # (bs * draft_token_num, vocab_size) + target_probs = top_k_renorm_prob( + target_probs, + torch.repeat_interleave( + sampling_info.top_ks, self.draft_token_num, dim=0 + ), + ) # (bs * draft_token_num, vocab_size) + target_probs = top_p_renorm_prob( + target_probs, + torch.repeat_interleave( + sampling_info.top_ps, self.draft_token_num, dim=0 + ), ) - accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda") - expanded_temperature = batch.sampling_info.temperatures.unsqueeze(1) - target_probs = F.softmax(target_logits / expanded_temperature, dim=-1) - draft_probs = torch.full_like( - target_probs, 0, dtype=torch.float32, device="cuda" + target_probs = target_probs.reshape(bs, self.draft_token_num, -1) + + draft_probs = torch.zeros( + target_probs.shape, dtype=torch.float32, device="cuda" ) coins = torch.rand_like(candidates, dtype=torch.float32, device="cuda") tree_speculative_sampling_target_only( @@ -394,6 +367,12 @@ class EagleVerifyInput: uniform_samples=coins, target_probs=target_probs, draft_probs=draft_probs, + threshold_single=global_server_args_dict[ + "speculative_accept_threshold_single" + ], + threshold_acc=global_server_args_dict[ + "speculative_accept_threshold_acc" + ], deterministic=True, ) @@ -425,119 +404,94 @@ class EagleVerifyInput: new_accept_index.extend(new_accept_index_) unfinished_index.append(i) req.spec_verify_ct += 1 - accept_length = (accept_index != -1).sum(dim=1) - 1 - accept_index = accept_index[accept_index != -1] - accept_length_cpu = accept_length.tolist() - verified_id = predict[accept_index] - evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool) - evict_mask[accept_index] = False - mem_need_free_idx = batch.out_cache_loc[evict_mask] - token_to_kv_pool_allocator.free(mem_need_free_idx) - assign_req_to_token_pool[(bs,)]( - batch.req_pool_indices, - batch.req_to_token_pool.req_to_token, - batch.seq_lens, - batch.seq_lens + accept_length + 1, - batch.out_cache_loc[accept_index], - batch.req_to_token_pool.req_to_token.shape[1], - triton.next_power_of_2(bs), - ) - batch.seq_lens.add_(accept_length + 1) + if not has_finished: + accept_index = accept_index[accept_index != -1] + verified_id = predict[accept_index] + evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool) + evict_mask[accept_index] = False + mem_need_free_idx = batch.out_cache_loc[evict_mask] + token_to_kv_pool_allocator.free(mem_need_free_idx) + batch.out_cache_loc = batch.out_cache_loc[accept_index] + assign_req_to_token_pool[(bs,)]( + batch.req_pool_indices, + batch.req_to_token_pool.req_to_token, + batch.seq_lens, + batch.seq_lens + accept_length + 1, + batch.out_cache_loc, + batch.req_to_token_pool.req_to_token.shape[1], + triton.next_power_of_2(bs), + ) + batch.seq_lens.add_(accept_length + 1) + accept_length_cpu = accept_length.tolist() - draft_input = EagleDraftInput() - if len(new_accept_index) > 0: - new_accept_index = torch.tensor(new_accept_index, device="cuda") - draft_input.hidden_states = batch.spec_info.hidden_states[new_accept_index] - draft_input.verified_id = predict[new_accept_index] - draft_input.accept_length = accept_length[unfinished_index] - draft_input.accept_length_cpu = [ - accept_length_cpu[i] for i in unfinished_index - ] - if has_finished: - draft_input.seq_lens_for_draft_extend = batch.seq_lens[unfinished_index] - else: - draft_input.seq_lens_for_draft_extend = batch.seq_lens - batch.out_cache_loc = batch.out_cache_loc[new_accept_index] + draft_input = EagleDraftInput() + draft_input.hidden_states = batch.spec_info.hidden_states[accept_index] + draft_input.verified_id = verified_id + draft_input.accept_length = accept_length + draft_input.accept_length_cpu = accept_length_cpu + draft_input.seq_lens_for_draft_extend = batch.seq_lens + draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices - return EagleVerifyOutput( - draft_input=draft_input, - logits_output=logits_output, - verified_id=verified_id, - accept_length_per_req_cpu=accept_length_cpu, - accepeted_indices_cpu=accept_index, - ) + return EagleVerifyOutput( + draft_input=draft_input, + logits_output=logits_output, + verified_id=verified_id, + accept_length_per_req_cpu=accept_length_cpu, + accepeted_indices=accept_index, + ) + else: + accept_length = (accept_index != -1).sum(dim=1) - 1 + accept_index = accept_index[accept_index != -1] + verified_id = predict[accept_index] + evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool) + evict_mask[accept_index] = False + mem_need_free_idx = batch.out_cache_loc[evict_mask] + token_to_kv_pool_allocator.free(mem_need_free_idx) + assign_req_to_token_pool[(bs,)]( + batch.req_pool_indices, + batch.req_to_token_pool.req_to_token, + batch.seq_lens, + batch.seq_lens + accept_length + 1, + batch.out_cache_loc[accept_index], + batch.req_to_token_pool.req_to_token.shape[1], + triton.next_power_of_2(bs), + ) + batch.seq_lens.add_(accept_length + 1) + accept_length_cpu = accept_length.tolist() + draft_input = EagleDraftInput() + if len(new_accept_index) > 0: + new_accept_index = torch.tensor(new_accept_index, device="cuda") + draft_input.hidden_states = batch.spec_info.hidden_states[ + new_accept_index + ] + draft_input.verified_id = predict[new_accept_index] + draft_input.accept_length = accept_length[unfinished_index] + draft_input.accept_length_cpu = [ + accept_length_cpu[i] for i in unfinished_index + ] + if has_finished: + draft_input.seq_lens_for_draft_extend = batch.seq_lens[ + unfinished_index + ] + draft_input.req_pool_indices_for_draft_extend = ( + batch.req_pool_indices[unfinished_index] + ) + else: + draft_input.seq_lens_for_draft_extend = batch.seq_lens + draft_input.req_pool_indices_for_draft_extend = ( + batch.req_pool_indices + ) + batch.out_cache_loc = batch.out_cache_loc[new_accept_index] -@triton.jit -def eagle_verify_retrive( - retrive_index, - accept_mask, - retrive_cum_len, - accept_index, - accept_length, - extract_index, - max_len: tl.constexpr, - draft_token_num: tl.constexpr, - max_len_upper: tl.constexpr, -): - """ - Args: - retrive_index: Pointer to indices of draft tokens - accept_mask: Mask indicating which tokens were accepted - retrive_cum_len: Cumulative lengths of token sequences in a batch - accept_index (out): Accept token indices - accept_length (out): Length of accepted tokens per sequence in a batch - extract_index (out): Index for last accepted tokens - max_len: Maximum length in a batch - draft_token_num: Number of tokens speculatively generated - max_len_upper An upper bound for token sequence length - """ - pid = tl.program_id(axis=0) - - retrive_end = tl.load(retrive_cum_len + pid + 1) - retrive_start = tl.load(retrive_cum_len + pid) - retrive_len = retrive_end - retrive_start - accept_ptr = accept_mask + retrive_start - accept_offset = tl.arange(0, draft_token_num) - accept_load_mask = accept_offset < retrive_len - accept_len_list = tl.load( - accept_ptr + accept_offset, mask=accept_load_mask, other=-1 - ) - - accept_len = tl.max(accept_len_list) - max_index = tl.argmax(accept_len_list, axis=0, tie_break_left=True) - # triton is not support argmax with tie_break_right, so I need implement it by some way - mask_max = accept_len_list == accept_len - - count_mask = tl.full(shape=[draft_token_num], value=0, dtype=tl.int32) - count = tl.sum(tl.where(mask_max, 1, count_mask)) - if count > 1: - index = tl.arange(0, draft_token_num) - mask_left = index != max_index - remained_index = tl.where(mask_max and mask_left, index, 0) - max_index = tl.max(remained_index) - - tl.store(accept_length + pid, accept_len) - retrive_index_ptr = retrive_index + (retrive_start + max_index) * max_len - retrive_offset = tl.arange(0, max_len_upper) - retrive_load_mask = retrive_offset < accept_len + 1 - data = tl.load(retrive_index_ptr + retrive_offset, mask=retrive_load_mask) - - tl.store( - accept_index + pid * max_len + retrive_offset, data, mask=retrive_load_mask - ) - - extract_load_ptr = accept_index + pid * max_len + accept_len - if accept_len == max_len - 1: - extract_data = tl.load(extract_load_ptr - 1) - tl.store(extract_index + pid * 2, extract_data) - extract_data = tl.load(extract_load_ptr) - tl.store(extract_index + pid * 2 + 1, extract_data) - - else: - extract_data = tl.load(extract_load_ptr) - tl.store(extract_index + pid * 2, extract_data) + return EagleVerifyOutput( + draft_input=draft_input, + logits_output=logits_output, + verified_id=verified_id, + accept_length_per_req_cpu=accept_length_cpu, + accepeted_indices=accept_index, + ) @triton.jit diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 90d47cc0f..e2dee9e12 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -1,11 +1,14 @@ import logging import os import time +from contextlib import contextmanager from typing import List, Optional, Tuple import torch from huggingface_hub import snapshot_download +from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group +from sglang.srt.layers.dp_attention import disable_dp_size from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs from sglang.srt.managers.schedule_batch import ScheduleBatch @@ -27,11 +30,22 @@ from sglang.srt.speculative.eagle_utils import ( fast_topk, select_top_k_tokens, ) -from sglang.srt.utils import get_available_gpu_memory +from sglang.srt.utils import empty_context, get_available_gpu_memory, is_cuda_available + +if is_cuda_available(): + from sgl_kernel import segment_packbits logger = logging.getLogger(__name__) +@contextmanager +def draft_tp_context(tp_group: GroupCoordinator): + # Draft model doesn't use dp and has its own tp group. + # We disable mscclpp now because it doesn't support 2 comm groups. + with disable_dp_size(), patch_tensor_parallel_group(tp_group): + yield + + class EAGLEWorker(TpModelWorker): def __init__( @@ -76,16 +90,17 @@ class EAGLEWorker(TpModelWorker): self.hot_token_id = None # Init draft worker - super().__init__( - gpu_id=gpu_id, - tp_rank=tp_rank, - server_args=server_args, - nccl_port=nccl_port, - dp_rank=dp_rank, - is_draft_worker=True, - req_to_token_pool=self.req_to_token_pool, - token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, - ) + with empty_context(): + super().__init__( + gpu_id=gpu_id, + tp_rank=tp_rank, + server_args=server_args, + nccl_port=nccl_port, + dp_rank=dp_rank, + is_draft_worker=True, + req_to_token_pool=self.req_to_token_pool, + token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, + ) # Share the embedding and lm_head embed, head = self.target_worker.model_runner.model.get_embed_and_head() @@ -94,12 +109,17 @@ class EAGLEWorker(TpModelWorker): self.hot_token_id = self.hot_token_id.to(head.device) head.data = head.data[self.hot_token_id] self.draft_model_runner.model.set_embed_and_head(embed, head) + + # Init attention backend and cuda graphs self.draft_model_runner.server_args.disable_cuda_graph = ( backup_disable_cuda_graph ) - - self.init_attention_backend() - self.init_cuda_graphs() + self.draft_tp_context = ( + draft_tp_context if server_args.enable_dp_attention else empty_context + ) + with self.draft_tp_context(self.draft_model_runner.tp_group): + self.init_attention_backend() + self.init_cuda_graphs() def init_attention_backend(self): # Create multi-step attn backends and cuda graph runners @@ -109,52 +129,70 @@ class EAGLEWorker(TpModelWorker): ) self.draft_attn_backend = FlashInferMultiStepDraftBackend( - self.model_runner, + self.draft_model_runner, self.topk, self.speculative_num_steps, ) + self.draft_extend_attn_backend = None + self.padded_static_len = self.speculative_num_steps + 1 + self.has_prefill_wrapper_verify = True elif self.server_args.attention_backend == "triton": from sglang.srt.layers.attention.triton_backend import ( TritonMultiStepDraftBackend, ) self.draft_attn_backend = TritonMultiStepDraftBackend( - self.model_runner, + self.draft_model_runner, self.topk, self.speculative_num_steps, ) + self.draft_extend_attn_backend = None + self.padded_static_len = self.speculative_num_steps + 1 + self.has_prefill_wrapper_verify = False elif self.server_args.attention_backend == "flashinfer_mla": from sglang.srt.layers.attention.flashinfer_mla_backend import ( FlashInferMLAMultiStepDraftBackend, ) self.draft_attn_backend = FlashInferMLAMultiStepDraftBackend( - self.model_runner, + self.draft_model_runner, self.topk, self.speculative_num_steps, ) + self.draft_extend_attn_backend = None + self.padded_static_len = self.speculative_num_steps + 1 + self.has_prefill_wrapper_verify = True else: raise ValueError( f"EAGLE is not supportted in attention backend {self.server_args.attention_backend}" ) + self.draft_model_runner.draft_attn_backend = self.draft_attn_backend def init_cuda_graphs(self): """Capture cuda graphs.""" self.cuda_graph_runner = None + self.cuda_graph_runner_for_draft_extend = None if self.server_args.disable_cuda_graph: return + # Capture draft tic = time.time() + before_mem = get_available_gpu_memory(self.device, self.gpu_id) logger.info( - f"Capture draft cuda graph begin. This can take up to several minutes. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" + f"Capture draft cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB" ) self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self) + after_mem = get_available_gpu_memory(self.device, self.gpu_id) logger.info( - f"Capture draft cuda graph end. Time elapsed: {time.time() - tic:.2f} s. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" + f"Capture draft cuda graph end. Time elapsed: {time.time() - tic:.2f} s. avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB." ) + # Capture extend + if self.draft_extend_attn_backend: + raise NotImplementedError() + @property def draft_model_runner(self): return self.model_runner @@ -164,8 +202,8 @@ class EAGLEWorker(TpModelWorker): ) -> Tuple[LogitsProcessorOutput, List[int], int, int]: """Run speculative decoding forward. - NOTE: Many states of batch is modified as you go through. It is not guaranteed - the final output batch doesn't have the same state as the input. + NOTE: Many states of batch is modified as you go through. It is not guaranteed that + the final output batch have the same state as the input. Args: batch: The batch to run forward. The state of the batch is modified as it runs. @@ -173,30 +211,42 @@ class EAGLEWorker(TpModelWorker): A tuple of the final logit output of the target model, next tokens accepeted, the batch id (used for overlap schedule), and number of accepeted tokens. """ - assert not batch.spec_algorithm.is_none() if batch.forward_mode.is_decode(): - spec_info, to_free_cache_loc = self.draft(batch) + with self.draft_tp_context(self.draft_model_runner.tp_group): + spec_info, to_free_cache_loc = self.draft(batch) logits_output, verify_output, model_worker_batch = self.verify( batch, spec_info ) + # Free cache loc (we put it here to avoid synchronization and hide kernel launch overhead.) self.token_to_kv_pool_allocator.free(to_free_cache_loc) - # if it is None, means all requests are finished - if batch.spec_info.verified_id is not None: - self.forward_draft_extend_after_decode(batch) + # If it is None, it means all requests are finished + if batch.spec_info.verified_id is not None: + with self.draft_tp_context(self.draft_model_runner.tp_group): + self.forward_draft_extend_after_decode(batch) return ( logits_output, verify_output.verified_id, model_worker_batch.bid, sum(verify_output.accept_length_per_req_cpu), ) - + elif batch.forward_mode.is_idle(): + model_worker_batch = batch.get_model_worker_batch() + logits_output, next_token_ids, _ = ( + self.target_worker.forward_batch_generation( + ForwardBatch.init_new( + model_worker_batch, self.target_worker.model_runner + ) + ) + ) + return logits_output, next_token_ids, model_worker_batch.bid, 0, False else: logits_output, next_token_ids, bid = self.forward_target_extend(batch) - self.forward_draft_extend( - batch, logits_output.hidden_states, next_token_ids - ) + with self.draft_tp_context(self.draft_model_runner.tp_group): + self.forward_draft_extend( + batch, logits_output.hidden_states, next_token_ids + ) return logits_output, next_token_ids, bid, 0 def forward_target_extend( @@ -226,6 +276,13 @@ class EAGLEWorker(TpModelWorker): num_seqs = batch.batch_size() spec_info = batch.spec_info + # Accumulate penalty + if batch.sampling_info.penalizer_orchestrator.is_required: + # This is a relaxed version of penalties for speculative decoding. + batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( + spec_info.verified_id.to(torch.int64) + ) + # Allocate cache locations out_cache_loc = batch.alloc_token_slots( num_seqs * self.topk * self.speculative_num_steps @@ -275,9 +332,7 @@ class EAGLEWorker(TpModelWorker): self.topk, self.speculative_num_steps, self.server_args.speculative_num_draft_tokens, - batch.sampling_info.is_all_greedy, ) - return ret, out_cache_loc def draft_forward(self, forward_batch: ForwardBatch): @@ -307,7 +362,7 @@ class EAGLEWorker(TpModelWorker): token_list.append(tree_info[1]) parents_list.append(tree_info[2]) - # we don't need to run the last forward. we get 1 token from draft prefill and (#spec steps - 1) tokens here + # We don't need to run the last forward. we get 1 token from draft prefill and (#spec steps - 1) tokens here if i == self.speculative_num_steps - 1: break @@ -322,7 +377,7 @@ class EAGLEWorker(TpModelWorker): spec_info.hidden_states = hidden_states # Run forward - logits_output = self.model_runner.model.forward( + logits_output = self.draft_model_runner.model.forward( forward_batch.input_ids, forward_batch.positions, forward_batch ) self._detect_nan_if_needed(logits_output) @@ -351,11 +406,10 @@ class EAGLEWorker(TpModelWorker): # Post process based on verified outputs. # Pick indices that we care (accepeted) logits_output.next_token_logits = logits_output.next_token_logits[ - res.accepeted_indices_cpu - ] - logits_output.hidden_states = logits_output.hidden_states[ - res.accepeted_indices_cpu + res.accepeted_indices ] + logits_output.hidden_states = logits_output.hidden_states[res.accepeted_indices] + # Prepare the batch for the next draft forwards. batch.forward_mode = ForwardMode.DECODE batch.spec_info = res.draft_input @@ -407,7 +461,7 @@ class EAGLEWorker(TpModelWorker): batch_next_token_ids, ] - # Add output logprobs to the request. + # Add output logprobs to the request pt = 0 next_token_logprobs = logits_output.next_token_logprobs.tolist() verified_ids = batch_next_token_ids.tolist() @@ -456,27 +510,38 @@ class EAGLEWorker(TpModelWorker): self.capture_for_decode(logits_output, forward_batch.spec_info) def forward_draft_extend_after_decode(self, batch: ScheduleBatch): - seq_lens_backup = batch.seq_lens + # Backup fileds that will be modified in-place + seq_lens_backup = batch.seq_lens.clone() + req_pool_indices_backup = batch.req_pool_indices + accept_length_backup = batch.spec_info.accept_length + return_logprob_backup = batch.return_logprob + + # Prepare metadata batch.forward_mode = ForwardMode.DRAFT_EXTEND - batch.spec_info.prepare_extend_after_decode(batch, self.speculative_num_steps) + batch.spec_info.prepare_extend_after_decode( + batch, + self.speculative_num_steps, + ) batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST - # We don't need logprob for this extend. - original_return_logprob = batch.return_logprob batch.return_logprob = False model_worker_batch = batch.get_model_worker_batch() forward_batch = ForwardBatch.init_new( model_worker_batch, self.draft_model_runner ) + + # Run logits_output = self.draft_model_runner.forward(forward_batch) + self._detect_nan_if_needed(logits_output) - assert forward_batch.spec_info is batch.spec_info self.capture_for_decode(logits_output, forward_batch.spec_info) # Restore backup. # This is because `seq_lens` can be modified in `prepare_extend_after_decode` - batch.return_logprob = original_return_logprob batch.forward_mode = ForwardMode.DECODE batch.seq_lens = seq_lens_backup + batch.req_pool_indices = req_pool_indices_backup + batch.spec_info.accept_length = accept_length_backup + batch.return_logprob = return_logprob_backup def capture_for_decode( self, logits_output: LogitsProcessorOutput, draft_input: EagleDraftInput @@ -489,7 +554,7 @@ class EAGLEWorker(TpModelWorker): if self.enable_nan_detection: logits = logits_output.next_token_logits if torch.any(torch.isnan(logits)): - logger.warning("Detected errors during sampling! NaN in the logits.") + logger.error("Detected errors during sampling! NaN in the logits.") raise ValueError("Detected errors during sampling! NaN in the logits.") diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index f70c7e9ec..e1f009b1e 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -36,6 +36,7 @@ import tempfile import threading import time import warnings +from contextlib import contextmanager from functools import lru_cache from importlib.metadata import PackageNotFoundError, version from importlib.util import find_spec @@ -1577,6 +1578,16 @@ def next_power_of_2(n: int): setattr(triton, "next_power_of_2", next_power_of_2) +@contextmanager +def empty_context(*args, **kwargs): + try: + # Setup code goes here + yield + finally: + # Cleanup code goes here + pass + + def add_prefix(name: str, prefix: str) -> str: """Add a weight path prefix to a module name. diff --git a/scripts/ci_install_dependency.sh b/scripts/ci_install_dependency.sh index b1d14973e..fba6bdd80 100755 --- a/scripts/ci_install_dependency.sh +++ b/scripts/ci_install_dependency.sh @@ -24,6 +24,3 @@ pip install transformers==4.48.3 sentence_transformers accelerate==1.4.0 peft pa # For compling xgrammar kernels pip install cuda-python nvidia-cuda-nvrtc-cu12 - -# reinstall sgl-kernel -pip install sgl-kernel==0.0.5.post1 --force-reinstall --no-deps diff --git a/sgl-kernel/csrc/speculative/speculative_sampling.cuh b/sgl-kernel/csrc/speculative/speculative_sampling.cuh index 10b14713f..158dfa471 100644 --- a/sgl-kernel/csrc/speculative/speculative_sampling.cuh +++ b/sgl-kernel/csrc/speculative/speculative_sampling.cuh @@ -36,8 +36,8 @@ template < typename DType, typename IdType> __global__ void TreeSpeculativeSamplingTargetOnly( - IdType* predicts, - IdType* accept_index, + IdType* predicts, // mutable + IdType* accept_index, // mutable IdType* accept_token_num, // mutable IdType* candidates, IdType* retrive_index, @@ -158,8 +158,8 @@ __global__ void TreeSpeculativeSamplingTargetOnly( template cudaError_t TreeSpeculativeSamplingTargetOnly( - IdType* predicts, - IdType* output_token_ids, + IdType* predicts, // mutable + IdType* output_token_ids, // mutable IdType* output_accepted_token_num, // mutable IdType* candidates, IdType* retrive_index, diff --git a/test/srt/test_eagle_infer.py b/test/srt/test_eagle_infer.py index a87b6e37b..a464c9f24 100644 --- a/test/srt/test_eagle_infer.py +++ b/test/srt/test_eagle_infer.py @@ -122,8 +122,8 @@ class TestEAGLEEngine(unittest.TestCase): def _test_acc_length(self, engine): prompt = [ - "Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:" - ] * 5 + "Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:", + ] * 5 # test batched generation sampling_params = {"temperature": 0, "max_new_tokens": 512} output = engine.generate(prompt, sampling_params) output = output[0] diff --git a/test/srt/test_mla_flashinfer.py b/test/srt/test_mla_flashinfer.py index 26b5740b1..68d1749ab 100644 --- a/test/srt/test_mla_flashinfer.py +++ b/test/srt/test_mla_flashinfer.py @@ -67,7 +67,7 @@ class TestFlashinferMLANoRagged(unittest.TestCase): "--enable-torch-compile", "--disable-cuda-graph", "--cuda-graph-max-bs", - "2", + "4", "--enable-flashinfer-mla", "--flashinfer-mla-disable-ragged", ] @@ -109,7 +109,7 @@ class TestFlashinferMLAMTP(unittest.TestCase): other_args.extend( [ "--cuda-graph-max-bs", - "2", + "4", "--disable-radix", "--enable-torch-compile", "--torch-compile-max-bs",