Support sliding window in triton backend (#6509)
This commit is contained in:
@@ -78,6 +78,7 @@ suites = {
|
||||
TestFile("test_triton_attention_kernels.py", 4),
|
||||
TestFile("test_triton_attention_backend.py", 134),
|
||||
TestFile("test_triton_moe_channel_fp8_kernel.py", 25),
|
||||
TestFile("test_triton_sliding_window.py", 250),
|
||||
TestFile("test_update_weights_from_disk.py", 114),
|
||||
TestFile("test_update_weights_from_tensor.py", 48),
|
||||
TestFile("test_vertex_endpoint.py", 31),
|
||||
|
||||
132
test/srt/test_triton_sliding_window.py
Normal file
132
test/srt/test_triton_sliding_window.py
Normal file
@@ -0,0 +1,132 @@
|
||||
import time
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
|
||||
import requests
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.run_eval import run_eval
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class TestSlidingWindowAttentionTriton(CustomTestCase):
|
||||
"""Test sliding window attention functionality with triton backend."""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
"""Set up the test server with Gemma3 model and triton backend."""
|
||||
# Gemma3 model supports sliding window attention
|
||||
cls.model = "google/gemma-3-4b-it"
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
|
||||
cls.common_args = [
|
||||
"--trust-remote-code",
|
||||
"--attention-backend",
|
||||
"triton",
|
||||
"--context-length",
|
||||
"8192",
|
||||
"--random-seed",
|
||||
"42",
|
||||
]
|
||||
|
||||
cls.short_context_prompt = "The capital of France is"
|
||||
|
||||
# Test prompt longer than window size
|
||||
cls.long_context_prompt = (
|
||||
"""
|
||||
Once upon a time, there was a mountain. In the mountain, there was a temple. In the temple, there was an old monk telling a story. The story was:
|
||||
"""
|
||||
* 100
|
||||
)
|
||||
cls.long_context_prompt += "\nNow, summarize the story in one sentence:"
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
pass
|
||||
|
||||
def _test_mmlu(self):
|
||||
args = SimpleNamespace(
|
||||
base_url=self.base_url,
|
||||
model=self.model,
|
||||
eval_name="mmlu",
|
||||
num_examples=64,
|
||||
num_threads=32,
|
||||
)
|
||||
|
||||
metrics = run_eval(args)
|
||||
print(f"MMLU metrics with sliding window: {metrics}")
|
||||
|
||||
self.assertGreaterEqual(metrics["score"], 0.64)
|
||||
|
||||
def _test_short_context_generation(self):
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"text": self.short_context_prompt,
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 256,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
result = response.json()
|
||||
self.assertIn("paris", result["text"].lower())
|
||||
print(f"Short context generation result: {result['text']}")
|
||||
|
||||
def _test_long_context_generation(self):
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"text": self.long_context_prompt,
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 256,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
result = response.json()
|
||||
self.assertGreater(len(result["text"].strip()), 0)
|
||||
print(f"Long context generation result: {result['text'][:100]}...")
|
||||
|
||||
def test_no_cuda_graph(self):
|
||||
self.no_cuda_graph_process = popen_launch_server(
|
||||
self.model,
|
||||
self.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=self.common_args + ["--disable-cuda-graph"],
|
||||
)
|
||||
|
||||
self._test_short_context_generation()
|
||||
self._test_long_context_generation()
|
||||
self._test_mmlu()
|
||||
|
||||
kill_process_tree(self.no_cuda_graph_process.pid)
|
||||
time.sleep(5)
|
||||
|
||||
def test_cuda_graph(self):
|
||||
self.cuda_graph_process = popen_launch_server(
|
||||
self.model,
|
||||
self.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=self.common_args,
|
||||
)
|
||||
|
||||
self._test_short_context_generation()
|
||||
self._test_long_context_generation()
|
||||
self._test_mmlu()
|
||||
|
||||
kill_process_tree(self.cuda_graph_process.pid)
|
||||
time.sleep(5)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user