diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index fa0fefdaf..8d4f839e8 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -144,18 +144,18 @@ jobs: cd test/srt python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_without_radix_cache - - name: Benchmark Offline Throughput (w/o ChunkedPrefill) - timeout-minutes: 10 - run: | - cd test/srt - python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_without_chunked_prefill - - name: Benchmark Offline Throughput (w/ Triton) timeout-minutes: 10 run: | cd test/srt python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_with_triton_attention_backend + - name: Benchmark Offline Throughput (w/ FP8) + timeout-minutes: 10 + run: | + cd test/srt + python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_default_fp8 + performance-test-2-gpu: if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' runs-on: 2-gpu-runner diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index 57c88f226..a01619b3a 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -305,8 +305,6 @@ class LlamaForCausalLM(nn.Module): self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) - self.param_dict = dict(self.named_parameters()) - @torch.no_grad() def forward( self, @@ -374,7 +372,7 @@ class LlamaForCausalLM(nn.Module): (".gate_up_proj", ".gate_proj", 0), (".gate_up_proj", ".up_proj", 1), ] - params_dict = self.param_dict + params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name or "projector" in name: diff --git a/python/sglang/srt/models/llama_classification.py b/python/sglang/srt/models/llama_classification.py index de37a00e6..c17651064 100644 --- a/python/sglang/srt/models/llama_classification.py +++ b/python/sglang/srt/models/llama_classification.py @@ -36,6 +36,7 @@ class LlamaForClassification(nn.Module): ) -> None: super().__init__() self.config = config + self.torchao_config = None self.quant_config = quant_config self.model = LlamaModel(config, quant_config=quant_config) @@ -44,8 +45,6 @@ class LlamaForClassification(nn.Module): ) self.eos_token_id = config.eos_token_id - self.param_dict = dict(self.named_parameters()) - @torch.no_grad() def forward( self, @@ -77,7 +76,7 @@ class LlamaForClassification(nn.Module): return logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - params_dict = self.param_dict + params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if "classification_head" in name: diff --git a/python/sglang/srt/models/xverse.py b/python/sglang/srt/models/xverse.py index 0f4e2a89f..388bd66f7 100644 --- a/python/sglang/srt/models/xverse.py +++ b/python/sglang/srt/models/xverse.py @@ -307,8 +307,6 @@ class XverseForCausalLM(nn.Module): self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) - self.param_dict = dict(self.named_parameters()) - @torch.no_grad() def forward( self, @@ -333,7 +331,7 @@ class XverseForCausalLM(nn.Module): ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] - params_dict = self.param_dict + params_dict = dict(self.named_parameters()) def load_weights_per_param(name, loaded_weight): if "rotary_emb.inv_freq" in name or "projector" in name: diff --git a/python/sglang/srt/models/xverse_moe.py b/python/sglang/srt/models/xverse_moe.py index b8cf2c8af..2f53682d1 100644 --- a/python/sglang/srt/models/xverse_moe.py +++ b/python/sglang/srt/models/xverse_moe.py @@ -383,8 +383,6 @@ class XverseMoeForCausalLM(nn.Module): ) self.logits_processor = LogitsProcessor(config) - self.param_dict = dict(self.named_parameters()) - @torch.no_grad() def forward( self, @@ -406,8 +404,7 @@ class XverseMoeForCausalLM(nn.Module): ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] - - params_dict = self.param_dict + params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 218dc2ccf..f6e5f3ca0 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -22,6 +22,7 @@ from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint from sglang.srt.utils import kill_child_process from sglang.utils import get_exception_traceback +DEFAULT_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/Meta-Llama-3.1-8B-FP8" DEFAULT_MODEL_NAME_FOR_TEST = "meta-llama/Meta-Llama-3.1-8B-Instruct" DEFAULT_MOE_MODEL_NAME_FOR_TEST = "mistralai/Mixtral-8x7B-Instruct-v0.1" DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH = 600 diff --git a/test/srt/test_bench_serving.py b/test/srt/test_bench_serving.py index eee6d7701..2a327f858 100644 --- a/test/srt/test_bench_serving.py +++ b/test/srt/test_bench_serving.py @@ -1,6 +1,7 @@ import unittest from sglang.test.test_utils import ( + DEFAULT_FP8_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MOE_MODEL_NAME_FOR_TEST, is_in_ci, @@ -59,6 +60,17 @@ class TestBenchServing(unittest.TestCase): if is_in_ci(): assert res["output_throughput"] > 2600 + def test_offline_throughput_default_fp8(self): + res = run_bench_serving( + model=DEFAULT_FP8_MODEL_NAME_FOR_TEST, + num_prompts=500, + request_rate=float("inf"), + other_server_args=[], + ) + + if is_in_ci(): + assert res["output_throughput"] > 3100 + def test_online_latency_default(self): res = run_bench_serving( model=DEFAULT_MODEL_NAME_FOR_TEST, diff --git a/test/srt/test_chunked_prefill.py b/test/srt/test_chunked_prefill.py index 057f42f05..6f6824416 100644 --- a/test/srt/test_chunked_prefill.py +++ b/test/srt/test_chunked_prefill.py @@ -12,8 +12,10 @@ from sglang.test.test_utils import ( class TestChunkedPrefill(unittest.TestCase): - def run_mmlu(self, disable_radix_cache, enable_mixed_chunk): - other_args = ["--chunked-prefill-size", "32"] + def run_mmlu( + self, disable_radix_cache, enable_mixed_chunk, chunked_prefill_size=32 + ): + other_args = ["--chunked-prefill-size", str(chunked_prefill_size)] if disable_radix_cache: other_args += ["--disable-radix-cache"] @@ -55,6 +57,11 @@ class TestChunkedPrefill(unittest.TestCase): def test_mixed_chunked_prefill_without_radix_cache(self): self.run_mmlu(disable_radix_cache=True, enable_mixed_chunk=True) + def test_no_chunked_prefill(self): + self.run_mmlu( + disable_radix_cache=False, enable_mixed_chunk=False, chunked_prefill_size=-1 + ) + if __name__ == "__main__": unittest.main()