Split the __init__ of scheduler as smaller functions. Improve the eagle tests (#4128)

This commit is contained in:
Lianmin Zheng
2025-03-06 00:13:20 -08:00
committed by GitHub
parent 0804dd11a0
commit fcc2e37f69
7 changed files with 279 additions and 341 deletions

View File

@@ -1,16 +1,20 @@
import multiprocessing as mp
import os
import random
import threading
import time
import unittest
from types import SimpleNamespace
from typing import List, Optional
import requests
import torch
import sglang as sgl
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval
from sglang.test.runners import DEFAULT_PROMPTS, SRTRunner
from sglang.test.test_utils import (
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
@@ -19,7 +23,9 @@ from sglang.test.test_utils import (
popen_launch_server,
)
acc_rate_tolerance = 0.15
torch_dtype = torch.float16
prefill_tolerance = 5e-2
decode_tolerance: float = 5e-2
class TestEAGLEEngine(unittest.TestCase):
@@ -28,51 +34,72 @@ class TestEAGLEEngine(unittest.TestCase):
"speculative_draft_model_path": DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"speculative_algorithm": "EAGLE",
"speculative_num_steps": 5,
"speculative_eagle_topk": 8,
"speculative_num_draft_tokens": 64,
"speculative_eagle_topk": 4,
"speculative_num_draft_tokens": 8,
"mem_fraction_static": 0.7,
"cuda_graph_max_bs": 32,
"cuda_graph_max_bs": 5,
}
NUM_CONFIGS = 3
def setUp(self):
self.prompt = "Today is a sunny day and I like"
self.sampling_params = {"temperature": 0, "max_new_tokens": 8}
ref_engine = sgl.Engine(model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST)
ref_engine = sgl.Engine(
model_path=self.BASE_CONFIG["model_path"], cuda_graph_max_bs=1
)
self.ref_output = ref_engine.generate(self.prompt, self.sampling_params)["text"]
ref_engine.shutdown()
def test_correctness(self):
configs = [
# Basic config
self.BASE_CONFIG,
# Disable cuda graph
{**self.BASE_CONFIG, "disable_cuda_graph": True},
{**self.BASE_CONFIG, "chunked_prefill_size": 2},
# Chunked prefill
{**self.BASE_CONFIG, "chunked_prefill_size": 4},
]
for config in configs:
with self.subTest(
cuda_graph=(
"enabled" if len(config) == len(self.BASE_CONFIG) else "disabled"
),
chunked_prefill_size=(
config["chunked_prefill_size"]
if "chunked_prefill_size" in config
else "default"
),
):
engine = sgl.Engine(**config)
for i, config in enumerate(configs[: self.NUM_CONFIGS]):
with self.subTest(i=i):
print(f"{config=}")
engine = sgl.Engine(**config, log_level="info", decode_log_interval=10)
try:
self._test_basic_generation(engine)
self._test_eos_token(engine)
self._test_single_generation(engine)
self._test_batch_generation(engine)
self._test_eos_token(engine)
self._test_acc_length(engine)
finally:
engine.shutdown()
print("=" * 100)
def _test_basic_generation(self, engine):
def _test_single_generation(self, engine):
output = engine.generate(self.prompt, self.sampling_params)["text"]
print(f"{output=}, {self.ref_output=}")
self.assertEqual(output, self.ref_output)
def _test_batch_generation(self, engine):
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
params = {"temperature": 0, "max_new_tokens": 50}
outputs = engine.generate(prompts, params)
for prompt, output in zip(prompts, outputs):
print(f"Prompt: {prompt}")
print(f"Generated: {output['text']}")
print("-" * 40)
print(f"{engine.get_server_info()=}")
avg_spec_accept_length = engine.get_server_info()["avg_spec_accept_length"]
print(f"{avg_spec_accept_length=}")
self.assertGreater(avg_spec_accept_length, 1.9)
def _test_eos_token(self, engine):
prompt = "[INST] <<SYS>>\nYou are a helpful assistant.\n<</SYS>>\nToday is a sunny day and I like [/INST]"
params = {
@@ -88,32 +115,54 @@ class TestEAGLEEngine(unittest.TestCase):
tokens = tokenizer.encode(output, truncation=False)
self.assertNotIn(tokenizer.eos_token_id, tokens)
def _test_batch_generation(self, engine):
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
def _test_acc_length(self, engine):
prompt = [
"Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:"
]
params = {"temperature": 0, "max_new_tokens": 30}
sampling_params = {"temperature": 0, "max_new_tokens": 512}
output = engine.generate(prompt, sampling_params)
output = output[0]
outputs = engine.generate(prompts, params)
for prompt, output in zip(prompts, outputs):
print(f"Prompt: {prompt}")
print(f"Generated: {output['text']}")
print("-" * 40)
if "spec_verify_ct" in output["meta_info"]:
acc_length = (
output["meta_info"]["completion_tokens"]
/ output["meta_info"]["spec_verify_ct"]
)
else:
acc_length = 1.0
speed = (
output["meta_info"]["completion_tokens"]
/ output["meta_info"]["e2e_latency"]
)
print(f"{acc_length=}")
self.assertGreater(acc_length, 3.6)
prompts = [
"[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nToday is a sunny day and I like[/INST]"
'[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nWhat are the mental triggers in Jeff Walker\'s Product Launch Formula and "Launch" book?[/INST]',
"[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nSummarize Russell Brunson's Perfect Webinar Script...[/INST]",
"[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nwho are you?[/INST]",
"[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nwhere are you from?[/INST]",
]
class TestEAGLEEngineTokenMap(unittest.TestCase):
BASE_CONFIG = {
"model_path": "meta-llama/Meta-Llama-3-8B-Instruct",
"speculative_draft_model_path": "lmsys/sglang-EAGLE-LLaMA3-Instruct-8B",
"speculative_algorithm": "EAGLE",
"speculative_num_steps": 5,
"speculative_eagle_topk": 4,
"speculative_num_draft_tokens": 8,
"speculative_token_map": "thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt",
"mem_fraction_static": 0.7,
"cuda_graph_max_bs": 5,
}
NUM_CONFIGS = 1
class TestEAGLEServer(unittest.TestCase):
PROMPTS = [
"[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nToday is a sunny day and I like[/INST]"
'[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nWhat are the mental triggers in Jeff Walker\'s Product Launch Formula and "Launch" book?[/INST]',
"[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nSummarize Russell Brunson's Perfect Webinar Script...[/INST]",
"[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nwho are you?[/INST]",
"[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nwhere are you from?[/INST]",
]
@classmethod
def setUpClass(cls):
cls.base_url = DEFAULT_URL_FOR_TEST
@@ -127,17 +176,17 @@ class TestEAGLEServer(unittest.TestCase):
"--speculative-draft-model-path",
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"--speculative-num-steps",
"5",
5,
"--speculative-eagle-topk",
"8",
8,
"--speculative-num-draft-tokens",
"64",
64,
"--mem-fraction-static",
"0.7",
0.7,
"--chunked-prefill-size",
"128",
"--cuda-graph-max-bs",
"32",
128,
"--max-running-requests",
8,
],
)
@@ -147,7 +196,7 @@ class TestEAGLEServer(unittest.TestCase):
def send_request(self):
time.sleep(random.uniform(0, 2))
for prompt in prompts:
for prompt in self.PROMPTS:
url = self.base_url + "/generate"
data = {
"text": prompt,
@@ -160,7 +209,7 @@ class TestEAGLEServer(unittest.TestCase):
assert response.status_code == 200
def send_requests_abort(self):
for prompt in prompts:
for prompt in self.PROMPTS:
try:
time.sleep(random.uniform(0, 2))
url = self.base_url + "/generate"
@@ -192,6 +241,8 @@ class TestEAGLEServer(unittest.TestCase):
p.join()
def test_gsm8k(self):
server_info = requests.get(self.base_url + "/flush_cache")
args = SimpleNamespace(
num_shots=5,
data_path=None,
@@ -201,96 +252,25 @@ class TestEAGLEServer(unittest.TestCase):
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.20)
server_info = requests.get(self.base_url + "/get_server_info")
avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
print(f"{avg_spec_accept_length=}")
self.assertGreater(avg_spec_accept_length, 2.9)
def measure_acc_rate(engine):
tic = time.time()
prompt = [
"Human: Give me a fully functional FastAPI server. Show the python code.<|separator|>\n\nAssistant:"
]
sampling_params = {"temperature": 0, "max_new_tokens": 512}
output = engine.generate(prompt, sampling_params)
output = output[0]
latency = time.time() - tic
if "spec_verify_ct" in output["meta_info"]:
base_acc_length = (
output["meta_info"]["completion_tokens"]
/ output["meta_info"]["spec_verify_ct"]
)
else:
base_acc_length = 0.0
base_speed = output["meta_info"]["completion_tokens"] / latency
return base_acc_length, base_speed
# Wait a little bit so that the memory check happens.
time.sleep(4)
class TestEagleAcceptanceRate(unittest.TestCase):
@classmethod
def setUpClass(cls):
mp.set_start_method("spawn", force=True)
ref_engine = sgl.Engine(
model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
speculative_draft_model_path=DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
speculative_algorithm="EAGLE",
speculative_num_steps=5,
speculative_eagle_topk=8,
speculative_num_draft_tokens=64,
mem_fraction_static=0.7,
disable_radix_cache=True,
)
cls.base_acc_length, cls.base_speed = measure_acc_rate(ref_engine)
ref_engine.shutdown()
assert cls.base_acc_length > 4.45
def test_acc_rate(self):
base_acc_length, base_speed = self.base_acc_length, self.base_speed
chunk_engine = sgl.Engine(
model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
speculative_draft_model_path=DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
speculative_algorithm="EAGLE",
speculative_num_steps=5,
speculative_eagle_topk=8,
speculative_num_draft_tokens=64,
mem_fraction_static=0.7,
chunked_prefill_size=2,
disable_radix_cache=True,
)
chunked_acc_length, chunked_base_speed = measure_acc_rate(chunk_engine)
chunk_engine.shutdown()
print(base_acc_length, base_speed)
print(chunked_acc_length, chunked_base_speed)
assert abs(base_acc_length - chunked_acc_length) < acc_rate_tolerance
def test_acc_rate_prefix_caching(self):
base_acc_length, base_speed = self.base_acc_length, self.base_speed
prefix_caching_engine = sgl.Engine(
model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
speculative_draft_model_path=DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
speculative_algorithm="EAGLE",
speculative_num_steps=5,
speculative_eagle_topk=8,
speculative_num_draft_tokens=64,
mem_fraction_static=0.7,
chunked_prefill_size=4,
schedule_policy="lpm",
)
for _ in range(10):
acc_length, _ = measure_acc_rate(prefix_caching_engine)
print(f"{acc_length=}")
assert abs(base_acc_length - acc_length) < acc_rate_tolerance
# The second one should hit the prefix cache.
prefix_caching_engine.shutdown()
class TestEAGLERetract(unittest.TestCase):
class TestEAGLERetract(TestEAGLEServer):
@classmethod
def setUpClass(cls):
# These config helps find a leak.
os.environ["SGLANG_CI_SMALL_KV_SIZE"] = "4500"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
@@ -302,41 +282,20 @@ class TestEAGLERetract(unittest.TestCase):
"--speculative-draft-model-path",
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"--speculative-num-steps",
"5",
5,
"--speculative-eagle-topk",
"8",
8,
"--speculative-num-draft-tokens",
"64",
64,
"--mem-fraction-static",
"0.7",
0.7,
"--chunked-prefill-size",
"128",
128,
"--max-running-requests",
"64",
64,
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.20)
# Wait a little bit so that the memory check happens.
time.sleep(5)
class TestEAGLEServerTriton(TestEAGLEServer):
@classmethod
@@ -352,73 +311,20 @@ class TestEAGLEServerTriton(TestEAGLEServer):
"--speculative-draft-model-path",
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"--speculative-num-steps",
"5",
5,
"--speculative-eagle-topk",
"4",
8,
"--speculative-num-draft-tokens",
"8",
64,
"--mem-fraction-static",
"0.7",
0.7,
"--attention-backend",
"triton",
"--cuda-graph-max-bs",
"16",
"--max-running-requests",
8,
],
)
class TestEAGLEEngineTokenMap(unittest.TestCase):
def setUp(self):
self.prompt = "Today is a sunny day and I like"
self.sampling_params = {"temperature": 0, "max_new_tokens": 8}
ref_engine = sgl.Engine(
model_path="meta-llama/Meta-Llama-3-8B-Instruct", cuda_graph_max_bs=2
)
self.ref_output = ref_engine.generate(self.prompt, self.sampling_params)["text"]
ref_engine.shutdown()
def test_correctness(self):
config = {
"model_path": "meta-llama/Meta-Llama-3-8B-Instruct",
"speculative_draft_model_path": "lmsys/sglang-EAGLE-LLaMA3-Instruct-8B",
"speculative_algorithm": "EAGLE",
"speculative_num_steps": 5,
"speculative_eagle_topk": 4,
"speculative_num_draft_tokens": 8,
"speculative_token_map": "thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt",
"mem_fraction_static": 0.7,
"cuda_graph_max_bs": 4,
"dtype": "bfloat16",
}
engine = sgl.Engine(**config)
try:
self._test_basic_generation(engine)
self._test_batch_generation(engine)
finally:
engine.shutdown()
def _test_basic_generation(self, engine):
output = engine.generate(self.prompt, self.sampling_params)["text"]
print(f"{output=}, {self.ref_output=}")
self.assertEqual(output, self.ref_output)
def _test_batch_generation(self, engine):
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
params = {"temperature": 0, "max_new_tokens": 30}
outputs = engine.generate(prompts, params)
for prompt, output in zip(prompts, outputs):
print(f"Prompt: {prompt}")
print(f"Generated: {output['text']}")
print("-" * 40)
if __name__ == "__main__":
unittest.main()

View File

@@ -59,6 +59,7 @@ class TestEnableMetrics(unittest.TestCase):
"sglang:spec_accept_length",
"sglang:prompt_tokens_total",
"sglang:generation_tokens_total",
"sglang:cached_tokens_total",
"sglang:num_requests_total",
"sglang:time_to_first_token_seconds",
"sglang:time_per_output_token_seconds",

View File

@@ -94,7 +94,7 @@ class TestEpMoEFP8(unittest.TestCase):
)
metrics = run_eval(args)
assert metrics["score"] >= 0.5
self.assertGreaterEqual(metrics["score"], 0.5)
def test_mgsm_en(self):
args = SimpleNamespace(
@@ -106,7 +106,7 @@ class TestEpMoEFP8(unittest.TestCase):
)
metrics = run_eval(args)
assert metrics["score"] >= 0.8
self.assertGreaterEqual(metrics["score"], 0.8)
if __name__ == "__main__":