Support speculative decoding in hybrid attention backend (#9573)
This commit is contained in:
@@ -7,6 +7,8 @@ import requests
|
||||
from sglang.srt.utils import get_device_sm, kill_process_tree
|
||||
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
|
||||
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
|
||||
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_MODEL_NAME_FOR_TEST_MLA,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
@@ -36,7 +38,7 @@ class TestHybridAttnBackendBase(CustomTestCase):
|
||||
base_url = DEFAULT_URL_FOR_TEST
|
||||
accuracy_threshold = 0.65 # derived tests need to override this
|
||||
speculative_decode = False
|
||||
spec_decode_threshold = 1.0 # derived spec decoding tests need to override this
|
||||
spec_decode_threshold = 2.2 # derived spec decoding tests need to override this
|
||||
|
||||
@classmethod
|
||||
def get_server_args(cls):
|
||||
@@ -49,8 +51,12 @@ class TestHybridAttnBackendBase(CustomTestCase):
|
||||
# please don't do this if you want to make your inference workload faster
|
||||
os.environ["SGL_JIT_DEEPGEMM_PRECOMPILE"] = "false"
|
||||
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false"
|
||||
if cls.speculative_decode:
|
||||
model = DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
|
||||
else:
|
||||
model = cls.model
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=cls.get_server_args(),
|
||||
@@ -105,5 +111,26 @@ class TestHybridAttnBackendTorchCompile(TestHybridAttnBackendBase):
|
||||
return DEFAULT_SERVER_ARGS + ["--enable-torch-compile"]
|
||||
|
||||
|
||||
class TestHybridAttnBackendSpeculativeDecoding(TestHybridAttnBackendBase):
|
||||
speculative_decode = True
|
||||
# This eagle test uses a very small model, so the accuracy is low.
|
||||
accuracy_threshold = 0.2
|
||||
|
||||
@classmethod
|
||||
def get_server_args(cls):
|
||||
return DEFAULT_SERVER_ARGS + [
|
||||
"--speculative-algorithm",
|
||||
"EAGLE",
|
||||
"--speculative-draft",
|
||||
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
|
||||
"--speculative-num-steps",
|
||||
"3",
|
||||
"--speculative-eagle-topk",
|
||||
"2",
|
||||
"--speculative-num-draft-tokens",
|
||||
"4",
|
||||
]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user