Fix a draft model accuracy bug in eagle; support step=1; return logprob in eagle (#4134)

Co-authored-by: Sehoon Kim <kssteven418@gmail.com>
Co-authored-by: SangBin Cho <rkooo567@gmail.com>
Co-authored-by: Sehoon Kim <sehoon@x.ai>
This commit is contained in:
Lianmin Zheng
2025-03-06 06:13:59 -08:00
committed by GitHub
parent 3a3918121f
commit bc1534ff32
11 changed files with 304 additions and 106 deletions

View File

@@ -165,7 +165,7 @@ class TestBenchServing(unittest.TestCase):
f'median_e2e_latency_ms : {res["median_e2e_latency_ms"]:.2f} ms\n'
f'accept_length : {res["accept_length"]:.2f} \n'
)
self.assertLess(res["median_e2e_latency_ms"], 1100)
self.assertLess(res["median_e2e_latency_ms"], 900)
self.assertGreater(res["accept_length"], 2.99)
def test_moe_offline_throughput_default(self):

View File

@@ -1,12 +1,16 @@
import json
import multiprocessing as mp
import os
import random
import threading
import time
import unittest
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from types import SimpleNamespace
from typing import List, Optional
import numpy as np
import requests
import torch
@@ -21,6 +25,7 @@ from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
run_logprob_check,
)
torch_dtype = torch.float16
@@ -260,11 +265,132 @@ class TestEAGLEServer(unittest.TestCase):
server_info = requests.get(self.base_url + "/get_server_info")
avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
print(f"{avg_spec_accept_length=}")
self.assertGreater(avg_spec_accept_length, 2.9)
self.assertGreater(avg_spec_accept_length, 3.5)
# Wait a little bit so that the memory check happens.
time.sleep(4)
def test_logprob_start_len(self):
logprob_start_len = 4
new_tokens = 4
prompts = [
"I have a very good idea on",
"Today is a sunndy day and",
]
response = requests.post(
self.base_url + "/generate",
json={
"text": prompts,
"sampling_params": {
"temperature": 0,
"max_new_tokens": new_tokens,
},
"return_logprob": True,
"top_logprobs_num": 5,
"logprob_start_len": logprob_start_len,
},
)
response_json = response.json()
print(json.dumps(response_json, indent=2))
for res in response_json:
self.assertEqual(
res["meta_info"]["prompt_tokens"],
logprob_start_len + len(res["meta_info"]["input_token_logprobs"]),
)
self.assertEqual(res["meta_info"]["completion_tokens"], new_tokens)
self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), new_tokens)
def test_logprob_match(self):
"""Test the output logprobs are close to the input logprobs if we run a prefill again."""
def run_generate(
prompt, return_logprob=False, max_new_tokens=512, logprob_start_len=-1
):
if isinstance(prompt, str):
prompt_kwargs = {"text": prompt}
else:
prompt_kwargs = {"input_ids": prompt}
response = requests.post(
self.base_url + "/generate",
json={
**prompt_kwargs,
"sampling_params": {
"temperature": 1.0,
"max_new_tokens": max_new_tokens,
"ignore_eos": True,
},
"return_logprob": return_logprob,
"return_text_in_logprobs": True,
"logprob_start_len": logprob_start_len,
},
)
return response.json()
prompt = "I have a very good idea on how to"
gen = run_generate(prompt, return_logprob=True, logprob_start_len=0)
output_logprobs = np.array(
[x[0] for x in gen["meta_info"]["output_token_logprobs"]]
)
num_prompts_tokens = gen["meta_info"]["prompt_tokens"]
input_tokens = [x[1] for x in gen["meta_info"]["input_token_logprobs"]]
output_tokens = [x[1] for x in gen["meta_info"]["output_token_logprobs"]]
new_prompt = input_tokens + output_tokens
score = run_generate(
new_prompt, return_logprob=True, logprob_start_len=0, max_new_tokens=0
)
output_logprobs_score = np.array(
[
x[0]
for x in score["meta_info"]["input_token_logprobs"][num_prompts_tokens:]
]
)
print(f"{output_logprobs[-10:]=}")
print(f"{output_logprobs_score[-10:]=}")
diff = np.abs(output_logprobs - output_logprobs_score)
max_diff = np.max(diff)
self.assertLess(max_diff, 0.25)
def test_logprob_mixed(self):
args = []
temperature = 0
# input_len, output_len, temperature, logprob_start_len, return_logprob, top_logprobs_num
# Llama 2 context length seems to be only 2k, so we can only test small length.
for input_len in [200, 500, 1000, 2000]:
for output_len in [4, 8]:
for logprob_start_len in [0, 100, 300, 800, 1998]:
for return_logprob in [True, False]:
for top_logprobs_num in [0, 5]:
if logprob_start_len >= input_len:
continue
args.append(
(
input_len,
output_len,
temperature,
logprob_start_len,
return_logprob,
top_logprobs_num,
)
)
random.shuffle(args)
func = partial(run_logprob_check, self)
with ThreadPoolExecutor(8) as executor:
list(executor.map(func, args))
class TestEAGLERetract(TestEAGLEServer):
@classmethod

View File

@@ -143,11 +143,11 @@ class TestGPTQModelDynamic(unittest.TestCase):
print(f"result = `{result}`")
assert "paris" in result["text"].lower()
self.assertIn("paris", result["text"].lower())
throughput = max_tokens / (tok - tic)
print(f"Throughput: {throughput} tokens/s")
assert throughput >= 140
self.assertGreaterEqual(throughput, 140)
def test_gptq_module(self):
check_quant_method(self.MODEL_PATH, use_marlin_kernel=False)