diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index df3f2c5ea..a1929cbe0 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -50,7 +50,7 @@ jobs: timeout-minutes: 25 run: | cd test/srt - python3 run_suite.py --suite minimal --range-begin 0 --range-end 5 + python3 run_suite.py --suite minimal --range-begin 0 --range-end 6 unit-test-backend-part-2: if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' @@ -67,7 +67,7 @@ jobs: timeout-minutes: 25 run: | cd test/srt - python3 run_suite.py --suite minimal --range-begin 5 --range-end 14 + python3 run_suite.py --suite minimal --range-begin 6 --range-end 14 unit-test-backend-part-3: if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' @@ -103,6 +103,31 @@ jobs: cd test/srt python3 run_suite.py --suite minimal --range-begin 21 + unit-test-backend-2-gpu-part-1: + if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' + runs-on: 2-gpu-runner + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Install dependencies + run: | + bash scripts/ci_install_dependency.sh + + - name: Evaluate data parallelism accuracy (DP=2) + timeout-minutes: 10 + run: | + cd test/srt + python3 test_data_parallelism.py + + - name: Evaluate MLA accuracy (TP=2) + timeout-minutes: 10 + run: | + cd test/srt + python3 test_mla.py + python3 test_mla_fp8.py + python3 test_dp_attention.py + performance-test-1-gpu-part-1: if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' runs-on: 1-gpu-runner @@ -178,6 +203,12 @@ jobs: run: | bash scripts/ci_install_dependency.sh + - name: Benchmark single latency (TP=2) + timeout-minutes: 10 + run: | + cd test/srt + python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_moe_default + - name: Benchmark offline throughput (TP=2) timeout-minutes: 10 run: | @@ -190,12 +221,6 @@ jobs: cd test/srt python3 -m unittest test_bench_serving.TestBenchServing.test_moe_offline_throughput_without_radix_cache - - name: Benchmark single latency (TP=2) - timeout-minutes: 10 - run: | - cd test/srt - python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_moe_default - accuracy-test-1-gpu: if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' runs-on: 1-gpu-runner @@ -238,23 +263,10 @@ jobs: cd test/srt python3 test_moe_eval_accuracy_large.py - - name: Evaluate MLA accuracy (TP=2) - timeout-minutes: 10 - run: | - cd test/srt - python3 test_mla.py - python3 test_mla_fp8.py - python3 test_dp_attention.py - - - name: Evaluate data parallelism accuracy (DP=2) - timeout-minutes: 10 - run: | - cd test/srt - python3 test_data_parallelism.py - finish: needs: [ unit-test-frontend, unit-test-backend-part-1, unit-test-backend-part-2, unit-test-backend-part-3, unit-test-backend-part-4, + unit-test-backend-2-gpu-part-1, performance-test-1-gpu-part-1, performance-test-1-gpu-part-2, performance-test-2-gpu, accuracy-test-1-gpu, accuracy-test-2-gpu ] diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py index 9313fbf6f..ec3308692 100644 --- a/python/sglang/bench_one_batch.py +++ b/python/sglang/bench_one_batch.py @@ -212,6 +212,7 @@ def extend(reqs, model_runner): token_to_kv_pool=model_runner.token_to_kv_pool, tree_cache=None, model_config=model_runner.model_config, + enable_overlap=False, ) batch.prepare_for_extend() model_worker_batch = batch.get_model_worker_batch() diff --git a/python/sglang/srt/layers/fused_moe/patch.py b/python/sglang/srt/layers/fused_moe_patch.py similarity index 95% rename from python/sglang/srt/layers/fused_moe/patch.py rename to python/sglang/srt/layers/fused_moe_patch.py index 6e64c89aa..400ca03c4 100644 --- a/python/sglang/srt/layers/fused_moe/patch.py +++ b/python/sglang/srt/layers/fused_moe_patch.py @@ -1,3 +1,8 @@ +""" +Torch-native implementation for FusedMoE. This is used for torch.compile. +It is based on https://github.com/pytorch-labs/gpt-fast/blob/32971d3129541c5bfb4f715abc33d1c5f408d204/mixtral-moe/model.py#L204 +""" + from typing import Callable, Optional import torch diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index c9f0ea676..8bfb8e8f7 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -437,9 +437,12 @@ class ScheduleBatch: token_to_kv_pool: BaseTokenToKVPool = None tree_cache: BasePrefixCache = None - # For utility + # Batch configs model_config: ModelConfig = None forward_mode: ForwardMode = None + enable_overlap: bool = False + + # Sampling info sampling_info: SamplingBatchInfo = None next_batch_sampling_info: SamplingBatchInfo = None @@ -488,10 +491,11 @@ class ScheduleBatch: def init_new( cls, reqs: List[Req], - req_to_token_pool, - token_to_kv_pool, - tree_cache, - model_config, + req_to_token_pool: ReqToTokenPool, + token_to_kv_pool: ReqToTokenPool, + tree_cache: BasePrefixCache, + model_config: ModelConfig, + enable_overlap: bool, ): return cls( reqs=reqs, @@ -499,6 +503,7 @@ class ScheduleBatch: token_to_kv_pool=token_to_kv_pool, tree_cache=tree_cache, model_config=model_config, + enable_overlap=enable_overlap, return_logprob=any(req.return_logprob for req in reqs), has_stream=any(req.stream for req in reqs), has_grammar=any(req.grammar for req in reqs), @@ -612,7 +617,7 @@ class ScheduleBatch: assert len(self.out_cache_loc) == self.extend_num_tokens - def prepare_for_extend(self, enable_overlap_schedule: bool = False): + def prepare_for_extend(self): self.forward_mode = ForwardMode.EXTEND bs = len(self.reqs) @@ -706,7 +711,7 @@ class ScheduleBatch: self.sampling_info = SamplingBatchInfo.from_schedule_batch( self, self.model_config.vocab_size, - enable_overlap_schedule=enable_overlap_schedule, + enable_overlap_schedule=self.enable_overlap, ) def mix_with_running(self, running_batch: "ScheduleBatch"): @@ -897,7 +902,7 @@ class ScheduleBatch: self.seq_lens_sum = 0 self.extend_num_tokens = 0 - def prepare_for_decode(self, enable_overlap: bool = False): + def prepare_for_decode(self): self.forward_mode = ForwardMode.DECODE self.input_ids = self.output_ids @@ -914,7 +919,7 @@ class ScheduleBatch: else: locs = self.seq_lens - if enable_overlap: + if self.enable_overlap: # Do not use in-place operations in the overlap mode self.req_to_token_pool.write( (self.req_pool_indices, locs), self.out_cache_loc diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 5e5b4c685..165c7f66f 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -466,6 +466,7 @@ class Scheduler: self.token_to_kv_pool, self.tree_cache, self.model_config, + self.enable_overlap, ) idle_batch.prepare_for_idle() return idle_batch @@ -842,14 +843,15 @@ class Scheduler: self.token_to_kv_pool, self.tree_cache, self.model_config, + self.enable_overlap, ) - new_batch.prepare_for_extend(self.enable_overlap) + new_batch.prepare_for_extend() # Mixed-style chunked prefill if self.is_mixed_chunk and self.running_batch is not None: self.running_batch.filter_batch() if not self.running_batch.is_empty(): - self.running_batch.prepare_for_decode(self.enable_overlap) + self.running_batch.prepare_for_decode() new_batch.mix_with_running(self.running_batch) new_batch.decoding_reqs = self.running_batch.reqs self.running_batch = None @@ -900,7 +902,7 @@ class Scheduler: self.batch_is_full = False # Update batch tensors - batch.prepare_for_decode(self.enable_overlap) + batch.prepare_for_decode() return batch def run_batch(self, batch: ScheduleBatch): @@ -1055,6 +1057,7 @@ class Scheduler: continue if self.enable_overlap and req.finished(): + # Free the one delayed token self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1]) continue diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index a31341470..02bd358b0 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -23,7 +23,7 @@ import torch from vllm.distributed.parallel_state import graph_capture from vllm.model_executor.custom_op import CustomOp -from sglang.srt.layers.fused_moe.patch import fused_moe_forward_native +from sglang.srt.layers.fused_moe_patch import fused_moe_forward_native from sglang.srt.layers.logits_processor import ( LogitsMetadata, LogitsProcessor, diff --git a/test/srt/test_bench_serving.py b/test/srt/test_bench_serving.py index 4584a86e0..5e2d6389b 100644 --- a/test/srt/test_bench_serving.py +++ b/test/srt/test_bench_serving.py @@ -20,7 +20,7 @@ class TestBenchServing(unittest.TestCase): ) if is_in_ci(): - self.assertGreater(res["output_throughput"], 2850) + self.assertGreater(res["output_throughput"], 3350) def test_offline_throughput_non_stream_small_batch_size(self): res = run_bench_serving( @@ -47,7 +47,7 @@ class TestBenchServing(unittest.TestCase): ) if is_in_ci(): - self.assertGreater(res["output_throughput"], 2900) + self.assertGreater(res["output_throughput"], 3350) def test_offline_throughput_without_chunked_prefill(self): res = run_bench_serving( @@ -74,7 +74,7 @@ class TestBenchServing(unittest.TestCase): ) if is_in_ci(): - self.assertGreater(res["output_throughput"], 2950) + self.assertGreater(res["output_throughput"], 3450) def test_offline_throughput_default_fp8(self): res = run_bench_serving( @@ -85,7 +85,7 @@ class TestBenchServing(unittest.TestCase): ) if is_in_ci(): - self.assertGreater(res["output_throughput"], 3200) + self.assertGreater(res["output_throughput"], 3850) def test_online_latency_default(self): res = run_bench_serving( @@ -109,7 +109,7 @@ class TestBenchServing(unittest.TestCase): ) if is_in_ci(): - self.assertGreater(res["output_throughput"], 1900) + self.assertGreater(res["output_throughput"], 2150) def test_moe_offline_throughput_without_radix_cache(self): res = run_bench_serving( @@ -120,7 +120,7 @@ class TestBenchServing(unittest.TestCase): ) if is_in_ci(): - self.assertGreater(res["output_throughput"], 1950) + self.assertGreater(res["output_throughput"], 2150) if __name__ == "__main__": diff --git a/test/srt/test_srt_endpoint.py b/test/srt/test_srt_endpoint.py index 98c124fee..e5018a02c 100644 --- a/test/srt/test_srt_endpoint.py +++ b/test/srt/test_srt_endpoint.py @@ -6,6 +6,7 @@ python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_logprob_with_chunked_ import json import unittest +import numpy as np import requests from sglang.srt.utils import kill_child_process @@ -132,6 +133,7 @@ class TestSRTEndpoint(unittest.TestCase): ) def test_logprob_with_chunked_prefill(self): + """Test a long prompt that requests output logprobs will not hit OOM.""" new_tokens = 4 prompts = "I have a very good idea on this. " * 8000 @@ -154,6 +156,63 @@ class TestSRTEndpoint(unittest.TestCase): self.assertEqual(res["meta_info"]["completion_tokens"], new_tokens) self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), new_tokens) + def test_logprob_match(self): + """Test the output logprobs are close to the input logprobs if we run a prefill again.""" + + def run_generate( + prompt, return_logprob=False, max_new_tokens=512, logprob_start_len=-1 + ): + + if isinstance(prompt, str): + prompt_kwargs = {"text": prompt} + else: + prompt_kwargs = {"input_ids": prompt} + + response = requests.post( + self.base_url + "/generate", + json={ + **prompt_kwargs, + "sampling_params": { + "temperature": 1.0, + "max_new_tokens": max_new_tokens, + "ignore_eos": True, + }, + "return_logprob": return_logprob, + "return_text_in_logprobs": True, + "logprob_start_len": logprob_start_len, + }, + ) + return response.json() + + prompt = "I have a very good idea on how to" + + gen = run_generate(prompt, return_logprob=True, logprob_start_len=0) + output_logprobs = np.array( + [x[0] for x in gen["meta_info"]["output_token_logprobs"]] + ) + num_prompts_tokens = gen["meta_info"]["prompt_tokens"] + + input_tokens = [x[1] for x in gen["meta_info"]["input_token_logprobs"]] + output_tokens = [x[1] for x in gen["meta_info"]["output_token_logprobs"]] + + new_prompt = input_tokens + output_tokens + score = run_generate( + new_prompt, return_logprob=True, logprob_start_len=0, max_new_tokens=0 + ) + output_logprobs_score = np.array( + [ + x[0] + for x in score["meta_info"]["input_token_logprobs"][num_prompts_tokens:] + ] + ) + + print(f"{output_logprobs[-10:]=}") + print(f"{output_logprobs_score[-10:]=}") + + diff = np.abs(output_logprobs - output_logprobs_score) + max_diff = np.max(diff) + self.assertLess(max_diff, 0.2) + def test_get_server_info(self): response = requests.get(self.base_url + "/get_server_info") response_json = response.json()