From fcc2e37f695b2b5e50476f5e063d83f9f3f7878e Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Thu, 6 Mar 2025 00:13:20 -0800 Subject: [PATCH] Split the __init__ of scheduler as smaller functions. Improve the eagle tests (#4128) --- python/sglang/srt/managers/io_struct.py | 1 + python/sglang/srt/managers/scheduler.py | 267 +++++++------- .../sglang/srt/managers/tokenizer_manager.py | 1 + python/sglang/srt/metrics/collector.py | 8 + test/srt/test_eagle_infer.py | 338 +++++++----------- test/srt/test_metrics.py | 1 + test/srt/test_moe_ep.py | 4 +- 7 files changed, 279 insertions(+), 341 deletions(-) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index d2ef57328..e7d548710 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -482,6 +482,7 @@ class BatchEmbeddingOut: embeddings: List[List[float]] # Token counts prompt_tokens: List[int] + cached_tokens: List[int] @dataclass diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 9351908c5..10698b0bc 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -159,17 +159,6 @@ class Scheduler: ) self.gpu_id = gpu_id self.enable_hierarchical_cache = server_args.enable_hierarchical_cache - self.decode_mem_cache_buf_multiplier = ( - ( - self.server_args.speculative_num_draft_tokens - + ( - self.server_args.speculative_eagle_topk - * self.server_args.speculative_num_draft_tokens - ) - ) - if not self.spec_algorithm.is_none() - else 1 - ) # Distributed rank info self.dp_size = server_args.dp_size @@ -208,42 +197,12 @@ class Scheduler: self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None) # Init tokenizer - self.model_config = ModelConfig( - server_args.model_path, - trust_remote_code=server_args.trust_remote_code, - revision=server_args.revision, - context_length=server_args.context_length, - model_override_args=server_args.json_model_override_args, - is_embedding=server_args.is_embedding, - dtype=server_args.dtype, - quantization=server_args.quantization, - ) - self.is_generation = self.model_config.is_generation - - if server_args.skip_tokenizer_init: - self.tokenizer = self.processor = None - else: - if self.model_config.is_multimodal: - self.processor = get_processor( - server_args.tokenizer_path, - tokenizer_mode=server_args.tokenizer_mode, - trust_remote_code=server_args.trust_remote_code, - revision=server_args.revision, - ) - self.tokenizer = self.processor.tokenizer - else: - self.tokenizer = get_tokenizer( - server_args.tokenizer_path, - tokenizer_mode=server_args.tokenizer_mode, - trust_remote_code=server_args.trust_remote_code, - revision=server_args.revision, - ) + self.init_tokenizer() # Check whether overlap can be enabled if not self.is_generation: self.enable_overlap = False logger.info("Overlap scheduler is disabled for embedding models.") - if self.model_config.is_multimodal: self.enable_overlap = False logger.info("Overlap scheduler is disabled for multimodal models.") @@ -307,32 +266,7 @@ class Scheduler: ) # Init memory pool and cache - self.req_to_token_pool, self.token_to_kv_pool_allocator = ( - self.tp_worker.get_memory_pool() - ) - - if ( - server_args.chunked_prefill_size is not None - and server_args.disable_radix_cache - ): - self.tree_cache = ChunkCache( - req_to_token_pool=self.req_to_token_pool, - token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, - ) - else: - if self.enable_hierarchical_cache: - self.tree_cache = HiRadixCache( - req_to_token_pool=self.req_to_token_pool, - token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, - ) - else: - self.tree_cache = RadixCache( - req_to_token_pool=self.req_to_token_pool, - token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, - disable=server_args.disable_radix_cache, - ) - - self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache) + self.init_memory_pool_and_cache() # Init running status self.waiting_queue: List[Req] = [] @@ -346,25 +280,13 @@ class Scheduler: self.forward_ct = 0 self.forward_ct_decode = 0 self.num_generated_tokens = 0 - self.spec_num_total_accepted_tokens = 0 - self.spec_num_total_forward_ct = 0 - self.cum_spec_accept_length = 0 - self.cum_spec_accept_count = 0 self.last_decode_stats_tic = time.time() self.return_health_check_ct = 0 self.current_stream = torch.get_device_module(self.device).current_stream() if self.device == "cpu": self.current_stream.synchronize = lambda: None # No-op for CPU - # For metrics only. - # The largest prefill length of a single request - self._largest_prefill_len: int = 0 - # The largest context length (prefill + generation) of a single request - self._largest_prefill_decode_len: int = 0 - self.last_gen_throughput: float = 0.0 - self.step_time_dict = defaultdict(list) # Dict[batch size -> step time] - - # Session info + # Init session info self.sessions: Dict[str, Session] = {} # Init chunked prefill @@ -385,11 +307,11 @@ class Scheduler: else: self.grammar_backend = None - # Init new token estimation + # Init schedule policy and new token estimation + self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache) assert ( server_args.schedule_conservativeness >= 0 ), "Invalid schedule_conservativeness" - self.init_new_token_ratio = min( global_config.default_init_new_token_ratio * server_args.schedule_conservativeness, @@ -428,14 +350,7 @@ class Scheduler: self.profiler_target_forward_ct: Optional[int] = None # Init metrics stats - self.stats = SchedulerStats() - if self.enable_metrics: - self.metrics_collector = SchedulerMetricsCollector( - labels={ - "model_name": self.server_args.served_model_name, - # TODO: Add lora name/path in the future, - }, - ) + self.init_metrics() # Init request dispatcher self._request_dispatcher = TypeBasedDispatcher( @@ -458,39 +373,104 @@ class Scheduler: (ResumeMemoryOccupationReqInput, self.resume_memory_occupation), (ProfileReq, self.profile), (GetInternalStateReq, self.get_internal_state), + (SetInternalStateReq, self.set_internal_state), ] ) - def watchdog_thread(self): - """A watch dog thread that will try to kill the server itself if one forward batch takes too long.""" - self.watchdog_last_forward_ct = 0 - self.watchdog_last_time = time.time() + def init_tokenizer(self): + server_args = self.server_args - while True: - current = time.time() - if self.cur_batch is not None: - if self.watchdog_last_forward_ct == self.forward_ct: - if current > self.watchdog_last_time + self.watchdog_timeout: - logger.error(f"Watchdog timeout ({self.watchdog_timeout=})") - break - else: - self.watchdog_last_forward_ct = self.forward_ct - self.watchdog_last_time = current - time.sleep(self.watchdog_timeout // 2) - - # Print batch size and memory pool info to check whether there are de-sync issues. - logger.error( - f"{self.cur_batch.batch_size()=}, " - f"{self.cur_batch.reqs=}, " - f"{self.token_to_kv_pool_allocator.available_size()=}, " - f"{self.tree_cache.evictable_size()=}, " + self.model_config = ModelConfig( + server_args.model_path, + trust_remote_code=server_args.trust_remote_code, + revision=server_args.revision, + context_length=server_args.context_length, + model_override_args=server_args.json_model_override_args, + is_embedding=server_args.is_embedding, + dtype=server_args.dtype, + quantization=server_args.quantization, ) - # Wait for some time so that the parent process can print the error. - pyspy_dump_schedulers() - print(file=sys.stderr, flush=True) - print(file=sys.stdout, flush=True) - time.sleep(5) - self.parent_process.send_signal(signal.SIGQUIT) + self.is_generation = self.model_config.is_generation + + if server_args.skip_tokenizer_init: + self.tokenizer = self.processor = None + else: + if self.model_config.is_multimodal: + self.processor = get_processor( + server_args.tokenizer_path, + tokenizer_mode=server_args.tokenizer_mode, + trust_remote_code=server_args.trust_remote_code, + revision=server_args.revision, + ) + self.tokenizer = self.processor.tokenizer + else: + self.tokenizer = get_tokenizer( + server_args.tokenizer_path, + tokenizer_mode=server_args.tokenizer_mode, + trust_remote_code=server_args.trust_remote_code, + revision=server_args.revision, + ) + + def init_memory_pool_and_cache(self): + server_args = self.server_args + + self.req_to_token_pool, self.token_to_kv_pool_allocator = ( + self.tp_worker.get_memory_pool() + ) + + if ( + server_args.chunked_prefill_size is not None + and server_args.disable_radix_cache + ): + self.tree_cache = ChunkCache( + req_to_token_pool=self.req_to_token_pool, + token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, + ) + else: + if self.enable_hierarchical_cache: + self.tree_cache = HiRadixCache( + req_to_token_pool=self.req_to_token_pool, + token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, + ) + else: + self.tree_cache = RadixCache( + req_to_token_pool=self.req_to_token_pool, + token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, + disable=server_args.disable_radix_cache, + ) + + self.decode_mem_cache_buf_multiplier = ( + 1 + if self.spec_algorithm.is_none() + else ( + server_args.speculative_num_draft_tokens + + ( + server_args.speculative_eagle_topk + * server_args.speculative_num_steps + ) + ) + ) + + def init_metrics(self): + # The largest prefill length of a single request + self._largest_prefill_len: int = 0 + # The largest context length (prefill + generation) of a single request + self._largest_prefill_decode_len: int = 0 + self.last_gen_throughput: float = 0.0 + self.step_time_dict = defaultdict(list) # Dict[batch size -> step time] + self.spec_num_total_accepted_tokens = 0 + self.spec_num_total_forward_ct = 0 + self.cum_spec_accept_length = 0 + self.cum_spec_accept_count = 0 + self.stats = SchedulerStats() + if self.enable_metrics: + engine_type = "unified" + self.metrics_collector = SchedulerMetricsCollector( + labels={ + "model_name": self.server_args.served_model_name, + "engine_type": engine_type, + }, + ) @torch.no_grad() def event_loop_normal(self): @@ -1176,6 +1156,7 @@ class Scheduler: ): self.stop_profile() + # Run forward if self.is_generation: if self.spec_algorithm.is_none(): model_worker_batch = batch.get_model_worker_batch() @@ -1196,6 +1177,7 @@ class Scheduler: self.spec_num_total_forward_ct += batch.batch_size() self.num_generated_tokens += num_accepted_tokens batch.output_ids = next_token_ids + # These 2 values are needed for processing the output, but the values can be # modified by overlap schedule. So we have to copy them here so that # we can use the correct values in output processing. @@ -1229,7 +1211,6 @@ class Scheduler: result: Union[GenerationBatchResult, EmbeddingBatchResult], ): if batch.forward_mode.is_decode(): - assert isinstance(result, GenerationBatchResult) self.process_batch_result_decode(batch, result) if batch.is_empty(): self.running_batch = None @@ -1481,6 +1462,7 @@ class Scheduler: batch.next_batch_sampling_info.update_regex_vocab_mask() self.current_stream.synchronize() batch.next_batch_sampling_info.sampling_info_done.set() + self.stream_output(batch.reqs, batch.return_logprob) self.token_to_kv_pool_allocator.free_group_end() @@ -1584,7 +1566,9 @@ class Scheduler: req.temp_input_token_ids_logprobs_idx ) for val, idx in zip( - req.temp_input_top_logprobs_val, req.temp_input_top_logprobs_idx + req.temp_input_top_logprobs_val, + req.temp_input_top_logprobs_idx, + strict=True, ): req.input_top_logprobs_val.extend(val) req.input_top_logprobs_idx.extend(idx) @@ -1809,14 +1793,18 @@ class Scheduler: else: # embedding or reward model embeddings = [] prompt_tokens = [] + cached_tokens = [] for req in reqs: if req.finished(): rids.append(req.rid) finished_reasons.append(req.finished_reason.to_json()) embeddings.append(req.embedding) prompt_tokens.append(len(req.origin_input_ids)) + cached_tokens.append(req.cached_tokens) self.send_to_detokenizer.send_pyobj( - BatchEmbeddingOut(rids, finished_reasons, embeddings, prompt_tokens) + BatchEmbeddingOut( + rids, finished_reasons, embeddings, prompt_tokens, cached_tokens + ) ) def prepare_dp_attn_batch(self, local_batch: ScheduleBatch): @@ -1902,6 +1890,37 @@ class Scheduler: self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs]) self.grammar_queue = self.grammar_queue[num_ready_reqs:] + def watchdog_thread(self): + """A watch dog thread that will try to kill the server itself if one forward batch takes too long.""" + self.watchdog_last_forward_ct = 0 + self.watchdog_last_time = time.time() + + while True: + current = time.time() + if self.cur_batch is not None: + if self.watchdog_last_forward_ct == self.forward_ct: + if current > self.watchdog_last_time + self.watchdog_timeout: + logger.error(f"Watchdog timeout ({self.watchdog_timeout=})") + break + else: + self.watchdog_last_forward_ct = self.forward_ct + self.watchdog_last_time = current + time.sleep(self.watchdog_timeout // 2) + + # Print batch size and memory pool info to check whether there are de-sync issues. + logger.error( + f"{self.cur_batch.batch_size()=}, " + f"{self.cur_batch.reqs=}, " + f"{self.token_to_kv_pool_allocator.available_size()=}, " + f"{self.tree_cache.evictable_size()=}, " + ) + # Wait for some time so that the parent process can print the error. + pyspy_dump_schedulers() + print(file=sys.stderr, flush=True) + print(file=sys.stdout, flush=True) + time.sleep(5) + self.parent_process.send_signal(signal.SIGQUIT) + def flush_cache_wrapped(self, recv_req: FlushCacheReq): self.flush_cache() @@ -1913,7 +1932,6 @@ class Scheduler: self.cur_batch = None self.last_batch = None self.tree_cache.reset() - self.tree_cache_metrics = {"total": 0, "hit": 0} if self.grammar_backend: self.grammar_backend.reset() self.req_to_token_pool.clear() @@ -2005,6 +2023,9 @@ class Scheduler: req.to_abort = True break + def _pause_engine(self) -> Tuple[List[Req], int]: + raise NotImplementedError() + def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput): """In-place update of the weights from disk.""" success, message = self.tp_worker.update_weights_from_disk(recv_req) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 486f1d24c..743c0c430 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -1068,6 +1068,7 @@ class TokenizerManager: self.metrics_collector.observe_one_finished_request( recv_obj.prompt_tokens[i], completion_tokens, + recv_obj.cached_tokens[i], state.finished_time - state.created_time, ) diff --git a/python/sglang/srt/metrics/collector.py b/python/sglang/srt/metrics/collector.py index 9f7d6d579..45fe2fce6 100644 --- a/python/sglang/srt/metrics/collector.py +++ b/python/sglang/srt/metrics/collector.py @@ -121,6 +121,12 @@ class TokenizerMetricsCollector: labelnames=labels.keys(), ) + self.cached_tokens_total = Counter( + name="sglang:cached_tokens_total", + documentation="Number of cached prompt tokens.", + labelnames=labels.keys(), + ) + self.num_requests_total = Counter( name="sglang:num_requests_total", documentation="Number of requests processed.", @@ -245,10 +251,12 @@ class TokenizerMetricsCollector: self, prompt_tokens: int, generation_tokens: int, + cached_tokens: int, e2e_latency: float, ): self.prompt_tokens_total.labels(**self.labels).inc(prompt_tokens) self.generation_tokens_total.labels(**self.labels).inc(generation_tokens) + self.cached_tokens_total.labels(**self.labels).inc(cached_tokens) self.num_requests_total.labels(**self.labels).inc(1) self._log_histogram(self.histogram_e2e_request_latency, e2e_latency) if generation_tokens >= 1: diff --git a/test/srt/test_eagle_infer.py b/test/srt/test_eagle_infer.py index 9571faf22..3dffb2584 100644 --- a/test/srt/test_eagle_infer.py +++ b/test/srt/test_eagle_infer.py @@ -1,16 +1,20 @@ import multiprocessing as mp +import os import random import threading import time import unittest from types import SimpleNamespace +from typing import List, Optional import requests +import torch import sglang as sgl from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.utils import kill_process_tree from sglang.test.few_shot_gsm8k import run_eval +from sglang.test.runners import DEFAULT_PROMPTS, SRTRunner from sglang.test.test_utils import ( DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, @@ -19,7 +23,9 @@ from sglang.test.test_utils import ( popen_launch_server, ) -acc_rate_tolerance = 0.15 +torch_dtype = torch.float16 +prefill_tolerance = 5e-2 +decode_tolerance: float = 5e-2 class TestEAGLEEngine(unittest.TestCase): @@ -28,51 +34,72 @@ class TestEAGLEEngine(unittest.TestCase): "speculative_draft_model_path": DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, "speculative_algorithm": "EAGLE", "speculative_num_steps": 5, - "speculative_eagle_topk": 8, - "speculative_num_draft_tokens": 64, + "speculative_eagle_topk": 4, + "speculative_num_draft_tokens": 8, "mem_fraction_static": 0.7, - "cuda_graph_max_bs": 32, + "cuda_graph_max_bs": 5, } + NUM_CONFIGS = 3 def setUp(self): self.prompt = "Today is a sunny day and I like" self.sampling_params = {"temperature": 0, "max_new_tokens": 8} - ref_engine = sgl.Engine(model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST) + ref_engine = sgl.Engine( + model_path=self.BASE_CONFIG["model_path"], cuda_graph_max_bs=1 + ) self.ref_output = ref_engine.generate(self.prompt, self.sampling_params)["text"] ref_engine.shutdown() def test_correctness(self): configs = [ + # Basic config self.BASE_CONFIG, + # Disable cuda graph {**self.BASE_CONFIG, "disable_cuda_graph": True}, - {**self.BASE_CONFIG, "chunked_prefill_size": 2}, + # Chunked prefill + {**self.BASE_CONFIG, "chunked_prefill_size": 4}, ] - for config in configs: - with self.subTest( - cuda_graph=( - "enabled" if len(config) == len(self.BASE_CONFIG) else "disabled" - ), - chunked_prefill_size=( - config["chunked_prefill_size"] - if "chunked_prefill_size" in config - else "default" - ), - ): - engine = sgl.Engine(**config) + for i, config in enumerate(configs[: self.NUM_CONFIGS]): + with self.subTest(i=i): + print(f"{config=}") + engine = sgl.Engine(**config, log_level="info", decode_log_interval=10) try: - self._test_basic_generation(engine) - self._test_eos_token(engine) + self._test_single_generation(engine) self._test_batch_generation(engine) + self._test_eos_token(engine) + self._test_acc_length(engine) finally: engine.shutdown() + print("=" * 100) - def _test_basic_generation(self, engine): + def _test_single_generation(self, engine): output = engine.generate(self.prompt, self.sampling_params)["text"] print(f"{output=}, {self.ref_output=}") self.assertEqual(output, self.ref_output) + def _test_batch_generation(self, engine): + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + params = {"temperature": 0, "max_new_tokens": 50} + + outputs = engine.generate(prompts, params) + for prompt, output in zip(prompts, outputs): + print(f"Prompt: {prompt}") + print(f"Generated: {output['text']}") + print("-" * 40) + + print(f"{engine.get_server_info()=}") + + avg_spec_accept_length = engine.get_server_info()["avg_spec_accept_length"] + print(f"{avg_spec_accept_length=}") + self.assertGreater(avg_spec_accept_length, 1.9) + def _test_eos_token(self, engine): prompt = "[INST] <>\nYou are a helpful assistant.\n<>\nToday is a sunny day and I like [/INST]" params = { @@ -88,32 +115,54 @@ class TestEAGLEEngine(unittest.TestCase): tokens = tokenizer.encode(output, truncation=False) self.assertNotIn(tokenizer.eos_token_id, tokens) - def _test_batch_generation(self, engine): - prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", + def _test_acc_length(self, engine): + prompt = [ + "Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:" ] - params = {"temperature": 0, "max_new_tokens": 30} + sampling_params = {"temperature": 0, "max_new_tokens": 512} + output = engine.generate(prompt, sampling_params) + output = output[0] - outputs = engine.generate(prompts, params) - for prompt, output in zip(prompts, outputs): - print(f"Prompt: {prompt}") - print(f"Generated: {output['text']}") - print("-" * 40) + if "spec_verify_ct" in output["meta_info"]: + acc_length = ( + output["meta_info"]["completion_tokens"] + / output["meta_info"]["spec_verify_ct"] + ) + else: + acc_length = 1.0 + + speed = ( + output["meta_info"]["completion_tokens"] + / output["meta_info"]["e2e_latency"] + ) + print(f"{acc_length=}") + self.assertGreater(acc_length, 3.6) -prompts = [ - "[INST] <>\\nYou are a helpful assistant.\\n<>\\nToday is a sunny day and I like[/INST]" - '[INST] <>\\nYou are a helpful assistant.\\n<>\\nWhat are the mental triggers in Jeff Walker\'s Product Launch Formula and "Launch" book?[/INST]', - "[INST] <>\\nYou are a helpful assistant.\\n<>\\nSummarize Russell Brunson's Perfect Webinar Script...[/INST]", - "[INST] <>\\nYou are a helpful assistant.\\n<>\\nwho are you?[/INST]", - "[INST] <>\\nYou are a helpful assistant.\\n<>\\nwhere are you from?[/INST]", -] +class TestEAGLEEngineTokenMap(unittest.TestCase): + BASE_CONFIG = { + "model_path": "meta-llama/Meta-Llama-3-8B-Instruct", + "speculative_draft_model_path": "lmsys/sglang-EAGLE-LLaMA3-Instruct-8B", + "speculative_algorithm": "EAGLE", + "speculative_num_steps": 5, + "speculative_eagle_topk": 4, + "speculative_num_draft_tokens": 8, + "speculative_token_map": "thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt", + "mem_fraction_static": 0.7, + "cuda_graph_max_bs": 5, + } + NUM_CONFIGS = 1 class TestEAGLEServer(unittest.TestCase): + PROMPTS = [ + "[INST] <>\\nYou are a helpful assistant.\\n<>\\nToday is a sunny day and I like[/INST]" + '[INST] <>\\nYou are a helpful assistant.\\n<>\\nWhat are the mental triggers in Jeff Walker\'s Product Launch Formula and "Launch" book?[/INST]', + "[INST] <>\\nYou are a helpful assistant.\\n<>\\nSummarize Russell Brunson's Perfect Webinar Script...[/INST]", + "[INST] <>\\nYou are a helpful assistant.\\n<>\\nwho are you?[/INST]", + "[INST] <>\\nYou are a helpful assistant.\\n<>\\nwhere are you from?[/INST]", + ] + @classmethod def setUpClass(cls): cls.base_url = DEFAULT_URL_FOR_TEST @@ -127,17 +176,17 @@ class TestEAGLEServer(unittest.TestCase): "--speculative-draft-model-path", DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, "--speculative-num-steps", - "5", + 5, "--speculative-eagle-topk", - "8", + 8, "--speculative-num-draft-tokens", - "64", + 64, "--mem-fraction-static", - "0.7", + 0.7, "--chunked-prefill-size", - "128", - "--cuda-graph-max-bs", - "32", + 128, + "--max-running-requests", + 8, ], ) @@ -147,7 +196,7 @@ class TestEAGLEServer(unittest.TestCase): def send_request(self): time.sleep(random.uniform(0, 2)) - for prompt in prompts: + for prompt in self.PROMPTS: url = self.base_url + "/generate" data = { "text": prompt, @@ -160,7 +209,7 @@ class TestEAGLEServer(unittest.TestCase): assert response.status_code == 200 def send_requests_abort(self): - for prompt in prompts: + for prompt in self.PROMPTS: try: time.sleep(random.uniform(0, 2)) url = self.base_url + "/generate" @@ -192,6 +241,8 @@ class TestEAGLEServer(unittest.TestCase): p.join() def test_gsm8k(self): + server_info = requests.get(self.base_url + "/flush_cache") + args = SimpleNamespace( num_shots=5, data_path=None, @@ -201,96 +252,25 @@ class TestEAGLEServer(unittest.TestCase): host="http://127.0.0.1", port=int(self.base_url.split(":")[-1]), ) + metrics = run_eval(args) print(f"{metrics=}") - self.assertGreater(metrics["accuracy"], 0.20) + server_info = requests.get(self.base_url + "/get_server_info") + avg_spec_accept_length = server_info.json()["avg_spec_accept_length"] + print(f"{avg_spec_accept_length=}") + self.assertGreater(avg_spec_accept_length, 2.9) -def measure_acc_rate(engine): - tic = time.time() - prompt = [ - "Human: Give me a fully functional FastAPI server. Show the python code.<|separator|>\n\nAssistant:" - ] - sampling_params = {"temperature": 0, "max_new_tokens": 512} - output = engine.generate(prompt, sampling_params) - output = output[0] - latency = time.time() - tic - - if "spec_verify_ct" in output["meta_info"]: - base_acc_length = ( - output["meta_info"]["completion_tokens"] - / output["meta_info"]["spec_verify_ct"] - ) - else: - base_acc_length = 0.0 - - base_speed = output["meta_info"]["completion_tokens"] / latency - return base_acc_length, base_speed + # Wait a little bit so that the memory check happens. + time.sleep(4) -class TestEagleAcceptanceRate(unittest.TestCase): - - @classmethod - def setUpClass(cls): - mp.set_start_method("spawn", force=True) - ref_engine = sgl.Engine( - model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, - speculative_draft_model_path=DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, - speculative_algorithm="EAGLE", - speculative_num_steps=5, - speculative_eagle_topk=8, - speculative_num_draft_tokens=64, - mem_fraction_static=0.7, - disable_radix_cache=True, - ) - cls.base_acc_length, cls.base_speed = measure_acc_rate(ref_engine) - ref_engine.shutdown() - assert cls.base_acc_length > 4.45 - - def test_acc_rate(self): - base_acc_length, base_speed = self.base_acc_length, self.base_speed - chunk_engine = sgl.Engine( - model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, - speculative_draft_model_path=DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, - speculative_algorithm="EAGLE", - speculative_num_steps=5, - speculative_eagle_topk=8, - speculative_num_draft_tokens=64, - mem_fraction_static=0.7, - chunked_prefill_size=2, - disable_radix_cache=True, - ) - chunked_acc_length, chunked_base_speed = measure_acc_rate(chunk_engine) - chunk_engine.shutdown() - print(base_acc_length, base_speed) - print(chunked_acc_length, chunked_base_speed) - assert abs(base_acc_length - chunked_acc_length) < acc_rate_tolerance - - def test_acc_rate_prefix_caching(self): - base_acc_length, base_speed = self.base_acc_length, self.base_speed - prefix_caching_engine = sgl.Engine( - model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, - speculative_draft_model_path=DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, - speculative_algorithm="EAGLE", - speculative_num_steps=5, - speculative_eagle_topk=8, - speculative_num_draft_tokens=64, - mem_fraction_static=0.7, - chunked_prefill_size=4, - schedule_policy="lpm", - ) - for _ in range(10): - acc_length, _ = measure_acc_rate(prefix_caching_engine) - print(f"{acc_length=}") - assert abs(base_acc_length - acc_length) < acc_rate_tolerance - # The second one should hit the prefix cache. - prefix_caching_engine.shutdown() - - -class TestEAGLERetract(unittest.TestCase): +class TestEAGLERetract(TestEAGLEServer): @classmethod def setUpClass(cls): + # These config helps find a leak. + os.environ["SGLANG_CI_SMALL_KV_SIZE"] = "4500" cls.base_url = DEFAULT_URL_FOR_TEST cls.process = popen_launch_server( DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, @@ -302,41 +282,20 @@ class TestEAGLERetract(unittest.TestCase): "--speculative-draft-model-path", DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, "--speculative-num-steps", - "5", + 5, "--speculative-eagle-topk", - "8", + 8, "--speculative-num-draft-tokens", - "64", + 64, "--mem-fraction-static", - "0.7", + 0.7, "--chunked-prefill-size", - "128", + 128, "--max-running-requests", - "64", + 64, ], ) - @classmethod - def tearDownClass(cls): - kill_process_tree(cls.process.pid) - - def test_gsm8k(self): - args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), - ) - metrics = run_eval(args) - print(f"{metrics=}") - - self.assertGreater(metrics["accuracy"], 0.20) - # Wait a little bit so that the memory check happens. - time.sleep(5) - class TestEAGLEServerTriton(TestEAGLEServer): @classmethod @@ -352,73 +311,20 @@ class TestEAGLEServerTriton(TestEAGLEServer): "--speculative-draft-model-path", DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, "--speculative-num-steps", - "5", + 5, "--speculative-eagle-topk", - "4", + 8, "--speculative-num-draft-tokens", - "8", + 64, "--mem-fraction-static", - "0.7", + 0.7, "--attention-backend", "triton", - "--cuda-graph-max-bs", - "16", + "--max-running-requests", + 8, ], ) -class TestEAGLEEngineTokenMap(unittest.TestCase): - def setUp(self): - self.prompt = "Today is a sunny day and I like" - self.sampling_params = {"temperature": 0, "max_new_tokens": 8} - - ref_engine = sgl.Engine( - model_path="meta-llama/Meta-Llama-3-8B-Instruct", cuda_graph_max_bs=2 - ) - self.ref_output = ref_engine.generate(self.prompt, self.sampling_params)["text"] - ref_engine.shutdown() - - def test_correctness(self): - config = { - "model_path": "meta-llama/Meta-Llama-3-8B-Instruct", - "speculative_draft_model_path": "lmsys/sglang-EAGLE-LLaMA3-Instruct-8B", - "speculative_algorithm": "EAGLE", - "speculative_num_steps": 5, - "speculative_eagle_topk": 4, - "speculative_num_draft_tokens": 8, - "speculative_token_map": "thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt", - "mem_fraction_static": 0.7, - "cuda_graph_max_bs": 4, - "dtype": "bfloat16", - } - - engine = sgl.Engine(**config) - try: - self._test_basic_generation(engine) - self._test_batch_generation(engine) - finally: - engine.shutdown() - - def _test_basic_generation(self, engine): - output = engine.generate(self.prompt, self.sampling_params)["text"] - print(f"{output=}, {self.ref_output=}") - self.assertEqual(output, self.ref_output) - - def _test_batch_generation(self, engine): - prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] - params = {"temperature": 0, "max_new_tokens": 30} - - outputs = engine.generate(prompts, params) - for prompt, output in zip(prompts, outputs): - print(f"Prompt: {prompt}") - print(f"Generated: {output['text']}") - print("-" * 40) - - if __name__ == "__main__": unittest.main() diff --git a/test/srt/test_metrics.py b/test/srt/test_metrics.py index 09b9b5a28..03dbf48c8 100644 --- a/test/srt/test_metrics.py +++ b/test/srt/test_metrics.py @@ -59,6 +59,7 @@ class TestEnableMetrics(unittest.TestCase): "sglang:spec_accept_length", "sglang:prompt_tokens_total", "sglang:generation_tokens_total", + "sglang:cached_tokens_total", "sglang:num_requests_total", "sglang:time_to_first_token_seconds", "sglang:time_per_output_token_seconds", diff --git a/test/srt/test_moe_ep.py b/test/srt/test_moe_ep.py index 9f87eb24d..054866e76 100644 --- a/test/srt/test_moe_ep.py +++ b/test/srt/test_moe_ep.py @@ -94,7 +94,7 @@ class TestEpMoEFP8(unittest.TestCase): ) metrics = run_eval(args) - assert metrics["score"] >= 0.5 + self.assertGreaterEqual(metrics["score"], 0.5) def test_mgsm_en(self): args = SimpleNamespace( @@ -106,7 +106,7 @@ class TestEpMoEFP8(unittest.TestCase): ) metrics = run_eval(args) - assert metrics["score"] >= 0.8 + self.assertGreaterEqual(metrics["score"], 0.8) if __name__ == "__main__":