diff --git a/python/pyproject.toml b/python/pyproject.toml index 0dc0ef63d..6eaa6263b 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -18,12 +18,15 @@ dependencies = ["requests", "tqdm", "numpy", "IPython", "setproctitle"] [project.optional-dependencies] runtime_common = [ "aiohttp", + "datasets", "decord", "fastapi", "hf_transfer", "huggingface_hub", "interegular", + "llguidance>=0.6.15", "modelscope", + "ninja", "orjson", "packaging", "pillow", @@ -33,13 +36,10 @@ runtime_common = [ "python-multipart", "pyzmq>=25.1.2", "torchao>=0.7.0", + "transformers==4.48.3", "uvicorn", "uvloop", "xgrammar==0.1.14", - "ninja", - "transformers==4.48.3", - "llguidance>=0.6.15", - "datasets" ] srt = [ diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 64ef15cf7..6f103bcc6 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -81,7 +81,7 @@ class ModelConfig: if context_length is not None: if context_length > derived_context_len: if get_bool_env_var( - "SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", default="False" + "SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", default="True" ): logger.warning( f"Warning: User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). " diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 7c0f287b7..f8a6b4e43 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -106,6 +106,8 @@ class Engine: tokenizer_manager, scheduler_info = _launch_subprocesses( server_args=server_args ) + + self.server_args = server_args self.tokenizer_manager = tokenizer_manager self.scheduler_info = scheduler_info diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index f471626e1..ec041305c 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -42,7 +42,6 @@ class Sampler(nn.Module): return_logprob: bool, top_logprobs_nums: List[int], token_ids_logprobs: List[List[int]], - batch_next_token_ids: Optional[torch.Tensor] = None, ): """Run a sampler & compute logprobs and update logits_output accordingly. @@ -72,8 +71,7 @@ class Sampler(nn.Module): if sampling_info.is_all_greedy: # Use torch.argmax if all requests use greedy sampling - if batch_next_token_ids is None: - batch_next_token_ids = torch.argmax(logits, -1) + batch_next_token_ids = torch.argmax(logits, -1) if return_logprob: logprobs = torch.nn.functional.log_softmax(logits, dim=-1) else: @@ -94,43 +92,39 @@ class Sampler(nn.Module): top_p_normalize_probs_torch(probs, sampling_info.top_ps) ).clamp(min=torch.finfo(probs.dtype).min) - if batch_next_token_ids is None: - max_top_k_round, batch_size = 32, probs.shape[0] - uniform_samples = torch.rand( - (max_top_k_round, batch_size), device=probs.device + max_top_k_round, batch_size = 32, probs.shape[0] + uniform_samples = torch.rand( + (max_top_k_round, batch_size), device=probs.device + ) + if sampling_info.need_min_p_sampling: + probs = top_k_renorm_prob(probs, sampling_info.top_ks) + probs = top_p_renorm_prob(probs, sampling_info.top_ps) + batch_next_token_ids = min_p_sampling_from_probs( + probs, uniform_samples, sampling_info.min_ps ) - if sampling_info.need_min_p_sampling: - probs = top_k_renorm_prob(probs, sampling_info.top_ks) - probs = top_p_renorm_prob(probs, sampling_info.top_ps) - batch_next_token_ids = min_p_sampling_from_probs( - probs, uniform_samples, sampling_info.min_ps - ) - else: - batch_next_token_ids, success = top_k_top_p_sampling_from_probs( - probs, - uniform_samples, - sampling_info.top_ks, - sampling_info.top_ps, - filter_apply_order="joint", - ) - - if self.use_nan_detection and not torch.all(success): - logger.warning("Detected errors during sampling!") - batch_next_token_ids = torch.zeros_like( - batch_next_token_ids - ) - - elif global_server_args_dict["sampling_backend"] == "pytorch": - if batch_next_token_ids is None: - # A slower fallback implementation with torch native operations. - batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch( + else: + batch_next_token_ids, success = top_k_top_p_sampling_from_probs( probs, + uniform_samples, sampling_info.top_ks, sampling_info.top_ps, - sampling_info.min_ps, - sampling_info.need_min_p_sampling, + filter_apply_order="joint", ) + if self.use_nan_detection and not torch.all(success): + logger.warning("Detected errors during sampling!") + batch_next_token_ids = torch.zeros_like(batch_next_token_ids) + + elif global_server_args_dict["sampling_backend"] == "pytorch": + # A slower fallback implementation with torch native operations. + batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch( + probs, + sampling_info.top_ks, + sampling_info.top_ps, + sampling_info.min_ps, + sampling_info.need_min_p_sampling, + ) + if return_logprob: # clamp to avoid -inf logprobs = torch.log( diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index cb3a4b5de..a5c6a1dbd 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -957,11 +957,13 @@ class Scheduler: self.req_to_token_pool.free(self.chunked_req.req_pool_idx) self.batch_is_full = False + # Filter batch last_bs = self.last_batch.batch_size() self.last_batch.filter_batch() if self.last_batch.batch_size() < last_bs: self.batch_is_full = False + # Merge the new batch into the running batch if not self.last_batch.is_empty(): if self.running_batch is None: self.running_batch = self.last_batch diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 813fbf6fc..6a2bab22a 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -300,10 +300,11 @@ class CudaGraphRunner: def capture(self): with graph_capture() as graph_capture_context: self.stream = graph_capture_context.stream + # Reverse the order to enable better memory sharing across cuda graphs. capture_range = ( - tqdm.tqdm(self.capture_bs) + tqdm.tqdm(reversed(self.capture_bs)) if get_tensor_model_parallel_rank() == 0 - else self.capture_bs + else reversed(self.capture_bs) ) for bs in capture_range: with patch_model( diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 666b97e2b..6489ea6ed 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -928,45 +928,6 @@ class ModelRunner: sampling_info.update_regex_vocab_mask() sampling_info.apply_logits_bias(logits_output.next_token_logits) - def update_output_logprobs( - self, - logits_output: LogitsProcessorOutput, - sampling_info: SamplingBatchInfo, - top_logprobs_nums: List[int], - token_ids_logprobs: List[int], - next_token_ids: torch.Tensor, - *, - num_tokens_per_req: List[int], - ): - """Update the logits_output's output logprob based on next_token_ids - - Args: - logits_output: The logits output from the model forward - sampling_info: Sampling info for logprob calculation - top_logprobs_nums: Number of logprobs per request. - next_token_ids: Next token ids. - num_tokens_per_req: The number of tokens per request. - - Returns: - A list of next_token_ids - """ - self._preprocess_logits(logits_output, sampling_info) - # We should repeat top_logprobs_nums to match num_tokens_per_req. - top_logprobs_nums_repeat_interleaved = [] - token_ids_logprobs_repeat_interleaved = [] - for num, num_tokens in zip(top_logprobs_nums, num_tokens_per_req): - top_logprobs_nums_repeat_interleaved.extend([num] * num_tokens) - for token_ids, num_tokens in zip(token_ids_logprobs, num_tokens_per_req): - token_ids_logprobs_repeat_interleaved.extend([token_ids] * num_tokens) - self.sampler( - logits_output, - sampling_info, - True, - top_logprobs_nums_repeat_interleaved, - token_ids_logprobs_repeat_interleaved, - batch_next_token_ids=next_token_ids, - ) - def sample( self, logits_output: LogitsProcessorOutput, diff --git a/python/sglang/srt/sampling/penaltylib/frequency_penalty.py b/python/sglang/srt/sampling/penaltylib/frequency_penalty.py index 691534627..893a1c377 100644 --- a/python/sglang/srt/sampling/penaltylib/frequency_penalty.py +++ b/python/sglang/srt/sampling/penaltylib/frequency_penalty.py @@ -56,7 +56,6 @@ class BatchedFrequencyPenalizer(_BatchedPenalizer): ] def _merge(self, their: "BatchedFrequencyPenalizer"): - print(f"{self.frequency_penalties.shape=}, {their.frequency_penalties.shape=}") self.frequency_penalties = torch.cat( [self.frequency_penalties, their.frequency_penalties], dim=0 ) diff --git a/python/sglang/srt/sampling/penaltylib/presence_penalty.py b/python/sglang/srt/sampling/penaltylib/presence_penalty.py index 91266b352..4f3a6ace3 100644 --- a/python/sglang/srt/sampling/penaltylib/presence_penalty.py +++ b/python/sglang/srt/sampling/penaltylib/presence_penalty.py @@ -56,7 +56,6 @@ class BatchedPresencePenalizer(_BatchedPenalizer): ] def _merge(self, their: "BatchedPresencePenalizer"): - print(f"{self.presence_penalties.shape=}, {their.presence_penalties.shape=}") self.presence_penalties = torch.cat( [self.presence_penalties, their.presence_penalties], dim=0 ) diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 12da787eb..bd2fa6009 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -7,6 +7,7 @@ import torch from huggingface_hub import snapshot_download from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.model_executor.forward_batch_info import ( @@ -302,13 +303,10 @@ class EAGLEWorker(TpModelWorker): # Set inputs forward_batch.input_ids = input_ids + out_cache_loc = out_cache_loc.view(forward_batch.batch_size, -1) forward_batch.out_cache_loc = out_cache_loc[ - forward_batch.batch_size - * self.topk - * i : forward_batch.batch_size - * self.topk - * (i + 1) - ] + :, self.topk * i : self.topk * (i + 1) + ].flatten() forward_batch.positions.add_(1) forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i] spec_info.hidden_states = hidden_states @@ -353,42 +351,70 @@ class EAGLEWorker(TpModelWorker): batch.spec_info = res.draft_input if batch.return_logprob: - # Compute output logprobs using the sampler. - num_tokens_per_req = [ - accept + 1 for accept in res.accept_length_per_req_cpu - ] - self.target_worker.model_runner.update_output_logprobs( - logits_output, - batch.sampling_info, - batch.top_logprobs_nums, - batch.token_ids_logprobs, - res.verified_id, - # +1 for bonus token. - num_tokens_per_req=num_tokens_per_req, - ) - - # Add output logprobs to the request. - pt = 0 - # NOTE: tolist() of these values are skipped when output is processed - next_token_logprobs = res.logits_output.next_token_logprobs.tolist() - verified_ids = res.verified_id.tolist() - for req, num_tokens in zip(batch.reqs, num_tokens_per_req): - for _ in range(num_tokens): - if req.return_logprob: - token_id = verified_ids[pt] - req.output_token_logprobs_val.append(next_token_logprobs[pt]) - req.output_token_logprobs_idx.append(token_id) - if req.top_logprobs_num > 0: - req.output_top_logprobs_val.append( - res.logits_output.next_token_top_logprobs_val[pt] - ) - req.output_top_logprobs_idx.append( - res.logits_output.next_token_top_logprobs_idx[pt] - ) - pt += 1 + self.add_logprob_values(batch, res, logits_output) return logits_output, res, model_worker_batch + def add_logprob_values( + self, + batch: ScheduleBatch, + res: EagleVerifyOutput, + logits_output: LogitsProcessorOutput, + ): + # Extract args + logits_output = res.logits_output + top_logprobs_nums = batch.top_logprobs_nums + token_ids_logprobs = batch.token_ids_logprobs + logprobs = torch.nn.functional.log_softmax( + logits_output.next_token_logits, dim=-1 + ) + batch_next_token_ids = res.verified_id + num_tokens_per_req = [accept + 1 for accept in res.accept_length_per_req_cpu] + + # We should repeat top_logprobs_nums to match num_tokens_per_req. + top_logprobs_nums_repeat_interleaved = [] + token_ids_logprobs_repeat_interleaved = [] + for num, num_tokens in zip(top_logprobs_nums, num_tokens_per_req): + top_logprobs_nums_repeat_interleaved.extend([num] * num_tokens) + for token_ids, num_tokens in zip(token_ids_logprobs, num_tokens_per_req): + token_ids_logprobs_repeat_interleaved.extend([token_ids] * num_tokens) + + # Extract logprobs + if any(x > 0 for x in top_logprobs_nums): + ( + logits_output.next_token_top_logprobs_val, + logits_output.next_token_top_logprobs_idx, + ) = get_top_logprobs(logprobs, top_logprobs_nums_repeat_interleaved) + + if any(x is not None for x in token_ids_logprobs): + ( + logits_output.next_token_token_ids_logprobs_val, + logits_output.next_token_token_ids_logprobs_idx, + ) = get_token_ids_logprobs(logprobs, token_ids_logprobs_repeat_interleaved) + + logits_output.next_token_logprobs = logprobs[ + torch.arange(len(batch_next_token_ids), device=batch.sampling_info.device), + batch_next_token_ids, + ] + + # Add output logprobs to the request. + pt = 0 + next_token_logprobs = logits_output.next_token_logprobs.tolist() + verified_ids = batch_next_token_ids.tolist() + for req, num_tokens in zip(batch.reqs, num_tokens_per_req): + for _ in range(num_tokens): + if req.return_logprob: + req.output_token_logprobs_val.append(next_token_logprobs[pt]) + req.output_token_logprobs_idx.append(verified_ids[pt]) + if req.top_logprobs_num > 0: + req.output_top_logprobs_val.append( + res.logits_output.next_token_top_logprobs_val[pt] + ) + req.output_top_logprobs_idx.append( + res.logits_output.next_token_top_logprobs_idx[pt] + ) + pt += 1 + def forward_draft_extend( self, batch: ScheduleBatch, diff --git a/test/lang/test_srt_backend.py b/test/lang/test_srt_backend.py index 92e828c26..29f7a12a2 100644 --- a/test/lang/test_srt_backend.py +++ b/test/lang/test_srt_backend.py @@ -76,7 +76,7 @@ class TestSRTBackend(unittest.TestCase): # Run twice to capture more bugs for _ in range(2): accuracy, latency = test_hellaswag_select() - self.assertGreater(accuracy, 0.65) + self.assertGreater(accuracy, 0.60) def test_gen_min_new_tokens(self): test_gen_min_new_tokens() diff --git a/test/srt/test_eagle_infer.py b/test/srt/test_eagle_infer.py index cadca667b..5b89071b6 100644 --- a/test/srt/test_eagle_infer.py +++ b/test/srt/test_eagle_infer.py @@ -123,7 +123,7 @@ class TestEAGLEEngine(unittest.TestCase): def _test_acc_length(self, engine): prompt = [ "Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:" - ] + ] * 5 sampling_params = {"temperature": 0, "max_new_tokens": 512} output = engine.generate(prompt, sampling_params) output = output[0] @@ -141,10 +141,14 @@ class TestEAGLEEngine(unittest.TestCase): / output["meta_info"]["e2e_latency"] ) print(f"{acc_length=}") - self.assertGreater(acc_length, 3.6) + + if engine.server_args.model_path == DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST: + self.assertGreater(acc_length, 3.6) + else: + self.assertGreater(acc_length, 2.6) -class TestEAGLEEngineTokenMap(unittest.TestCase): +class TestEAGLEEngineTokenMap(TestEAGLEEngine): BASE_CONFIG = { "model_path": "meta-llama/Meta-Llama-3-8B-Instruct", "speculative_draft_model_path": "lmsys/sglang-EAGLE-LLaMA3-Instruct-8B", @@ -155,6 +159,7 @@ class TestEAGLEEngineTokenMap(unittest.TestCase): "speculative_token_map": "thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt", "mem_fraction_static": 0.7, "cuda_graph_max_bs": 5, + "dtype": "float16", } NUM_CONFIGS = 1 @@ -245,8 +250,25 @@ class TestEAGLEServer(unittest.TestCase): for p in threads: p.join() + def test_max_token_one(self): + requests.get(self.base_url + "/flush_cache") + + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=1, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + + # Just run and check it does not hang + metrics = run_eval(args) + self.assertGreater(metrics["output_throughput"], 50) + def test_gsm8k(self): - server_info = requests.get(self.base_url + "/flush_cache") + requests.get(self.base_url + "/flush_cache") args = SimpleNamespace( num_shots=5, @@ -391,6 +413,53 @@ class TestEAGLEServer(unittest.TestCase): with ThreadPoolExecutor(8) as executor: list(executor.map(func, args)) + def run_decode(self, sampling_params): + return_logprob = True + top_logprobs_num = 5 + return_text = True + n = 1 + + response = requests.post( + self.base_url + "/generate", + json={ + "text": "Human: Write a travel blog post to Hawaii.\n\nAssistant:", + "sampling_params": { + "max_new_tokens": 48, + "n": n, + "temperature": 0.7, + **sampling_params, + }, + "return_logprob": return_logprob, + "top_logprobs_num": top_logprobs_num, + "return_text_in_logprobs": return_text, + "logprob_start_len": 0, + }, + ) + self.assertEqual(response.status_code, 200) + print(json.dumps(response.json())) + print("=" * 100) + + def test_penalty_mixed(self): + args = [ + {}, + {}, + {}, + {"frequency_penalty": 2}, + {"presence_penalty": 1}, + {"min_new_tokens": 16}, + {"frequency_penalty": 0.2}, + {"presence_penalty": 0.4}, + {"min_new_tokens": 8}, + {"frequency_penalty": 0.4, "presence_penalty": 0.8}, + {"frequency_penalty": 0.4, "min_new_tokens": 12}, + {"presence_penalty": 0.8, "min_new_tokens": 12}, + {"presence_penalty": -0.3, "frequency_penalty": 1.3, "min_new_tokens": 32}, + {"presence_penalty": 0.3, "frequency_penalty": -1.3, "min_new_tokens": 32}, + ] + random.shuffle(args * 5) + with ThreadPoolExecutor(8) as executor: + list(executor.map(self.run_decode, args)) + class TestEAGLERetract(TestEAGLEServer): @classmethod diff --git a/test/srt/test_eval_accuracy_large.py b/test/srt/test_eval_accuracy_large.py index dd923777f..f5e0e3cdb 100644 --- a/test/srt/test_eval_accuracy_large.py +++ b/test/srt/test_eval_accuracy_large.py @@ -44,11 +44,12 @@ class TestEvalAccuracyLarge(unittest.TestCase): ) metrics = run_eval(args) - self.assertGreater(metrics["score"], 0.71) if is_in_ci(): write_github_step_summary(f"### test_mmlu\n" f'{metrics["score"]=:.4f}\n') + self.assertGreater(metrics["score"], 0.71) + def test_human_eval(self): args = SimpleNamespace( base_url=self.base_url, @@ -59,13 +60,14 @@ class TestEvalAccuracyLarge(unittest.TestCase): ) metrics = run_eval(args) - self.assertGreater(metrics["score"], 0.64) if is_in_ci(): write_github_step_summary( f"### test_human_eval\n" f'{metrics["score"]=:.4f}\n' ) + self.assertGreater(metrics["score"], 0.64) + def test_mgsm_en(self): args = SimpleNamespace( base_url=self.base_url, @@ -76,13 +78,14 @@ class TestEvalAccuracyLarge(unittest.TestCase): ) metrics = run_eval(args) - self.assertGreater(metrics["score"], 0.835) if is_in_ci(): write_github_step_summary( f"### test_mgsm_en\n" f'{metrics["score"]=:.4f}\n' ) + self.assertGreater(metrics["score"], 0.835) + if __name__ == "__main__": unittest.main() diff --git a/test/srt/test_mla.py b/test/srt/test_mla.py index a019988ab..b2a831f99 100644 --- a/test/srt/test_mla.py +++ b/test/srt/test_mla.py @@ -1,6 +1,7 @@ import unittest from types import SimpleNamespace +import requests import torch from sglang.srt.utils import kill_process_tree @@ -129,6 +130,8 @@ class TestDeepseekV3MTP(unittest.TestCase): kill_process_tree(cls.process.pid) def test_gsm8k(self): + requests.get(self.base_url + "/flush_cache") + args = SimpleNamespace( num_shots=5, data_path=None, @@ -143,6 +146,11 @@ class TestDeepseekV3MTP(unittest.TestCase): self.assertGreater(metrics["accuracy"], 0.60) + server_info = requests.get(self.base_url + "/get_server_info") + avg_spec_accept_length = server_info.json()["avg_spec_accept_length"] + print(f"{avg_spec_accept_length=}") + self.assertGreater(avg_spec_accept_length, 2.5) + if __name__ == "__main__": unittest.main() diff --git a/test/srt/test_penalty.py b/test/srt/test_penalty.py index cb9b6b3dc..e1d11a9ac 100644 --- a/test/srt/test_penalty.py +++ b/test/srt/test_penalty.py @@ -42,7 +42,7 @@ class TestPenalty(unittest.TestCase): # prompt that is supposed to generate < 32 tokens "text": "<|start_header_id|>user<|end_header_id|>\n\nWhat is the answer for 1 + 1 = ?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", "sampling_params": { - "max_new_tokens": 32, + "max_new_tokens": 48, "n": n, **sampling_params, }, @@ -68,19 +68,22 @@ class TestPenalty(unittest.TestCase): def test_presence_penalty(self): self.run_decode({"presence_penalty": 2}) - def test_mixed(self): + def test_penalty_mixed(self): args = [ {}, {}, {}, {"frequency_penalty": 2}, - {"min_new_tokens": 16}, {"presence_penalty": 1}, + {"min_new_tokens": 16}, {"frequency_penalty": 0.2}, - {"min_new_tokens": 8}, {"presence_penalty": 0.4}, - {"presence_penalty": 0.4, "frequency_penalty": 2}, - {"min_new_tokens": 12, "frequency_penalty": 2}, + {"min_new_tokens": 8}, + {"frequency_penalty": 0.4, "presence_penalty": 0.8}, + {"frequency_penalty": 0.4, "min_new_tokens": 12}, + {"presence_penalty": 0.8, "min_new_tokens": 12}, + {"presence_penalty": -0.3, "frequency_penalty": 1.3, "min_new_tokens": 32}, + {"presence_penalty": 0.3, "frequency_penalty": -1.3, "min_new_tokens": 32}, ] random.shuffle(args * 5) with ThreadPoolExecutor(8) as executor: