[Eagle] Refactor eagle speculative decoding (#3986)
Co-authored-by: Ke Bao <ISPObaoke@163.com>
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import multiprocessing as mp
|
||||
import random
|
||||
import threading
|
||||
import time
|
||||
@@ -18,6 +19,8 @@ from sglang.test.test_utils import (
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
acc_rate_tolerance = 0.15
|
||||
|
||||
|
||||
class TestEAGLEEngine(unittest.TestCase):
|
||||
BASE_CONFIG = {
|
||||
@@ -43,13 +46,19 @@ class TestEAGLEEngine(unittest.TestCase):
|
||||
configs = [
|
||||
self.BASE_CONFIG,
|
||||
{**self.BASE_CONFIG, "disable_cuda_graph": True},
|
||||
{**self.BASE_CONFIG, "chunked_prefill_size": 2},
|
||||
]
|
||||
|
||||
for config in configs:
|
||||
with self.subTest(
|
||||
cuda_graph=(
|
||||
"enabled" if len(config) == len(self.BASE_CONFIG) else "disabled"
|
||||
)
|
||||
),
|
||||
chunked_prefill_size=(
|
||||
config["chunked_prefill_size"]
|
||||
if "chunked_prefill_size" in config
|
||||
else "default"
|
||||
),
|
||||
):
|
||||
engine = sgl.Engine(**config)
|
||||
try:
|
||||
@@ -125,6 +134,8 @@ class TestEAGLEServer(unittest.TestCase):
|
||||
"64",
|
||||
"--mem-fraction-static",
|
||||
"0.7",
|
||||
"--chunked-prefill-size",
|
||||
"128",
|
||||
"--cuda-graph-max-bs",
|
||||
"32",
|
||||
],
|
||||
@@ -196,6 +207,137 @@ class TestEAGLEServer(unittest.TestCase):
|
||||
self.assertGreater(metrics["accuracy"], 0.20)
|
||||
|
||||
|
||||
def measure_acc_rate(engine):
|
||||
tic = time.time()
|
||||
prompt = [
|
||||
"Human: Give me a fully functional FastAPI server. Show the python code.<|separator|>\n\nAssistant:"
|
||||
]
|
||||
sampling_params = {"temperature": 0, "max_new_tokens": 512}
|
||||
output = engine.generate(prompt, sampling_params)
|
||||
output = output[0]
|
||||
latency = time.time() - tic
|
||||
|
||||
if "spec_verify_ct" in output["meta_info"]:
|
||||
base_acc_length = (
|
||||
output["meta_info"]["completion_tokens"]
|
||||
/ output["meta_info"]["spec_verify_ct"]
|
||||
)
|
||||
else:
|
||||
base_acc_length = 0.0
|
||||
|
||||
base_speed = output["meta_info"]["completion_tokens"] / latency
|
||||
return base_acc_length, base_speed
|
||||
|
||||
|
||||
class TestEagleAcceptanceRate(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
mp.set_start_method("spawn", force=True)
|
||||
ref_engine = sgl.Engine(
|
||||
model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
|
||||
speculative_draft_model_path=DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
|
||||
speculative_algorithm="EAGLE",
|
||||
speculative_num_steps=5,
|
||||
speculative_eagle_topk=8,
|
||||
speculative_num_draft_tokens=64,
|
||||
mem_fraction_static=0.7,
|
||||
disable_radix_cache=True,
|
||||
)
|
||||
cls.base_acc_length, cls.base_speed = measure_acc_rate(ref_engine)
|
||||
ref_engine.shutdown()
|
||||
assert cls.base_acc_length > 4.45
|
||||
|
||||
def test_acc_rate(self):
|
||||
base_acc_length, base_speed = self.base_acc_length, self.base_speed
|
||||
chunk_engine = sgl.Engine(
|
||||
model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
|
||||
speculative_draft_model_path=DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
|
||||
speculative_algorithm="EAGLE",
|
||||
speculative_num_steps=5,
|
||||
speculative_eagle_topk=8,
|
||||
speculative_num_draft_tokens=64,
|
||||
mem_fraction_static=0.7,
|
||||
chunked_prefill_size=2,
|
||||
disable_radix_cache=True,
|
||||
)
|
||||
chunked_acc_length, chunked_base_speed = measure_acc_rate(chunk_engine)
|
||||
chunk_engine.shutdown()
|
||||
print(base_acc_length, base_speed)
|
||||
print(chunked_acc_length, chunked_base_speed)
|
||||
assert abs(base_acc_length - chunked_acc_length) < acc_rate_tolerance
|
||||
|
||||
def test_acc_rate_prefix_caching(self):
|
||||
base_acc_length, base_speed = self.base_acc_length, self.base_speed
|
||||
prefix_caching_engine = sgl.Engine(
|
||||
model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
|
||||
speculative_draft_model_path=DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
|
||||
speculative_algorithm="EAGLE",
|
||||
speculative_num_steps=5,
|
||||
speculative_eagle_topk=8,
|
||||
speculative_num_draft_tokens=64,
|
||||
mem_fraction_static=0.7,
|
||||
chunked_prefill_size=4,
|
||||
schedule_policy="lpm",
|
||||
)
|
||||
for _ in range(10):
|
||||
acc_length, _ = measure_acc_rate(prefix_caching_engine)
|
||||
print(f"{acc_length=}")
|
||||
assert abs(base_acc_length - acc_length) < acc_rate_tolerance
|
||||
# The second one should hit the prefix cache.
|
||||
prefix_caching_engine.shutdown()
|
||||
|
||||
|
||||
class TestEAGLERetract(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
"--speculative-algorithm",
|
||||
"EAGLE",
|
||||
"--speculative-draft-model-path",
|
||||
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
|
||||
"--speculative-num-steps",
|
||||
"5",
|
||||
"--speculative-eagle-topk",
|
||||
"8",
|
||||
"--speculative-num-draft-tokens",
|
||||
"64",
|
||||
"--mem-fraction-static",
|
||||
"0.7",
|
||||
"--chunked-prefill-size",
|
||||
"128",
|
||||
"--max-running-requests",
|
||||
"64",
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_gsm8k(self):
|
||||
args = SimpleNamespace(
|
||||
num_shots=5,
|
||||
data_path=None,
|
||||
num_questions=200,
|
||||
max_new_tokens=512,
|
||||
parallel=128,
|
||||
host="http://127.0.0.1",
|
||||
port=int(self.base_url.split(":")[-1]),
|
||||
)
|
||||
metrics = run_eval(args)
|
||||
print(f"{metrics=}")
|
||||
|
||||
self.assertGreater(metrics["accuracy"], 0.20)
|
||||
# Wait a little bit so that the memory check happens.
|
||||
time.sleep(5)
|
||||
|
||||
|
||||
class TestEAGLEServerTriton(TestEAGLEServer):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
|
||||
Reference in New Issue
Block a user