From bc1534ff32c27c4c1a0f485b07031dde516489a4 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Thu, 6 Mar 2025 06:13:59 -0800 Subject: [PATCH] Fix a draft model accuracy bug in eagle; support step=1; return logprob in eagle (#4134) Co-authored-by: Sehoon Kim Co-authored-by: SangBin Cho Co-authored-by: Sehoon Kim --- .github/workflows/pr-test.yml | 2 +- .../layers/attention/flashinfer_backend.py | 141 ++++++++++++------ .../srt/layers/attention/triton_backend.py | 6 +- .../srt/model_executor/cuda_graph_runner.py | 6 - .../srt/speculative/build_eagle_tree.py | 7 +- .../eagle_draft_cuda_graph_runner.py | 12 +- python/sglang/srt/speculative/eagle_utils.py | 3 +- python/sglang/srt/speculative/eagle_worker.py | 99 ++++++++---- test/srt/test_bench_serving.py | 2 +- test/srt/test_eagle_infer.py | 128 +++++++++++++++- test/srt/test_gptqmodel_dynamic.py | 4 +- 11 files changed, 304 insertions(+), 106 deletions(-) diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 2bce60768..225c215c8 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -95,7 +95,7 @@ jobs: strategy: fail-fast: false matrix: - range: [0-6, 6-15, 15-22, 22-32, 32-40, 40-100] + range: [0-6, 6-15, 15-22, 22-32, 32-40, 40-48, 48-100] steps: - name: Checkout code uses: actions/checkout@v3 diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index acec75241..de3bbe5cf 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -7,16 +7,14 @@ FlashInfer is faster and Triton is easier to customize. Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode. """ -import math import os from dataclasses import dataclass from enum import Enum, auto from functools import partial -from typing import TYPE_CHECKING, List, Optional, Union +from typing import TYPE_CHECKING, Callable, List, Optional, Union import torch import triton -import triton.language as tl from sglang.global_config import global_config from sglang.srt.layers.attention.base_attn_backend import AttentionBackend @@ -37,7 +35,7 @@ if is_flashinfer_available(): BatchPrefillWithRaggedKVCacheWrapper, ) from flashinfer.cascade import merge_state - from flashinfer.decode import PosEncodingMode + from flashinfer.decode import _get_range_buf, get_seq_lens class WrapperDispatch(Enum): @@ -73,8 +71,6 @@ class FlashInferAttnBackend(AttentionBackend): ): super().__init__() - self.is_multimodal = model_runner.model_config.is_multimodal - # Parse constants self.decode_use_tensor_cores = should_use_tensor_core( kv_cache_dtype=model_runner.kv_cache_dtype, @@ -86,6 +82,7 @@ class FlashInferAttnBackend(AttentionBackend): ) self.max_context_len = model_runner.model_config.context_len self.skip_prefill = skip_prefill + self.is_multimodal = model_runner.model_config.is_multimodal assert not ( model_runner.sliding_window_size is not None @@ -115,7 +112,6 @@ class FlashInferAttnBackend(AttentionBackend): device=model_runner.device, ) self.workspace_buffer = global_workspace_buffer - max_bs = model_runner.req_to_token_pool.size if kv_indptr_buf is None: self.kv_indptr = [ @@ -163,9 +159,11 @@ class FlashInferAttnBackend(AttentionBackend): ) ) self.prefill_wrappers_verify.append( - BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD") + BatchPrefillWithPagedKVCacheWrapper( + self.workspace_buffer, + "NHD", + ) ) - self.decode_wrappers.append( BatchDecodeWithPagedKVCacheWrapper( self.workspace_buffer, @@ -178,13 +176,14 @@ class FlashInferAttnBackend(AttentionBackend): if not skip_prefill: self.indices_updater_prefill = FlashInferIndicesUpdaterPrefill( model_runner, self - ) + ) # for verify self.indices_updater_decode = FlashInferIndicesUpdaterDecode(model_runner, self) # Other metadata self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None self.decode_cuda_graph_metadata = {} - self.prefill_cuda_graph_metadata = {} + self.prefill_cuda_graph_metadata = {} # For verify + self.draft_extend_cuda_graph_metadata = {} # For draft extend def init_forward_metadata(self, forward_batch: ForwardBatch): if forward_batch.forward_mode.is_decode_or_idle(): @@ -300,7 +299,6 @@ class FlashInferAttnBackend(AttentionBackend): ], ) ) - seq_lens_sum = seq_lens.sum().item() self.indices_updater_decode.update( req_pool_indices, @@ -312,6 +310,10 @@ class FlashInferAttnBackend(AttentionBackend): ) self.decode_cuda_graph_metadata[bs] = decode_wrappers self.forward_metadata = DecodeMetadata(decode_wrappers) + for i in range(self.num_wrappers): + decode_wrappers[i].begin_forward = partial( + fast_decode_plan, decode_wrappers[i] + ) elif forward_mode.is_target_verify(): prefill_wrappers = [] for i in range(self.num_wrappers): @@ -437,7 +439,7 @@ class FlashInferAttnBackend(AttentionBackend): forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), causal=False, sm_scale=layer.scaling, - logits_soft_cap=layer.logit_cap, + logits_soft_cap=logits_soft_cap, ) o, _ = merge_state(o1, s1, o2, s2) @@ -636,9 +638,15 @@ class FlashInferIndicesUpdaterDecode: bs = len(req_pool_indices) kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) kv_indptr = kv_indptr[: bs + 1] - kv_indices = torch.empty( - paged_kernel_lens_sum, dtype=torch.int32, device="cuda" - ) + + if wrapper.is_cuda_graph_enabled: + # Directly write to the cuda graph input buffer + kv_indices = wrapper._paged_kv_indices_buf + else: + kv_indices = torch.empty( + paged_kernel_lens_sum, dtype=torch.int32, device="cuda" + ) + create_flashinfer_kv_indices_triton[(bs,)]( self.req_to_token, req_pool_indices, @@ -649,9 +657,9 @@ class FlashInferIndicesUpdaterDecode: self.req_to_token.shape[1], ) else: - assert isinstance(spec_info, EagleDraftInput) kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices bs = kv_indptr.shape[0] - 1 + wrapper.begin_forward( kv_indptr, kv_indices, @@ -699,7 +707,7 @@ class FlashInferIndicesUpdaterPrefill: def update( self, - req_pool_indices: torch.Tnesor, + req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, prefix_lens: torch.Tensor, @@ -713,7 +721,7 @@ class FlashInferIndicesUpdaterPrefill: def update_single_wrapper( self, - req_pool_indices: torch.Tnesor, + req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, prefix_lens: torch.Tensor, @@ -858,7 +866,6 @@ class FlashInferIndicesUpdaterPrefill: kv_indices, self.req_to_token.shape[1], ) - qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0) qo_indptr = qo_indptr[: bs + 1] custom_mask = None @@ -954,7 +961,10 @@ class FlashInferMultiStepDraftBackend: self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1] def common_template( - self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int + self, + forward_batch: ForwardBatch, + kv_indices_buffer: torch.Tensor, + call_fn: Callable, ): num_seqs = forward_batch.batch_size bs = self.topk * num_seqs @@ -1042,17 +1052,15 @@ class FlashInferMultiStepDraftBackend: forward_mode=ForwardMode.DECODE, spec_info=forward_batch.spec_info, ) - decode_wrapper = self.attn_backends[i].decode_cuda_graph_metadata[ - forward_batch.batch_size - ][0] - decode_wrapper.begin_forward = partial(fast_decode_plan, decode_wrapper) self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn) - def init_forward_metadata_replay_cuda_graph(self, forward_batch): + def init_forward_metadata_replay_cuda_graph( + self, forward_batch: ForwardBatch, bs: int + ): def call_fn(i, forward_batch): self.attn_backends[i].init_forward_metadata_replay_cuda_graph( - forward_batch.batch_size, + bs, forward_batch.req_pool_indices, forward_batch.seq_lens, seq_lens_sum=-1, @@ -1113,6 +1121,11 @@ def should_use_tensor_core( return False +# Use as a fast path to override the indptr in flashinfer's plan function +# This is used to remove some host-to-device copy overhead. +global_override_indptr_cpu = None + + def fast_decode_plan( self, indptr: torch.Tensor, @@ -1142,6 +1155,9 @@ def fast_decode_plan( if logits_soft_cap is None: logits_soft_cap = 0.0 + if self.use_tensor_cores: + qo_indptr_host = _get_range_buf(batch_size + 1, "cpu") + if self.is_cuda_graph_enabled: if batch_size != self._fixed_batch_size: raise ValueError( @@ -1154,7 +1170,7 @@ def fast_decode_plan( raise ValueError( "The size of indices should be less than or equal to the allocated buffer" ) - # Skip these copies + # Skip these copies because we directly write to them during prepartion # self._paged_kv_indptr_buf.copy_(indptr) # self._paged_kv_indices_buf[: len(indices)] = indices # self._paged_kv_last_page_len_buf.copy_(last_page_len) @@ -1162,6 +1178,7 @@ def fast_decode_plan( self._paged_kv_indptr_buf = indptr self._paged_kv_indices_buf = indices self._paged_kv_last_page_len_buf = last_page_len + self._qo_indptr_buf = qo_indptr_host.to(self.device, non_blocking=non_blocking) # NOTE(Zihao): the following tensors acts as placeholder to pass dtype info if not q_data_type: @@ -1184,27 +1201,55 @@ def fast_decode_plan( ) self.last_page_len = torch.ones(32768, dtype=torch.int32) - empty_q_data = self.empty_q_data - empty_kv_cache = self.empty_kv_cache - stream = torch.cuda.current_stream() - self._cached_module.plan( - self._float_workspace_buffer, - self._int_workspace_buffer, - self._pin_memory_int_workspace_buffer, - indptr.to("cpu"), - batch_size, - num_qo_heads, - num_kv_heads, - page_size, - self.is_cuda_graph_enabled, - window_left, - logits_soft_cap, - head_dim, - head_dim, - empty_q_data, - empty_kv_cache, - stream.cuda_stream, + indptr_host = ( + global_override_indptr_cpu + if global_override_indptr_cpu is not None + else indptr.cpu() ) + + if self.use_tensor_cores: + kv_lens_arr_host = get_seq_lens( + indptr_host, self.last_page_len[:batch_size], page_size + ) + + self._plan_info = self._cached_module.plan( + self._float_workspace_buffer, + self._int_workspace_buffer, + self._pin_memory_int_workspace_buffer, + qo_indptr_host, + indptr_host, + kv_lens_arr_host, + batch_size, # total_num_rows + batch_size, + num_qo_heads, + num_kv_heads, + page_size, + self.is_cuda_graph_enabled, + head_dim, + head_dim, + False, # causal + torch.cuda.current_stream().cuda_stream, + ) + else: + self._plan_info = self._cached_module.plan( + self._float_workspace_buffer, + self._int_workspace_buffer, + self._pin_memory_int_workspace_buffer, + indptr_host, + batch_size, + num_qo_heads, + num_kv_heads, + page_size, + self.is_cuda_graph_enabled, + window_left, + logits_soft_cap, + head_dim, + head_dim, + self.empty_q_data, + self.empty_kv_cache, + torch.cuda.current_stream().cuda_stream, + ) + self._pos_encoding_mode = pos_encoding_mode self._window_left = window_left self._logits_soft_cap = logits_soft_cap diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index d07baf102..a9d726180 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -578,10 +578,12 @@ class TritonMultiStepDraftBackend: self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn) - def init_forward_metadata_replay_cuda_graph(self, forward_batch): + def init_forward_metadata_replay_cuda_graph( + self, forward_batch: ForwardBatch, bs: int + ): def call_fn(i, forward_batch): self.attn_backends[i].init_forward_metadata_replay_cuda_graph( - forward_batch.batch_size, + bs, forward_batch.req_pool_indices, forward_batch.seq_lens, seq_lens_sum=-1, diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 842f59a3b..813fbf6fc 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -396,16 +396,10 @@ class CudaGraphRunner: run_once() - torch.cuda.synchronize() - self.model_runner.tp_group.barrier() - global global_graph_memory_pool with torch.cuda.graph(graph, pool=global_graph_memory_pool, stream=stream): out = run_once() - torch.cuda.synchronize() - self.model_runner.tp_group.barrier() - global_graph_memory_pool = graph.pool() return graph, out diff --git a/python/sglang/srt/speculative/build_eagle_tree.py b/python/sglang/srt/speculative/build_eagle_tree.py index 027838ab1..fba411479 100644 --- a/python/sglang/srt/speculative/build_eagle_tree.py +++ b/python/sglang/srt/speculative/build_eagle_tree.py @@ -26,7 +26,12 @@ def build_tree_kernel_efficient_preprocess( draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1) draft_tokens = torch.cat((verified_id.unsqueeze(1), draft_tokens), dim=1).flatten() - parent_list = torch.cat(parents_list[:-1], dim=1) + + if len(parents_list) > 1: + parent_list = torch.cat(parents_list[:-1], dim=1) + else: + batch_size = parents_list[0].shape[0] + parent_list = torch.empty(batch_size, 0, device=parents_list[0].device) return parent_list, top_scores_index, draft_tokens diff --git a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py index c3ecb80a4..e5410ec00 100644 --- a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py @@ -1,7 +1,6 @@ from __future__ import annotations import bisect -import time from typing import TYPE_CHECKING, Callable import torch @@ -162,20 +161,11 @@ class EAGLEDraftCudaGraphRunner: run_once() - torch.cuda.synchronize() - self.model_runner.tp_group.barrier() - - torch.cuda.synchronize() - self.model_runner.tp_group.barrier() - with torch.cuda.graph( graph, pool=get_global_graph_memory_pool(), stream=stream ): out = run_once() - torch.cuda.synchronize() - self.model_runner.tp_group.barrier() - set_global_graph_memory_pool(graph.pool()) return graph, out @@ -204,7 +194,7 @@ class EAGLEDraftCudaGraphRunner: # Attention backend self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph( - forward_batch + forward_batch, forward_batch.batch_size ) # Replay diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py index 17e688085..086b532ac 100644 --- a/python/sglang/srt/speculative/eagle_utils.py +++ b/python/sglang/srt/speculative/eagle_utils.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List +from typing import TYPE_CHECKING, List import torch import torch.nn.functional as F @@ -62,6 +62,7 @@ class EagleDraftInput: batch.input_ids[pt : pt + extend_len] = torch.concat( (input_ids[1:], self.verified_id[i].reshape(1)) ) + pt += extend_len def prepare_extend_after_decode(self, batch: ScheduleBatch, speculative_num_steps): assert self.verified_id.numel() == batch.out_cache_loc.shape[0] diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 4dce896c0..12da787eb 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -1,20 +1,19 @@ import logging import os import time -from typing import Dict, List, Optional, Tuple, Union +from typing import List, Optional, Tuple import torch from huggingface_hub import snapshot_download from sglang.srt.layers.logits_processor import LogitsProcessorOutput -from sglang.srt.managers.schedule_batch import Req, ScheduleBatch +from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.model_executor.forward_batch_info import ( CaptureHiddenMode, ForwardBatch, ForwardMode, ) -from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.server_args import ServerArgs from sglang.srt.speculative.eagle_draft_cuda_graph_runner import ( EAGLEDraftCudaGraphRunner, @@ -27,7 +26,6 @@ from sglang.srt.speculative.eagle_utils import ( fast_topk, select_top_k_tokens, ) -from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.utils import get_available_gpu_memory logger = logging.getLogger(__name__) @@ -44,16 +42,30 @@ class EAGLEWorker(TpModelWorker): nccl_port: int, target_worker: TpModelWorker, ): + # Parse arguments + self.server_args = server_args + self.topk = server_args.speculative_eagle_topk + self.speculative_num_steps = server_args.speculative_num_steps + self.padded_static_len = self.speculative_num_steps + 1 + self.enable_nan_detection = server_args.enable_nan_detection + self.gpu_id = gpu_id + self.device = server_args.device + self.target_worker = target_worker + # Override context length with target model's context length server_args.context_length = target_worker.model_runner.model_config.context_len - os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] = "1" # Do not capture cuda graph in `super().__init__()` - # We will capture it later + # It will be captured later. backup_disable_cuda_graph = server_args.disable_cuda_graph server_args.disable_cuda_graph = True + # Share the allocator with a target worker. + # Draft and target worker own their own KV cache pools. + self.req_to_token_pool, self.token_to_kv_pool_allocator = ( + target_worker.get_memory_pool() + ) - # Lossy optimization by using hot tokens + # Load hot token ids if server_args.speculative_token_map is not None: self.hot_token_id = load_token_map(server_args.speculative_token_map) server_args.json_model_override_args = ( @@ -62,13 +74,7 @@ class EAGLEWorker(TpModelWorker): else: self.hot_token_id = None - # We share the allocator with a target worker. Draft/target worker - # owns its own KV cache. - self.req_to_token_pool, self.token_to_kv_pool_allocator = ( - target_worker.get_memory_pool() - ) - - # Init target worker + # Init draft worker super().__init__( gpu_id=gpu_id, tp_rank=tp_rank, @@ -79,18 +85,6 @@ class EAGLEWorker(TpModelWorker): req_to_token_pool=self.req_to_token_pool, token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, ) - self.target_worker = target_worker - - # Parse arguments - self.topk = server_args.speculative_eagle_topk - self.speculative_num_steps = server_args.speculative_num_steps - self.speculative_algorithm = SpeculativeAlgorithm.from_string( - server_args.speculative_algorithm - ) - self.server_args = server_args - self.use_nan_detection = self.server_args.enable_nan_detection - self.device = self.model_runner.device - self.gpu_id = self.model_runner.gpu_id # Share the embedding and lm_head embed, head = self.target_worker.model_runner.model.get_embed_and_head() @@ -103,8 +97,12 @@ class EAGLEWorker(TpModelWorker): backup_disable_cuda_graph ) + self.init_attention_backend() + self.init_cuda_graphs() + + def init_attention_backend(self): # Create multi-step attn backends and cuda graph runners - if server_args.attention_backend == "flashinfer": + if self.server_args.attention_backend == "flashinfer": from sglang.srt.layers.attention.flashinfer_backend import ( FlashInferMultiStepDraftBackend, ) @@ -114,7 +112,7 @@ class EAGLEWorker(TpModelWorker): self.topk, self.speculative_num_steps, ) - elif server_args.attention_backend == "triton": + elif self.server_args.attention_backend == "triton": from sglang.srt.layers.attention.triton_backend import ( TritonMultiStepDraftBackend, ) @@ -126,11 +124,9 @@ class EAGLEWorker(TpModelWorker): ) else: raise ValueError( - f"EAGLE is not supportted in attention backend {server_args.attention_backend}" + f"EAGLE is not supportted in attention backend {self.server_args.attention_backend}" ) - self.draft_model_runner.draft_attn_backend = self.draft_attn_backend - self.init_cuda_graphs() def init_cuda_graphs(self): """Capture cuda graphs.""" @@ -356,6 +352,41 @@ class EAGLEWorker(TpModelWorker): batch.forward_mode = ForwardMode.DECODE batch.spec_info = res.draft_input + if batch.return_logprob: + # Compute output logprobs using the sampler. + num_tokens_per_req = [ + accept + 1 for accept in res.accept_length_per_req_cpu + ] + self.target_worker.model_runner.update_output_logprobs( + logits_output, + batch.sampling_info, + batch.top_logprobs_nums, + batch.token_ids_logprobs, + res.verified_id, + # +1 for bonus token. + num_tokens_per_req=num_tokens_per_req, + ) + + # Add output logprobs to the request. + pt = 0 + # NOTE: tolist() of these values are skipped when output is processed + next_token_logprobs = res.logits_output.next_token_logprobs.tolist() + verified_ids = res.verified_id.tolist() + for req, num_tokens in zip(batch.reqs, num_tokens_per_req): + for _ in range(num_tokens): + if req.return_logprob: + token_id = verified_ids[pt] + req.output_token_logprobs_val.append(next_token_logprobs[pt]) + req.output_token_logprobs_idx.append(token_id) + if req.top_logprobs_num > 0: + req.output_top_logprobs_val.append( + res.logits_output.next_token_top_logprobs_val[pt] + ) + req.output_top_logprobs_idx.append( + res.logits_output.next_token_top_logprobs_idx[pt] + ) + pt += 1 + return logits_output, res, model_worker_batch def forward_draft_extend( @@ -381,6 +412,7 @@ class EAGLEWorker(TpModelWorker): forward_batch = ForwardBatch.init_new( model_worker_batch, self.draft_model_runner ) + forward_batch.return_logprob = False logits_output = self.draft_model_runner.forward(forward_batch) self._detect_nan_if_needed(logits_output) assert isinstance(forward_batch.spec_info, EagleDraftInput) @@ -393,6 +425,8 @@ class EAGLEWorker(TpModelWorker): batch.spec_info.prepare_extend_after_decode(batch, self.speculative_num_steps) batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST # We don't need logprob for this extend. + original_return_logprob = batch.return_logprob + batch.return_logprob = False model_worker_batch = batch.get_model_worker_batch() forward_batch = ForwardBatch.init_new( model_worker_batch, self.draft_model_runner @@ -404,6 +438,7 @@ class EAGLEWorker(TpModelWorker): # Restore backup. # This is because `seq_lens` can be modified in `prepare_extend_after_decode` + batch.return_logprob = original_return_logprob batch.forward_mode = ForwardMode.DECODE batch.seq_lens = seq_lens_backup @@ -415,7 +450,7 @@ class EAGLEWorker(TpModelWorker): draft_input.hidden_states = logits_output.hidden_states def _detect_nan_if_needed(self, logits_output: LogitsProcessorOutput): - if self.use_nan_detection: + if self.enable_nan_detection: logits = logits_output.next_token_logits if torch.any(torch.isnan(logits)): logger.warning("Detected errors during sampling! NaN in the logits.") diff --git a/test/srt/test_bench_serving.py b/test/srt/test_bench_serving.py index 939c2b5cb..f8b4b1f9a 100644 --- a/test/srt/test_bench_serving.py +++ b/test/srt/test_bench_serving.py @@ -165,7 +165,7 @@ class TestBenchServing(unittest.TestCase): f'median_e2e_latency_ms : {res["median_e2e_latency_ms"]:.2f} ms\n' f'accept_length : {res["accept_length"]:.2f} \n' ) - self.assertLess(res["median_e2e_latency_ms"], 1100) + self.assertLess(res["median_e2e_latency_ms"], 900) self.assertGreater(res["accept_length"], 2.99) def test_moe_offline_throughput_default(self): diff --git a/test/srt/test_eagle_infer.py b/test/srt/test_eagle_infer.py index 3dffb2584..cadca667b 100644 --- a/test/srt/test_eagle_infer.py +++ b/test/srt/test_eagle_infer.py @@ -1,12 +1,16 @@ +import json import multiprocessing as mp import os import random import threading import time import unittest +from concurrent.futures import ThreadPoolExecutor +from functools import partial from types import SimpleNamespace from typing import List, Optional +import numpy as np import requests import torch @@ -21,6 +25,7 @@ from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, popen_launch_server, + run_logprob_check, ) torch_dtype = torch.float16 @@ -260,11 +265,132 @@ class TestEAGLEServer(unittest.TestCase): server_info = requests.get(self.base_url + "/get_server_info") avg_spec_accept_length = server_info.json()["avg_spec_accept_length"] print(f"{avg_spec_accept_length=}") - self.assertGreater(avg_spec_accept_length, 2.9) + self.assertGreater(avg_spec_accept_length, 3.5) # Wait a little bit so that the memory check happens. time.sleep(4) + def test_logprob_start_len(self): + logprob_start_len = 4 + new_tokens = 4 + prompts = [ + "I have a very good idea on", + "Today is a sunndy day and", + ] + + response = requests.post( + self.base_url + "/generate", + json={ + "text": prompts, + "sampling_params": { + "temperature": 0, + "max_new_tokens": new_tokens, + }, + "return_logprob": True, + "top_logprobs_num": 5, + "logprob_start_len": logprob_start_len, + }, + ) + response_json = response.json() + print(json.dumps(response_json, indent=2)) + + for res in response_json: + self.assertEqual( + res["meta_info"]["prompt_tokens"], + logprob_start_len + len(res["meta_info"]["input_token_logprobs"]), + ) + + self.assertEqual(res["meta_info"]["completion_tokens"], new_tokens) + self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), new_tokens) + + def test_logprob_match(self): + """Test the output logprobs are close to the input logprobs if we run a prefill again.""" + + def run_generate( + prompt, return_logprob=False, max_new_tokens=512, logprob_start_len=-1 + ): + + if isinstance(prompt, str): + prompt_kwargs = {"text": prompt} + else: + prompt_kwargs = {"input_ids": prompt} + + response = requests.post( + self.base_url + "/generate", + json={ + **prompt_kwargs, + "sampling_params": { + "temperature": 1.0, + "max_new_tokens": max_new_tokens, + "ignore_eos": True, + }, + "return_logprob": return_logprob, + "return_text_in_logprobs": True, + "logprob_start_len": logprob_start_len, + }, + ) + return response.json() + + prompt = "I have a very good idea on how to" + + gen = run_generate(prompt, return_logprob=True, logprob_start_len=0) + output_logprobs = np.array( + [x[0] for x in gen["meta_info"]["output_token_logprobs"]] + ) + num_prompts_tokens = gen["meta_info"]["prompt_tokens"] + + input_tokens = [x[1] for x in gen["meta_info"]["input_token_logprobs"]] + output_tokens = [x[1] for x in gen["meta_info"]["output_token_logprobs"]] + + new_prompt = input_tokens + output_tokens + score = run_generate( + new_prompt, return_logprob=True, logprob_start_len=0, max_new_tokens=0 + ) + output_logprobs_score = np.array( + [ + x[0] + for x in score["meta_info"]["input_token_logprobs"][num_prompts_tokens:] + ] + ) + + print(f"{output_logprobs[-10:]=}") + print(f"{output_logprobs_score[-10:]=}") + + diff = np.abs(output_logprobs - output_logprobs_score) + max_diff = np.max(diff) + self.assertLess(max_diff, 0.25) + + def test_logprob_mixed(self): + args = [] + temperature = 0 + # input_len, output_len, temperature, logprob_start_len, return_logprob, top_logprobs_num + # Llama 2 context length seems to be only 2k, so we can only test small length. + for input_len in [200, 500, 1000, 2000]: + for output_len in [4, 8]: + for logprob_start_len in [0, 100, 300, 800, 1998]: + for return_logprob in [True, False]: + for top_logprobs_num in [0, 5]: + + if logprob_start_len >= input_len: + continue + + args.append( + ( + input_len, + output_len, + temperature, + logprob_start_len, + return_logprob, + top_logprobs_num, + ) + ) + + random.shuffle(args) + + func = partial(run_logprob_check, self) + with ThreadPoolExecutor(8) as executor: + list(executor.map(func, args)) + class TestEAGLERetract(TestEAGLEServer): @classmethod diff --git a/test/srt/test_gptqmodel_dynamic.py b/test/srt/test_gptqmodel_dynamic.py index f22f37f1d..92e17a8e4 100644 --- a/test/srt/test_gptqmodel_dynamic.py +++ b/test/srt/test_gptqmodel_dynamic.py @@ -143,11 +143,11 @@ class TestGPTQModelDynamic(unittest.TestCase): print(f"result = `{result}`") - assert "paris" in result["text"].lower() + self.assertIn("paris", result["text"].lower()) throughput = max_tokens / (tok - tic) print(f"Throughput: {throughput} tokens/s") - assert throughput >= 140 + self.assertGreaterEqual(throughput, 140) def test_gptq_module(self): check_quant_method(self.MODEL_PATH, use_marlin_kernel=False)