[EAGLE] many fixes for eagle (#4195)

Co-authored-by: SangBin Cho <rkooo567@gmail.com>
Co-authored-by: Sehoon Kim <sehoon@x.ai>
This commit is contained in:
Lianmin Zheng
2025-03-07 22:12:13 -08:00
parent d052f4c8a9
commit d4017a6b63
15 changed files with 202 additions and 135 deletions

View File

@@ -123,7 +123,7 @@ class TestEAGLEEngine(unittest.TestCase):
def _test_acc_length(self, engine):
prompt = [
"Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:"
]
] * 5
sampling_params = {"temperature": 0, "max_new_tokens": 512}
output = engine.generate(prompt, sampling_params)
output = output[0]
@@ -141,10 +141,14 @@ class TestEAGLEEngine(unittest.TestCase):
/ output["meta_info"]["e2e_latency"]
)
print(f"{acc_length=}")
self.assertGreater(acc_length, 3.6)
if engine.server_args.model_path == DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST:
self.assertGreater(acc_length, 3.6)
else:
self.assertGreater(acc_length, 2.6)
class TestEAGLEEngineTokenMap(unittest.TestCase):
class TestEAGLEEngineTokenMap(TestEAGLEEngine):
BASE_CONFIG = {
"model_path": "meta-llama/Meta-Llama-3-8B-Instruct",
"speculative_draft_model_path": "lmsys/sglang-EAGLE-LLaMA3-Instruct-8B",
@@ -155,6 +159,7 @@ class TestEAGLEEngineTokenMap(unittest.TestCase):
"speculative_token_map": "thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt",
"mem_fraction_static": 0.7,
"cuda_graph_max_bs": 5,
"dtype": "float16",
}
NUM_CONFIGS = 1
@@ -245,8 +250,25 @@ class TestEAGLEServer(unittest.TestCase):
for p in threads:
p.join()
def test_max_token_one(self):
requests.get(self.base_url + "/flush_cache")
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=1,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
# Just run and check it does not hang
metrics = run_eval(args)
self.assertGreater(metrics["output_throughput"], 50)
def test_gsm8k(self):
server_info = requests.get(self.base_url + "/flush_cache")
requests.get(self.base_url + "/flush_cache")
args = SimpleNamespace(
num_shots=5,
@@ -391,6 +413,53 @@ class TestEAGLEServer(unittest.TestCase):
with ThreadPoolExecutor(8) as executor:
list(executor.map(func, args))
def run_decode(self, sampling_params):
return_logprob = True
top_logprobs_num = 5
return_text = True
n = 1
response = requests.post(
self.base_url + "/generate",
json={
"text": "Human: Write a travel blog post to Hawaii.\n\nAssistant:",
"sampling_params": {
"max_new_tokens": 48,
"n": n,
"temperature": 0.7,
**sampling_params,
},
"return_logprob": return_logprob,
"top_logprobs_num": top_logprobs_num,
"return_text_in_logprobs": return_text,
"logprob_start_len": 0,
},
)
self.assertEqual(response.status_code, 200)
print(json.dumps(response.json()))
print("=" * 100)
def test_penalty_mixed(self):
args = [
{},
{},
{},
{"frequency_penalty": 2},
{"presence_penalty": 1},
{"min_new_tokens": 16},
{"frequency_penalty": 0.2},
{"presence_penalty": 0.4},
{"min_new_tokens": 8},
{"frequency_penalty": 0.4, "presence_penalty": 0.8},
{"frequency_penalty": 0.4, "min_new_tokens": 12},
{"presence_penalty": 0.8, "min_new_tokens": 12},
{"presence_penalty": -0.3, "frequency_penalty": 1.3, "min_new_tokens": 32},
{"presence_penalty": 0.3, "frequency_penalty": -1.3, "min_new_tokens": 32},
]
random.shuffle(args * 5)
with ThreadPoolExecutor(8) as executor:
list(executor.map(self.run_decode, args))
class TestEAGLERetract(TestEAGLEServer):
@classmethod