[EAGLE] many fixes for eagle (#4195)

Co-authored-by: SangBin Cho <rkooo567@gmail.com>
Co-authored-by: Sehoon Kim <sehoon@x.ai>
This commit is contained in:
Lianmin Zheng
2025-03-07 22:12:13 -08:00
parent d052f4c8a9
commit d4017a6b63
15 changed files with 202 additions and 135 deletions

View File

@@ -123,7 +123,7 @@ class TestEAGLEEngine(unittest.TestCase):
def _test_acc_length(self, engine):
prompt = [
"Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:"
]
] * 5
sampling_params = {"temperature": 0, "max_new_tokens": 512}
output = engine.generate(prompt, sampling_params)
output = output[0]
@@ -141,10 +141,14 @@ class TestEAGLEEngine(unittest.TestCase):
/ output["meta_info"]["e2e_latency"]
)
print(f"{acc_length=}")
self.assertGreater(acc_length, 3.6)
if engine.server_args.model_path == DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST:
self.assertGreater(acc_length, 3.6)
else:
self.assertGreater(acc_length, 2.6)
class TestEAGLEEngineTokenMap(unittest.TestCase):
class TestEAGLEEngineTokenMap(TestEAGLEEngine):
BASE_CONFIG = {
"model_path": "meta-llama/Meta-Llama-3-8B-Instruct",
"speculative_draft_model_path": "lmsys/sglang-EAGLE-LLaMA3-Instruct-8B",
@@ -155,6 +159,7 @@ class TestEAGLEEngineTokenMap(unittest.TestCase):
"speculative_token_map": "thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt",
"mem_fraction_static": 0.7,
"cuda_graph_max_bs": 5,
"dtype": "float16",
}
NUM_CONFIGS = 1
@@ -245,8 +250,25 @@ class TestEAGLEServer(unittest.TestCase):
for p in threads:
p.join()
def test_max_token_one(self):
requests.get(self.base_url + "/flush_cache")
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=1,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
# Just run and check it does not hang
metrics = run_eval(args)
self.assertGreater(metrics["output_throughput"], 50)
def test_gsm8k(self):
server_info = requests.get(self.base_url + "/flush_cache")
requests.get(self.base_url + "/flush_cache")
args = SimpleNamespace(
num_shots=5,
@@ -391,6 +413,53 @@ class TestEAGLEServer(unittest.TestCase):
with ThreadPoolExecutor(8) as executor:
list(executor.map(func, args))
def run_decode(self, sampling_params):
return_logprob = True
top_logprobs_num = 5
return_text = True
n = 1
response = requests.post(
self.base_url + "/generate",
json={
"text": "Human: Write a travel blog post to Hawaii.\n\nAssistant:",
"sampling_params": {
"max_new_tokens": 48,
"n": n,
"temperature": 0.7,
**sampling_params,
},
"return_logprob": return_logprob,
"top_logprobs_num": top_logprobs_num,
"return_text_in_logprobs": return_text,
"logprob_start_len": 0,
},
)
self.assertEqual(response.status_code, 200)
print(json.dumps(response.json()))
print("=" * 100)
def test_penalty_mixed(self):
args = [
{},
{},
{},
{"frequency_penalty": 2},
{"presence_penalty": 1},
{"min_new_tokens": 16},
{"frequency_penalty": 0.2},
{"presence_penalty": 0.4},
{"min_new_tokens": 8},
{"frequency_penalty": 0.4, "presence_penalty": 0.8},
{"frequency_penalty": 0.4, "min_new_tokens": 12},
{"presence_penalty": 0.8, "min_new_tokens": 12},
{"presence_penalty": -0.3, "frequency_penalty": 1.3, "min_new_tokens": 32},
{"presence_penalty": 0.3, "frequency_penalty": -1.3, "min_new_tokens": 32},
]
random.shuffle(args * 5)
with ThreadPoolExecutor(8) as executor:
list(executor.map(self.run_decode, args))
class TestEAGLERetract(TestEAGLEServer):
@classmethod

View File

@@ -44,11 +44,12 @@ class TestEvalAccuracyLarge(unittest.TestCase):
)
metrics = run_eval(args)
self.assertGreater(metrics["score"], 0.71)
if is_in_ci():
write_github_step_summary(f"### test_mmlu\n" f'{metrics["score"]=:.4f}\n')
self.assertGreater(metrics["score"], 0.71)
def test_human_eval(self):
args = SimpleNamespace(
base_url=self.base_url,
@@ -59,13 +60,14 @@ class TestEvalAccuracyLarge(unittest.TestCase):
)
metrics = run_eval(args)
self.assertGreater(metrics["score"], 0.64)
if is_in_ci():
write_github_step_summary(
f"### test_human_eval\n" f'{metrics["score"]=:.4f}\n'
)
self.assertGreater(metrics["score"], 0.64)
def test_mgsm_en(self):
args = SimpleNamespace(
base_url=self.base_url,
@@ -76,13 +78,14 @@ class TestEvalAccuracyLarge(unittest.TestCase):
)
metrics = run_eval(args)
self.assertGreater(metrics["score"], 0.835)
if is_in_ci():
write_github_step_summary(
f"### test_mgsm_en\n" f'{metrics["score"]=:.4f}\n'
)
self.assertGreater(metrics["score"], 0.835)
if __name__ == "__main__":
unittest.main()

View File

@@ -1,6 +1,7 @@
import unittest
from types import SimpleNamespace
import requests
import torch
from sglang.srt.utils import kill_process_tree
@@ -129,6 +130,8 @@ class TestDeepseekV3MTP(unittest.TestCase):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
requests.get(self.base_url + "/flush_cache")
args = SimpleNamespace(
num_shots=5,
data_path=None,
@@ -143,6 +146,11 @@ class TestDeepseekV3MTP(unittest.TestCase):
self.assertGreater(metrics["accuracy"], 0.60)
server_info = requests.get(self.base_url + "/get_server_info")
avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
print(f"{avg_spec_accept_length=}")
self.assertGreater(avg_spec_accept_length, 2.5)
if __name__ == "__main__":
unittest.main()

View File

@@ -42,7 +42,7 @@ class TestPenalty(unittest.TestCase):
# prompt that is supposed to generate < 32 tokens
"text": "<|start_header_id|>user<|end_header_id|>\n\nWhat is the answer for 1 + 1 = ?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
"sampling_params": {
"max_new_tokens": 32,
"max_new_tokens": 48,
"n": n,
**sampling_params,
},
@@ -68,19 +68,22 @@ class TestPenalty(unittest.TestCase):
def test_presence_penalty(self):
self.run_decode({"presence_penalty": 2})
def test_mixed(self):
def test_penalty_mixed(self):
args = [
{},
{},
{},
{"frequency_penalty": 2},
{"min_new_tokens": 16},
{"presence_penalty": 1},
{"min_new_tokens": 16},
{"frequency_penalty": 0.2},
{"min_new_tokens": 8},
{"presence_penalty": 0.4},
{"presence_penalty": 0.4, "frequency_penalty": 2},
{"min_new_tokens": 12, "frequency_penalty": 2},
{"min_new_tokens": 8},
{"frequency_penalty": 0.4, "presence_penalty": 0.8},
{"frequency_penalty": 0.4, "min_new_tokens": 12},
{"presence_penalty": 0.8, "min_new_tokens": 12},
{"presence_penalty": -0.3, "frequency_penalty": 1.3, "min_new_tokens": 32},
{"presence_penalty": 0.3, "frequency_penalty": -1.3, "min_new_tokens": 32},
]
random.shuffle(args * 5)
with ThreadPoolExecutor(8) as executor: