diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py
index d2ef57328..e7d548710 100644
--- a/python/sglang/srt/managers/io_struct.py
+++ b/python/sglang/srt/managers/io_struct.py
@@ -482,6 +482,7 @@ class BatchEmbeddingOut:
embeddings: List[List[float]]
# Token counts
prompt_tokens: List[int]
+ cached_tokens: List[int]
@dataclass
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
index 9351908c5..10698b0bc 100644
--- a/python/sglang/srt/managers/scheduler.py
+++ b/python/sglang/srt/managers/scheduler.py
@@ -159,17 +159,6 @@ class Scheduler:
)
self.gpu_id = gpu_id
self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
- self.decode_mem_cache_buf_multiplier = (
- (
- self.server_args.speculative_num_draft_tokens
- + (
- self.server_args.speculative_eagle_topk
- * self.server_args.speculative_num_draft_tokens
- )
- )
- if not self.spec_algorithm.is_none()
- else 1
- )
# Distributed rank info
self.dp_size = server_args.dp_size
@@ -208,42 +197,12 @@ class Scheduler:
self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
# Init tokenizer
- self.model_config = ModelConfig(
- server_args.model_path,
- trust_remote_code=server_args.trust_remote_code,
- revision=server_args.revision,
- context_length=server_args.context_length,
- model_override_args=server_args.json_model_override_args,
- is_embedding=server_args.is_embedding,
- dtype=server_args.dtype,
- quantization=server_args.quantization,
- )
- self.is_generation = self.model_config.is_generation
-
- if server_args.skip_tokenizer_init:
- self.tokenizer = self.processor = None
- else:
- if self.model_config.is_multimodal:
- self.processor = get_processor(
- server_args.tokenizer_path,
- tokenizer_mode=server_args.tokenizer_mode,
- trust_remote_code=server_args.trust_remote_code,
- revision=server_args.revision,
- )
- self.tokenizer = self.processor.tokenizer
- else:
- self.tokenizer = get_tokenizer(
- server_args.tokenizer_path,
- tokenizer_mode=server_args.tokenizer_mode,
- trust_remote_code=server_args.trust_remote_code,
- revision=server_args.revision,
- )
+ self.init_tokenizer()
# Check whether overlap can be enabled
if not self.is_generation:
self.enable_overlap = False
logger.info("Overlap scheduler is disabled for embedding models.")
-
if self.model_config.is_multimodal:
self.enable_overlap = False
logger.info("Overlap scheduler is disabled for multimodal models.")
@@ -307,32 +266,7 @@ class Scheduler:
)
# Init memory pool and cache
- self.req_to_token_pool, self.token_to_kv_pool_allocator = (
- self.tp_worker.get_memory_pool()
- )
-
- if (
- server_args.chunked_prefill_size is not None
- and server_args.disable_radix_cache
- ):
- self.tree_cache = ChunkCache(
- req_to_token_pool=self.req_to_token_pool,
- token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
- )
- else:
- if self.enable_hierarchical_cache:
- self.tree_cache = HiRadixCache(
- req_to_token_pool=self.req_to_token_pool,
- token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
- )
- else:
- self.tree_cache = RadixCache(
- req_to_token_pool=self.req_to_token_pool,
- token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
- disable=server_args.disable_radix_cache,
- )
-
- self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
+ self.init_memory_pool_and_cache()
# Init running status
self.waiting_queue: List[Req] = []
@@ -346,25 +280,13 @@ class Scheduler:
self.forward_ct = 0
self.forward_ct_decode = 0
self.num_generated_tokens = 0
- self.spec_num_total_accepted_tokens = 0
- self.spec_num_total_forward_ct = 0
- self.cum_spec_accept_length = 0
- self.cum_spec_accept_count = 0
self.last_decode_stats_tic = time.time()
self.return_health_check_ct = 0
self.current_stream = torch.get_device_module(self.device).current_stream()
if self.device == "cpu":
self.current_stream.synchronize = lambda: None # No-op for CPU
- # For metrics only.
- # The largest prefill length of a single request
- self._largest_prefill_len: int = 0
- # The largest context length (prefill + generation) of a single request
- self._largest_prefill_decode_len: int = 0
- self.last_gen_throughput: float = 0.0
- self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
-
- # Session info
+ # Init session info
self.sessions: Dict[str, Session] = {}
# Init chunked prefill
@@ -385,11 +307,11 @@ class Scheduler:
else:
self.grammar_backend = None
- # Init new token estimation
+ # Init schedule policy and new token estimation
+ self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
assert (
server_args.schedule_conservativeness >= 0
), "Invalid schedule_conservativeness"
-
self.init_new_token_ratio = min(
global_config.default_init_new_token_ratio
* server_args.schedule_conservativeness,
@@ -428,14 +350,7 @@ class Scheduler:
self.profiler_target_forward_ct: Optional[int] = None
# Init metrics stats
- self.stats = SchedulerStats()
- if self.enable_metrics:
- self.metrics_collector = SchedulerMetricsCollector(
- labels={
- "model_name": self.server_args.served_model_name,
- # TODO: Add lora name/path in the future,
- },
- )
+ self.init_metrics()
# Init request dispatcher
self._request_dispatcher = TypeBasedDispatcher(
@@ -458,39 +373,104 @@ class Scheduler:
(ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
(ProfileReq, self.profile),
(GetInternalStateReq, self.get_internal_state),
+ (SetInternalStateReq, self.set_internal_state),
]
)
- def watchdog_thread(self):
- """A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
- self.watchdog_last_forward_ct = 0
- self.watchdog_last_time = time.time()
+ def init_tokenizer(self):
+ server_args = self.server_args
- while True:
- current = time.time()
- if self.cur_batch is not None:
- if self.watchdog_last_forward_ct == self.forward_ct:
- if current > self.watchdog_last_time + self.watchdog_timeout:
- logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
- break
- else:
- self.watchdog_last_forward_ct = self.forward_ct
- self.watchdog_last_time = current
- time.sleep(self.watchdog_timeout // 2)
-
- # Print batch size and memory pool info to check whether there are de-sync issues.
- logger.error(
- f"{self.cur_batch.batch_size()=}, "
- f"{self.cur_batch.reqs=}, "
- f"{self.token_to_kv_pool_allocator.available_size()=}, "
- f"{self.tree_cache.evictable_size()=}, "
+ self.model_config = ModelConfig(
+ server_args.model_path,
+ trust_remote_code=server_args.trust_remote_code,
+ revision=server_args.revision,
+ context_length=server_args.context_length,
+ model_override_args=server_args.json_model_override_args,
+ is_embedding=server_args.is_embedding,
+ dtype=server_args.dtype,
+ quantization=server_args.quantization,
)
- # Wait for some time so that the parent process can print the error.
- pyspy_dump_schedulers()
- print(file=sys.stderr, flush=True)
- print(file=sys.stdout, flush=True)
- time.sleep(5)
- self.parent_process.send_signal(signal.SIGQUIT)
+ self.is_generation = self.model_config.is_generation
+
+ if server_args.skip_tokenizer_init:
+ self.tokenizer = self.processor = None
+ else:
+ if self.model_config.is_multimodal:
+ self.processor = get_processor(
+ server_args.tokenizer_path,
+ tokenizer_mode=server_args.tokenizer_mode,
+ trust_remote_code=server_args.trust_remote_code,
+ revision=server_args.revision,
+ )
+ self.tokenizer = self.processor.tokenizer
+ else:
+ self.tokenizer = get_tokenizer(
+ server_args.tokenizer_path,
+ tokenizer_mode=server_args.tokenizer_mode,
+ trust_remote_code=server_args.trust_remote_code,
+ revision=server_args.revision,
+ )
+
+ def init_memory_pool_and_cache(self):
+ server_args = self.server_args
+
+ self.req_to_token_pool, self.token_to_kv_pool_allocator = (
+ self.tp_worker.get_memory_pool()
+ )
+
+ if (
+ server_args.chunked_prefill_size is not None
+ and server_args.disable_radix_cache
+ ):
+ self.tree_cache = ChunkCache(
+ req_to_token_pool=self.req_to_token_pool,
+ token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
+ )
+ else:
+ if self.enable_hierarchical_cache:
+ self.tree_cache = HiRadixCache(
+ req_to_token_pool=self.req_to_token_pool,
+ token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
+ )
+ else:
+ self.tree_cache = RadixCache(
+ req_to_token_pool=self.req_to_token_pool,
+ token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
+ disable=server_args.disable_radix_cache,
+ )
+
+ self.decode_mem_cache_buf_multiplier = (
+ 1
+ if self.spec_algorithm.is_none()
+ else (
+ server_args.speculative_num_draft_tokens
+ + (
+ server_args.speculative_eagle_topk
+ * server_args.speculative_num_steps
+ )
+ )
+ )
+
+ def init_metrics(self):
+ # The largest prefill length of a single request
+ self._largest_prefill_len: int = 0
+ # The largest context length (prefill + generation) of a single request
+ self._largest_prefill_decode_len: int = 0
+ self.last_gen_throughput: float = 0.0
+ self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
+ self.spec_num_total_accepted_tokens = 0
+ self.spec_num_total_forward_ct = 0
+ self.cum_spec_accept_length = 0
+ self.cum_spec_accept_count = 0
+ self.stats = SchedulerStats()
+ if self.enable_metrics:
+ engine_type = "unified"
+ self.metrics_collector = SchedulerMetricsCollector(
+ labels={
+ "model_name": self.server_args.served_model_name,
+ "engine_type": engine_type,
+ },
+ )
@torch.no_grad()
def event_loop_normal(self):
@@ -1176,6 +1156,7 @@ class Scheduler:
):
self.stop_profile()
+ # Run forward
if self.is_generation:
if self.spec_algorithm.is_none():
model_worker_batch = batch.get_model_worker_batch()
@@ -1196,6 +1177,7 @@ class Scheduler:
self.spec_num_total_forward_ct += batch.batch_size()
self.num_generated_tokens += num_accepted_tokens
batch.output_ids = next_token_ids
+
# These 2 values are needed for processing the output, but the values can be
# modified by overlap schedule. So we have to copy them here so that
# we can use the correct values in output processing.
@@ -1229,7 +1211,6 @@ class Scheduler:
result: Union[GenerationBatchResult, EmbeddingBatchResult],
):
if batch.forward_mode.is_decode():
- assert isinstance(result, GenerationBatchResult)
self.process_batch_result_decode(batch, result)
if batch.is_empty():
self.running_batch = None
@@ -1481,6 +1462,7 @@ class Scheduler:
batch.next_batch_sampling_info.update_regex_vocab_mask()
self.current_stream.synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()
+
self.stream_output(batch.reqs, batch.return_logprob)
self.token_to_kv_pool_allocator.free_group_end()
@@ -1584,7 +1566,9 @@ class Scheduler:
req.temp_input_token_ids_logprobs_idx
)
for val, idx in zip(
- req.temp_input_top_logprobs_val, req.temp_input_top_logprobs_idx
+ req.temp_input_top_logprobs_val,
+ req.temp_input_top_logprobs_idx,
+ strict=True,
):
req.input_top_logprobs_val.extend(val)
req.input_top_logprobs_idx.extend(idx)
@@ -1809,14 +1793,18 @@ class Scheduler:
else: # embedding or reward model
embeddings = []
prompt_tokens = []
+ cached_tokens = []
for req in reqs:
if req.finished():
rids.append(req.rid)
finished_reasons.append(req.finished_reason.to_json())
embeddings.append(req.embedding)
prompt_tokens.append(len(req.origin_input_ids))
+ cached_tokens.append(req.cached_tokens)
self.send_to_detokenizer.send_pyobj(
- BatchEmbeddingOut(rids, finished_reasons, embeddings, prompt_tokens)
+ BatchEmbeddingOut(
+ rids, finished_reasons, embeddings, prompt_tokens, cached_tokens
+ )
)
def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
@@ -1902,6 +1890,37 @@ class Scheduler:
self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
+ def watchdog_thread(self):
+ """A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
+ self.watchdog_last_forward_ct = 0
+ self.watchdog_last_time = time.time()
+
+ while True:
+ current = time.time()
+ if self.cur_batch is not None:
+ if self.watchdog_last_forward_ct == self.forward_ct:
+ if current > self.watchdog_last_time + self.watchdog_timeout:
+ logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
+ break
+ else:
+ self.watchdog_last_forward_ct = self.forward_ct
+ self.watchdog_last_time = current
+ time.sleep(self.watchdog_timeout // 2)
+
+ # Print batch size and memory pool info to check whether there are de-sync issues.
+ logger.error(
+ f"{self.cur_batch.batch_size()=}, "
+ f"{self.cur_batch.reqs=}, "
+ f"{self.token_to_kv_pool_allocator.available_size()=}, "
+ f"{self.tree_cache.evictable_size()=}, "
+ )
+ # Wait for some time so that the parent process can print the error.
+ pyspy_dump_schedulers()
+ print(file=sys.stderr, flush=True)
+ print(file=sys.stdout, flush=True)
+ time.sleep(5)
+ self.parent_process.send_signal(signal.SIGQUIT)
+
def flush_cache_wrapped(self, recv_req: FlushCacheReq):
self.flush_cache()
@@ -1913,7 +1932,6 @@ class Scheduler:
self.cur_batch = None
self.last_batch = None
self.tree_cache.reset()
- self.tree_cache_metrics = {"total": 0, "hit": 0}
if self.grammar_backend:
self.grammar_backend.reset()
self.req_to_token_pool.clear()
@@ -2005,6 +2023,9 @@ class Scheduler:
req.to_abort = True
break
+ def _pause_engine(self) -> Tuple[List[Req], int]:
+ raise NotImplementedError()
+
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
"""In-place update of the weights from disk."""
success, message = self.tp_worker.update_weights_from_disk(recv_req)
diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
index 486f1d24c..743c0c430 100644
--- a/python/sglang/srt/managers/tokenizer_manager.py
+++ b/python/sglang/srt/managers/tokenizer_manager.py
@@ -1068,6 +1068,7 @@ class TokenizerManager:
self.metrics_collector.observe_one_finished_request(
recv_obj.prompt_tokens[i],
completion_tokens,
+ recv_obj.cached_tokens[i],
state.finished_time - state.created_time,
)
diff --git a/python/sglang/srt/metrics/collector.py b/python/sglang/srt/metrics/collector.py
index 9f7d6d579..45fe2fce6 100644
--- a/python/sglang/srt/metrics/collector.py
+++ b/python/sglang/srt/metrics/collector.py
@@ -121,6 +121,12 @@ class TokenizerMetricsCollector:
labelnames=labels.keys(),
)
+ self.cached_tokens_total = Counter(
+ name="sglang:cached_tokens_total",
+ documentation="Number of cached prompt tokens.",
+ labelnames=labels.keys(),
+ )
+
self.num_requests_total = Counter(
name="sglang:num_requests_total",
documentation="Number of requests processed.",
@@ -245,10 +251,12 @@ class TokenizerMetricsCollector:
self,
prompt_tokens: int,
generation_tokens: int,
+ cached_tokens: int,
e2e_latency: float,
):
self.prompt_tokens_total.labels(**self.labels).inc(prompt_tokens)
self.generation_tokens_total.labels(**self.labels).inc(generation_tokens)
+ self.cached_tokens_total.labels(**self.labels).inc(cached_tokens)
self.num_requests_total.labels(**self.labels).inc(1)
self._log_histogram(self.histogram_e2e_request_latency, e2e_latency)
if generation_tokens >= 1:
diff --git a/test/srt/test_eagle_infer.py b/test/srt/test_eagle_infer.py
index 9571faf22..3dffb2584 100644
--- a/test/srt/test_eagle_infer.py
+++ b/test/srt/test_eagle_infer.py
@@ -1,16 +1,20 @@
import multiprocessing as mp
+import os
import random
import threading
import time
import unittest
from types import SimpleNamespace
+from typing import List, Optional
import requests
+import torch
import sglang as sgl
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval
+from sglang.test.runners import DEFAULT_PROMPTS, SRTRunner
from sglang.test.test_utils import (
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
@@ -19,7 +23,9 @@ from sglang.test.test_utils import (
popen_launch_server,
)
-acc_rate_tolerance = 0.15
+torch_dtype = torch.float16
+prefill_tolerance = 5e-2
+decode_tolerance: float = 5e-2
class TestEAGLEEngine(unittest.TestCase):
@@ -28,51 +34,72 @@ class TestEAGLEEngine(unittest.TestCase):
"speculative_draft_model_path": DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"speculative_algorithm": "EAGLE",
"speculative_num_steps": 5,
- "speculative_eagle_topk": 8,
- "speculative_num_draft_tokens": 64,
+ "speculative_eagle_topk": 4,
+ "speculative_num_draft_tokens": 8,
"mem_fraction_static": 0.7,
- "cuda_graph_max_bs": 32,
+ "cuda_graph_max_bs": 5,
}
+ NUM_CONFIGS = 3
def setUp(self):
self.prompt = "Today is a sunny day and I like"
self.sampling_params = {"temperature": 0, "max_new_tokens": 8}
- ref_engine = sgl.Engine(model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST)
+ ref_engine = sgl.Engine(
+ model_path=self.BASE_CONFIG["model_path"], cuda_graph_max_bs=1
+ )
self.ref_output = ref_engine.generate(self.prompt, self.sampling_params)["text"]
ref_engine.shutdown()
def test_correctness(self):
configs = [
+ # Basic config
self.BASE_CONFIG,
+ # Disable cuda graph
{**self.BASE_CONFIG, "disable_cuda_graph": True},
- {**self.BASE_CONFIG, "chunked_prefill_size": 2},
+ # Chunked prefill
+ {**self.BASE_CONFIG, "chunked_prefill_size": 4},
]
- for config in configs:
- with self.subTest(
- cuda_graph=(
- "enabled" if len(config) == len(self.BASE_CONFIG) else "disabled"
- ),
- chunked_prefill_size=(
- config["chunked_prefill_size"]
- if "chunked_prefill_size" in config
- else "default"
- ),
- ):
- engine = sgl.Engine(**config)
+ for i, config in enumerate(configs[: self.NUM_CONFIGS]):
+ with self.subTest(i=i):
+ print(f"{config=}")
+ engine = sgl.Engine(**config, log_level="info", decode_log_interval=10)
try:
- self._test_basic_generation(engine)
- self._test_eos_token(engine)
+ self._test_single_generation(engine)
self._test_batch_generation(engine)
+ self._test_eos_token(engine)
+ self._test_acc_length(engine)
finally:
engine.shutdown()
+ print("=" * 100)
- def _test_basic_generation(self, engine):
+ def _test_single_generation(self, engine):
output = engine.generate(self.prompt, self.sampling_params)["text"]
print(f"{output=}, {self.ref_output=}")
self.assertEqual(output, self.ref_output)
+ def _test_batch_generation(self, engine):
+ prompts = [
+ "Hello, my name is",
+ "The president of the United States is",
+ "The capital of France is",
+ "The future of AI is",
+ ]
+ params = {"temperature": 0, "max_new_tokens": 50}
+
+ outputs = engine.generate(prompts, params)
+ for prompt, output in zip(prompts, outputs):
+ print(f"Prompt: {prompt}")
+ print(f"Generated: {output['text']}")
+ print("-" * 40)
+
+ print(f"{engine.get_server_info()=}")
+
+ avg_spec_accept_length = engine.get_server_info()["avg_spec_accept_length"]
+ print(f"{avg_spec_accept_length=}")
+ self.assertGreater(avg_spec_accept_length, 1.9)
+
def _test_eos_token(self, engine):
prompt = "[INST] <>\nYou are a helpful assistant.\n<>\nToday is a sunny day and I like [/INST]"
params = {
@@ -88,32 +115,54 @@ class TestEAGLEEngine(unittest.TestCase):
tokens = tokenizer.encode(output, truncation=False)
self.assertNotIn(tokenizer.eos_token_id, tokens)
- def _test_batch_generation(self, engine):
- prompts = [
- "Hello, my name is",
- "The president of the United States is",
- "The capital of France is",
- "The future of AI is",
+ def _test_acc_length(self, engine):
+ prompt = [
+ "Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:"
]
- params = {"temperature": 0, "max_new_tokens": 30}
+ sampling_params = {"temperature": 0, "max_new_tokens": 512}
+ output = engine.generate(prompt, sampling_params)
+ output = output[0]
- outputs = engine.generate(prompts, params)
- for prompt, output in zip(prompts, outputs):
- print(f"Prompt: {prompt}")
- print(f"Generated: {output['text']}")
- print("-" * 40)
+ if "spec_verify_ct" in output["meta_info"]:
+ acc_length = (
+ output["meta_info"]["completion_tokens"]
+ / output["meta_info"]["spec_verify_ct"]
+ )
+ else:
+ acc_length = 1.0
+
+ speed = (
+ output["meta_info"]["completion_tokens"]
+ / output["meta_info"]["e2e_latency"]
+ )
+ print(f"{acc_length=}")
+ self.assertGreater(acc_length, 3.6)
-prompts = [
- "[INST] <>\\nYou are a helpful assistant.\\n<>\\nToday is a sunny day and I like[/INST]"
- '[INST] <>\\nYou are a helpful assistant.\\n<>\\nWhat are the mental triggers in Jeff Walker\'s Product Launch Formula and "Launch" book?[/INST]',
- "[INST] <>\\nYou are a helpful assistant.\\n<>\\nSummarize Russell Brunson's Perfect Webinar Script...[/INST]",
- "[INST] <>\\nYou are a helpful assistant.\\n<>\\nwho are you?[/INST]",
- "[INST] <>\\nYou are a helpful assistant.\\n<>\\nwhere are you from?[/INST]",
-]
+class TestEAGLEEngineTokenMap(unittest.TestCase):
+ BASE_CONFIG = {
+ "model_path": "meta-llama/Meta-Llama-3-8B-Instruct",
+ "speculative_draft_model_path": "lmsys/sglang-EAGLE-LLaMA3-Instruct-8B",
+ "speculative_algorithm": "EAGLE",
+ "speculative_num_steps": 5,
+ "speculative_eagle_topk": 4,
+ "speculative_num_draft_tokens": 8,
+ "speculative_token_map": "thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt",
+ "mem_fraction_static": 0.7,
+ "cuda_graph_max_bs": 5,
+ }
+ NUM_CONFIGS = 1
class TestEAGLEServer(unittest.TestCase):
+ PROMPTS = [
+ "[INST] <>\\nYou are a helpful assistant.\\n<>\\nToday is a sunny day and I like[/INST]"
+ '[INST] <>\\nYou are a helpful assistant.\\n<>\\nWhat are the mental triggers in Jeff Walker\'s Product Launch Formula and "Launch" book?[/INST]',
+ "[INST] <>\\nYou are a helpful assistant.\\n<>\\nSummarize Russell Brunson's Perfect Webinar Script...[/INST]",
+ "[INST] <>\\nYou are a helpful assistant.\\n<>\\nwho are you?[/INST]",
+ "[INST] <>\\nYou are a helpful assistant.\\n<>\\nwhere are you from?[/INST]",
+ ]
+
@classmethod
def setUpClass(cls):
cls.base_url = DEFAULT_URL_FOR_TEST
@@ -127,17 +176,17 @@ class TestEAGLEServer(unittest.TestCase):
"--speculative-draft-model-path",
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"--speculative-num-steps",
- "5",
+ 5,
"--speculative-eagle-topk",
- "8",
+ 8,
"--speculative-num-draft-tokens",
- "64",
+ 64,
"--mem-fraction-static",
- "0.7",
+ 0.7,
"--chunked-prefill-size",
- "128",
- "--cuda-graph-max-bs",
- "32",
+ 128,
+ "--max-running-requests",
+ 8,
],
)
@@ -147,7 +196,7 @@ class TestEAGLEServer(unittest.TestCase):
def send_request(self):
time.sleep(random.uniform(0, 2))
- for prompt in prompts:
+ for prompt in self.PROMPTS:
url = self.base_url + "/generate"
data = {
"text": prompt,
@@ -160,7 +209,7 @@ class TestEAGLEServer(unittest.TestCase):
assert response.status_code == 200
def send_requests_abort(self):
- for prompt in prompts:
+ for prompt in self.PROMPTS:
try:
time.sleep(random.uniform(0, 2))
url = self.base_url + "/generate"
@@ -192,6 +241,8 @@ class TestEAGLEServer(unittest.TestCase):
p.join()
def test_gsm8k(self):
+ server_info = requests.get(self.base_url + "/flush_cache")
+
args = SimpleNamespace(
num_shots=5,
data_path=None,
@@ -201,96 +252,25 @@ class TestEAGLEServer(unittest.TestCase):
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
+
metrics = run_eval(args)
print(f"{metrics=}")
-
self.assertGreater(metrics["accuracy"], 0.20)
+ server_info = requests.get(self.base_url + "/get_server_info")
+ avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
+ print(f"{avg_spec_accept_length=}")
+ self.assertGreater(avg_spec_accept_length, 2.9)
-def measure_acc_rate(engine):
- tic = time.time()
- prompt = [
- "Human: Give me a fully functional FastAPI server. Show the python code.<|separator|>\n\nAssistant:"
- ]
- sampling_params = {"temperature": 0, "max_new_tokens": 512}
- output = engine.generate(prompt, sampling_params)
- output = output[0]
- latency = time.time() - tic
-
- if "spec_verify_ct" in output["meta_info"]:
- base_acc_length = (
- output["meta_info"]["completion_tokens"]
- / output["meta_info"]["spec_verify_ct"]
- )
- else:
- base_acc_length = 0.0
-
- base_speed = output["meta_info"]["completion_tokens"] / latency
- return base_acc_length, base_speed
+ # Wait a little bit so that the memory check happens.
+ time.sleep(4)
-class TestEagleAcceptanceRate(unittest.TestCase):
-
- @classmethod
- def setUpClass(cls):
- mp.set_start_method("spawn", force=True)
- ref_engine = sgl.Engine(
- model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
- speculative_draft_model_path=DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
- speculative_algorithm="EAGLE",
- speculative_num_steps=5,
- speculative_eagle_topk=8,
- speculative_num_draft_tokens=64,
- mem_fraction_static=0.7,
- disable_radix_cache=True,
- )
- cls.base_acc_length, cls.base_speed = measure_acc_rate(ref_engine)
- ref_engine.shutdown()
- assert cls.base_acc_length > 4.45
-
- def test_acc_rate(self):
- base_acc_length, base_speed = self.base_acc_length, self.base_speed
- chunk_engine = sgl.Engine(
- model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
- speculative_draft_model_path=DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
- speculative_algorithm="EAGLE",
- speculative_num_steps=5,
- speculative_eagle_topk=8,
- speculative_num_draft_tokens=64,
- mem_fraction_static=0.7,
- chunked_prefill_size=2,
- disable_radix_cache=True,
- )
- chunked_acc_length, chunked_base_speed = measure_acc_rate(chunk_engine)
- chunk_engine.shutdown()
- print(base_acc_length, base_speed)
- print(chunked_acc_length, chunked_base_speed)
- assert abs(base_acc_length - chunked_acc_length) < acc_rate_tolerance
-
- def test_acc_rate_prefix_caching(self):
- base_acc_length, base_speed = self.base_acc_length, self.base_speed
- prefix_caching_engine = sgl.Engine(
- model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
- speculative_draft_model_path=DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
- speculative_algorithm="EAGLE",
- speculative_num_steps=5,
- speculative_eagle_topk=8,
- speculative_num_draft_tokens=64,
- mem_fraction_static=0.7,
- chunked_prefill_size=4,
- schedule_policy="lpm",
- )
- for _ in range(10):
- acc_length, _ = measure_acc_rate(prefix_caching_engine)
- print(f"{acc_length=}")
- assert abs(base_acc_length - acc_length) < acc_rate_tolerance
- # The second one should hit the prefix cache.
- prefix_caching_engine.shutdown()
-
-
-class TestEAGLERetract(unittest.TestCase):
+class TestEAGLERetract(TestEAGLEServer):
@classmethod
def setUpClass(cls):
+ # These config helps find a leak.
+ os.environ["SGLANG_CI_SMALL_KV_SIZE"] = "4500"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
@@ -302,41 +282,20 @@ class TestEAGLERetract(unittest.TestCase):
"--speculative-draft-model-path",
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"--speculative-num-steps",
- "5",
+ 5,
"--speculative-eagle-topk",
- "8",
+ 8,
"--speculative-num-draft-tokens",
- "64",
+ 64,
"--mem-fraction-static",
- "0.7",
+ 0.7,
"--chunked-prefill-size",
- "128",
+ 128,
"--max-running-requests",
- "64",
+ 64,
],
)
- @classmethod
- def tearDownClass(cls):
- kill_process_tree(cls.process.pid)
-
- def test_gsm8k(self):
- args = SimpleNamespace(
- num_shots=5,
- data_path=None,
- num_questions=200,
- max_new_tokens=512,
- parallel=128,
- host="http://127.0.0.1",
- port=int(self.base_url.split(":")[-1]),
- )
- metrics = run_eval(args)
- print(f"{metrics=}")
-
- self.assertGreater(metrics["accuracy"], 0.20)
- # Wait a little bit so that the memory check happens.
- time.sleep(5)
-
class TestEAGLEServerTriton(TestEAGLEServer):
@classmethod
@@ -352,73 +311,20 @@ class TestEAGLEServerTriton(TestEAGLEServer):
"--speculative-draft-model-path",
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"--speculative-num-steps",
- "5",
+ 5,
"--speculative-eagle-topk",
- "4",
+ 8,
"--speculative-num-draft-tokens",
- "8",
+ 64,
"--mem-fraction-static",
- "0.7",
+ 0.7,
"--attention-backend",
"triton",
- "--cuda-graph-max-bs",
- "16",
+ "--max-running-requests",
+ 8,
],
)
-class TestEAGLEEngineTokenMap(unittest.TestCase):
- def setUp(self):
- self.prompt = "Today is a sunny day and I like"
- self.sampling_params = {"temperature": 0, "max_new_tokens": 8}
-
- ref_engine = sgl.Engine(
- model_path="meta-llama/Meta-Llama-3-8B-Instruct", cuda_graph_max_bs=2
- )
- self.ref_output = ref_engine.generate(self.prompt, self.sampling_params)["text"]
- ref_engine.shutdown()
-
- def test_correctness(self):
- config = {
- "model_path": "meta-llama/Meta-Llama-3-8B-Instruct",
- "speculative_draft_model_path": "lmsys/sglang-EAGLE-LLaMA3-Instruct-8B",
- "speculative_algorithm": "EAGLE",
- "speculative_num_steps": 5,
- "speculative_eagle_topk": 4,
- "speculative_num_draft_tokens": 8,
- "speculative_token_map": "thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt",
- "mem_fraction_static": 0.7,
- "cuda_graph_max_bs": 4,
- "dtype": "bfloat16",
- }
-
- engine = sgl.Engine(**config)
- try:
- self._test_basic_generation(engine)
- self._test_batch_generation(engine)
- finally:
- engine.shutdown()
-
- def _test_basic_generation(self, engine):
- output = engine.generate(self.prompt, self.sampling_params)["text"]
- print(f"{output=}, {self.ref_output=}")
- self.assertEqual(output, self.ref_output)
-
- def _test_batch_generation(self, engine):
- prompts = [
- "Hello, my name is",
- "The president of the United States is",
- "The capital of France is",
- "The future of AI is",
- ]
- params = {"temperature": 0, "max_new_tokens": 30}
-
- outputs = engine.generate(prompts, params)
- for prompt, output in zip(prompts, outputs):
- print(f"Prompt: {prompt}")
- print(f"Generated: {output['text']}")
- print("-" * 40)
-
-
if __name__ == "__main__":
unittest.main()
diff --git a/test/srt/test_metrics.py b/test/srt/test_metrics.py
index 09b9b5a28..03dbf48c8 100644
--- a/test/srt/test_metrics.py
+++ b/test/srt/test_metrics.py
@@ -59,6 +59,7 @@ class TestEnableMetrics(unittest.TestCase):
"sglang:spec_accept_length",
"sglang:prompt_tokens_total",
"sglang:generation_tokens_total",
+ "sglang:cached_tokens_total",
"sglang:num_requests_total",
"sglang:time_to_first_token_seconds",
"sglang:time_per_output_token_seconds",
diff --git a/test/srt/test_moe_ep.py b/test/srt/test_moe_ep.py
index 9f87eb24d..054866e76 100644
--- a/test/srt/test_moe_ep.py
+++ b/test/srt/test_moe_ep.py
@@ -94,7 +94,7 @@ class TestEpMoEFP8(unittest.TestCase):
)
metrics = run_eval(args)
- assert metrics["score"] >= 0.5
+ self.assertGreaterEqual(metrics["score"], 0.5)
def test_mgsm_en(self):
args = SimpleNamespace(
@@ -106,7 +106,7 @@ class TestEpMoEFP8(unittest.TestCase):
)
metrics = run_eval(args)
- assert metrics["score"] >= 0.8
+ self.assertGreaterEqual(metrics["score"], 0.8)
if __name__ == "__main__":