2025-03-05 08:06:07 -08:00
import multiprocessing as mp
2025-01-11 06:00:43 +08:00
import random
2025-01-21 02:55:14 -08:00
import threading
2025-01-11 06:00:43 +08:00
import time
2025-01-02 19:22:34 +08:00
import unittest
2025-01-21 02:55:14 -08:00
from types import SimpleNamespace
2025-01-02 19:22:34 +08:00
2025-01-11 06:00:43 +08:00
import requests
2025-01-08 13:46:02 +08:00
2025-01-02 19:22:34 +08:00
import sglang as sgl
2025-01-21 02:55:14 -08:00
from sglang . srt . hf_transformers_utils import get_tokenizer
2025-01-11 06:00:43 +08:00
from sglang . srt . utils import kill_process_tree
2025-01-21 02:55:14 -08:00
from sglang . test . few_shot_gsm8k import run_eval
2025-01-11 06:00:43 +08:00
from sglang . test . test_utils import (
2025-01-21 02:55:14 -08:00
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST ,
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST ,
2025-01-11 06:00:43 +08:00
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH ,
DEFAULT_URL_FOR_TEST ,
popen_launch_server ,
)
2025-01-02 19:22:34 +08:00
2025-03-05 08:06:07 -08:00
acc_rate_tolerance = 0.15
2025-01-02 19:22:34 +08:00
class TestEAGLEEngine ( unittest . TestCase ) :
2025-02-09 09:34:30 +08:00
BASE_CONFIG = {
" model_path " : DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST ,
" speculative_draft_model_path " : DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST ,
" speculative_algorithm " : " EAGLE " ,
" speculative_num_steps " : 5 ,
" speculative_eagle_topk " : 8 ,
" speculative_num_draft_tokens " : 64 ,
" mem_fraction_static " : 0.7 ,
2025-02-14 08:50:14 +08:00
" cuda_graph_max_bs " : 32 ,
2025-02-09 09:34:30 +08:00
}
2025-01-02 19:22:34 +08:00
2025-02-09 09:34:30 +08:00
def setUp ( self ) :
self . prompt = " Today is a sunny day and I like "
self . sampling_params = { " temperature " : 0 , " max_new_tokens " : 8 }
2025-01-02 19:22:34 +08:00
2025-01-21 02:55:14 -08:00
ref_engine = sgl . Engine ( model_path = DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST )
2025-02-09 09:34:30 +08:00
self . ref_output = ref_engine . generate ( self . prompt , self . sampling_params ) [ " text " ]
2025-01-21 02:55:14 -08:00
ref_engine . shutdown ( )
2025-03-04 13:40:40 -08:00
def test_correctness ( self ) :
2025-02-09 08:02:56 +08:00
configs = [
2025-02-09 09:34:30 +08:00
self . BASE_CONFIG ,
{ * * self . BASE_CONFIG , " disable_cuda_graph " : True } ,
2025-03-05 08:06:07 -08:00
{ * * self . BASE_CONFIG , " chunked_prefill_size " : 2 } ,
2025-02-09 08:02:56 +08:00
]
2025-01-02 19:22:34 +08:00
2025-02-09 08:02:56 +08:00
for config in configs :
2025-02-09 09:34:30 +08:00
with self . subTest (
cuda_graph = (
" enabled " if len ( config ) == len ( self . BASE_CONFIG ) else " disabled "
2025-03-05 08:06:07 -08:00
) ,
chunked_prefill_size = (
config [ " chunked_prefill_size " ]
if " chunked_prefill_size " in config
else " default "
) ,
2025-02-09 09:34:30 +08:00
) :
engine = sgl . Engine ( * * config )
try :
self . _test_basic_generation ( engine )
self . _test_eos_token ( engine )
self . _test_batch_generation ( engine )
finally :
engine . shutdown ( )
def _test_basic_generation ( self , engine ) :
output = engine . generate ( self . prompt , self . sampling_params ) [ " text " ]
print ( f " { output =} , { self . ref_output =} " )
self . assertEqual ( output , self . ref_output )
def _test_eos_token ( self , engine ) :
prompt = " [INST] <<SYS>> \n You are a helpful assistant. \n <</SYS>> \n Today is a sunny day and I like [/INST] "
params = {
" temperature " : 0 ,
" max_new_tokens " : 1024 ,
" skip_special_tokens " : False ,
}
tokenizer = get_tokenizer ( DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST )
output = engine . generate ( prompt , params ) [ " text " ]
print ( f " { output =} " )
tokens = tokenizer . encode ( output , truncation = False )
self . assertNotIn ( tokenizer . eos_token_id , tokens )
def _test_batch_generation ( self , engine ) :
prompts = [
" Hello, my name is " ,
" The president of the United States is " ,
" The capital of France is " ,
" The future of AI is " ,
]
params = { " temperature " : 0 , " max_new_tokens " : 30 }
outputs = engine . generate ( prompts , params )
for prompt , output in zip ( prompts , outputs ) :
print ( f " Prompt: { prompt } " )
print ( f " Generated: { output [ ' text ' ] } " )
print ( " - " * 40 )
2025-01-21 02:55:14 -08:00
2025-01-02 19:22:34 +08:00
2025-01-11 06:00:43 +08:00
prompts = [
" [INST] <<SYS>> \\ nYou are a helpful assistant. \\ n<</SYS>> \\ nToday is a sunny day and I like[/INST] "
' [INST] <<SYS>> \\ nYou are a helpful assistant. \\ n<</SYS>> \\ nWhat are the mental triggers in Jeff Walker \' s Product Launch Formula and " Launch " book?[/INST] ' ,
" [INST] <<SYS>> \\ nYou are a helpful assistant. \\ n<</SYS>> \\ nSummarize Russell Brunson ' s Perfect Webinar Script...[/INST] " ,
" [INST] <<SYS>> \\ nYou are a helpful assistant. \\ n<</SYS>> \\ nwho are you?[/INST] " ,
" [INST] <<SYS>> \\ nYou are a helpful assistant. \\ n<</SYS>> \\ nwhere are you from?[/INST] " ,
]
2025-01-21 02:55:14 -08:00
class TestEAGLEServer ( unittest . TestCase ) :
2025-01-11 06:00:43 +08:00
@classmethod
def setUpClass ( cls ) :
cls . base_url = DEFAULT_URL_FOR_TEST
cls . process = popen_launch_server (
2025-01-21 02:55:14 -08:00
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST ,
2025-01-11 06:00:43 +08:00
cls . base_url ,
timeout = DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH ,
other_args = [
" --speculative-algorithm " ,
" EAGLE " ,
" --speculative-draft-model-path " ,
2025-01-21 02:55:14 -08:00
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST ,
2025-01-11 06:00:43 +08:00
" --speculative-num-steps " ,
2025-01-21 02:55:14 -08:00
" 5 " ,
2025-01-11 06:00:43 +08:00
" --speculative-eagle-topk " ,
2025-01-21 02:55:14 -08:00
" 8 " ,
2025-01-11 06:00:43 +08:00
" --speculative-num-draft-tokens " ,
2025-01-21 02:55:14 -08:00
" 64 " ,
" --mem-fraction-static " ,
" 0.7 " ,
2025-03-05 08:06:07 -08:00
" --chunked-prefill-size " ,
" 128 " ,
2025-02-14 08:50:14 +08:00
" --cuda-graph-max-bs " ,
" 32 " ,
2025-01-11 06:00:43 +08:00
] ,
)
@classmethod
def tearDownClass ( cls ) :
kill_process_tree ( cls . process . pid )
2025-01-21 02:55:14 -08:00
def send_request ( self ) :
time . sleep ( random . uniform ( 0 , 2 ) )
for prompt in prompts :
url = self . base_url + " /generate "
data = {
" text " : prompt ,
" sampling_params " : {
" temperature " : 0 ,
" max_new_tokens " : 1024 ,
} ,
}
response = requests . post ( url , json = data )
assert response . status_code == 200
def send_requests_abort ( self ) :
for prompt in prompts :
try :
time . sleep ( random . uniform ( 0 , 2 ) )
url = self . base_url + " /generate "
data = {
" model " : " base " ,
" text " : prompt ,
" sampling_params " : {
" temperature " : 0 ,
" max_new_tokens " : 1024 ,
} ,
}
2025-03-04 13:40:40 -08:00
# set timeout = 1s, mock disconnected
2025-01-21 02:55:14 -08:00
requests . post ( url , json = data , timeout = 1 )
except Exception as e :
print ( e )
pass
def test_request_abort ( self ) :
2025-01-11 06:00:43 +08:00
concurrency = 4
2025-01-21 02:55:14 -08:00
threads = [
threading . Thread ( target = self . send_request ) for _ in range ( concurrency )
2025-01-11 06:00:43 +08:00
] + [
2025-01-21 02:55:14 -08:00
threading . Thread ( target = self . send_requests_abort )
2025-01-11 06:00:43 +08:00
for _ in range ( concurrency )
]
2025-01-21 02:55:14 -08:00
for worker in threads :
2025-01-11 06:00:43 +08:00
worker . start ( )
2025-01-21 02:55:14 -08:00
for p in threads :
2025-01-11 06:00:43 +08:00
p . join ( )
2025-01-21 02:55:14 -08:00
def test_gsm8k ( self ) :
args = SimpleNamespace (
num_shots = 5 ,
data_path = None ,
num_questions = 200 ,
max_new_tokens = 512 ,
parallel = 128 ,
host = " http://127.0.0.1 " ,
port = int ( self . base_url . split ( " : " ) [ - 1 ] ) ,
)
metrics = run_eval ( args )
print ( f " { metrics =} " )
self . assertGreater ( metrics [ " accuracy " ] , 0.20 )
2025-01-11 06:00:43 +08:00
2025-03-05 08:06:07 -08:00
def measure_acc_rate ( engine ) :
tic = time . time ( )
prompt = [
" Human: Give me a fully functional FastAPI server. Show the python code.<|separator|> \n \n Assistant: "
]
sampling_params = { " temperature " : 0 , " max_new_tokens " : 512 }
output = engine . generate ( prompt , sampling_params )
output = output [ 0 ]
latency = time . time ( ) - tic
if " spec_verify_ct " in output [ " meta_info " ] :
base_acc_length = (
output [ " meta_info " ] [ " completion_tokens " ]
/ output [ " meta_info " ] [ " spec_verify_ct " ]
)
else :
base_acc_length = 0.0
base_speed = output [ " meta_info " ] [ " completion_tokens " ] / latency
return base_acc_length , base_speed
class TestEagleAcceptanceRate ( unittest . TestCase ) :
@classmethod
def setUpClass ( cls ) :
mp . set_start_method ( " spawn " , force = True )
ref_engine = sgl . Engine (
model_path = DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST ,
speculative_draft_model_path = DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST ,
speculative_algorithm = " EAGLE " ,
speculative_num_steps = 5 ,
speculative_eagle_topk = 8 ,
speculative_num_draft_tokens = 64 ,
mem_fraction_static = 0.7 ,
disable_radix_cache = True ,
)
cls . base_acc_length , cls . base_speed = measure_acc_rate ( ref_engine )
ref_engine . shutdown ( )
assert cls . base_acc_length > 4.45
def test_acc_rate ( self ) :
base_acc_length , base_speed = self . base_acc_length , self . base_speed
chunk_engine = sgl . Engine (
model_path = DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST ,
speculative_draft_model_path = DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST ,
speculative_algorithm = " EAGLE " ,
speculative_num_steps = 5 ,
speculative_eagle_topk = 8 ,
speculative_num_draft_tokens = 64 ,
mem_fraction_static = 0.7 ,
chunked_prefill_size = 2 ,
disable_radix_cache = True ,
)
chunked_acc_length , chunked_base_speed = measure_acc_rate ( chunk_engine )
chunk_engine . shutdown ( )
print ( base_acc_length , base_speed )
print ( chunked_acc_length , chunked_base_speed )
assert abs ( base_acc_length - chunked_acc_length ) < acc_rate_tolerance
def test_acc_rate_prefix_caching ( self ) :
base_acc_length , base_speed = self . base_acc_length , self . base_speed
prefix_caching_engine = sgl . Engine (
model_path = DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST ,
speculative_draft_model_path = DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST ,
speculative_algorithm = " EAGLE " ,
speculative_num_steps = 5 ,
speculative_eagle_topk = 8 ,
speculative_num_draft_tokens = 64 ,
mem_fraction_static = 0.7 ,
chunked_prefill_size = 4 ,
schedule_policy = " lpm " ,
)
for _ in range ( 10 ) :
acc_length , _ = measure_acc_rate ( prefix_caching_engine )
print ( f " { acc_length =} " )
assert abs ( base_acc_length - acc_length ) < acc_rate_tolerance
# The second one should hit the prefix cache.
prefix_caching_engine . shutdown ( )
class TestEAGLERetract ( unittest . TestCase ) :
@classmethod
def setUpClass ( cls ) :
cls . base_url = DEFAULT_URL_FOR_TEST
cls . process = popen_launch_server (
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST ,
cls . base_url ,
timeout = DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH ,
other_args = [
" --speculative-algorithm " ,
" EAGLE " ,
" --speculative-draft-model-path " ,
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST ,
" --speculative-num-steps " ,
" 5 " ,
" --speculative-eagle-topk " ,
" 8 " ,
" --speculative-num-draft-tokens " ,
" 64 " ,
" --mem-fraction-static " ,
" 0.7 " ,
" --chunked-prefill-size " ,
" 128 " ,
" --max-running-requests " ,
" 64 " ,
] ,
)
@classmethod
def tearDownClass ( cls ) :
kill_process_tree ( cls . process . pid )
def test_gsm8k ( self ) :
args = SimpleNamespace (
num_shots = 5 ,
data_path = None ,
num_questions = 200 ,
max_new_tokens = 512 ,
parallel = 128 ,
host = " http://127.0.0.1 " ,
port = int ( self . base_url . split ( " : " ) [ - 1 ] ) ,
)
metrics = run_eval ( args )
print ( f " { metrics =} " )
self . assertGreater ( metrics [ " accuracy " ] , 0.20 )
# Wait a little bit so that the memory check happens.
time . sleep ( 5 )
2025-02-10 20:00:42 +08:00
class TestEAGLEServerTriton ( TestEAGLEServer ) :
@classmethod
def setUpClass ( cls ) :
cls . base_url = DEFAULT_URL_FOR_TEST
cls . process = popen_launch_server (
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST ,
cls . base_url ,
timeout = DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH ,
other_args = [
" --speculative-algorithm " ,
" EAGLE " ,
" --speculative-draft-model-path " ,
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST ,
" --speculative-num-steps " ,
" 5 " ,
" --speculative-eagle-topk " ,
2025-03-04 13:40:40 -08:00
" 4 " ,
2025-02-10 20:00:42 +08:00
" --speculative-num-draft-tokens " ,
2025-03-04 13:40:40 -08:00
" 8 " ,
2025-02-10 20:00:42 +08:00
" --mem-fraction-static " ,
" 0.7 " ,
" --attention-backend " ,
" triton " ,
2025-02-14 09:43:14 +08:00
" --cuda-graph-max-bs " ,
2025-03-04 13:40:40 -08:00
" 16 " ,
2025-02-10 20:00:42 +08:00
] ,
)
2025-03-04 13:40:40 -08:00
class TestEAGLEEngineTokenMap ( unittest . TestCase ) :
def setUp ( self ) :
self . prompt = " Today is a sunny day and I like "
self . sampling_params = { " temperature " : 0 , " max_new_tokens " : 8 }
ref_engine = sgl . Engine (
model_path = " meta-llama/Meta-Llama-3-8B-Instruct " , cuda_graph_max_bs = 2
)
self . ref_output = ref_engine . generate ( self . prompt , self . sampling_params ) [ " text " ]
ref_engine . shutdown ( )
def test_correctness ( self ) :
config = {
" model_path " : " meta-llama/Meta-Llama-3-8B-Instruct " ,
" speculative_draft_model_path " : " lmsys/sglang-EAGLE-LLaMA3-Instruct-8B " ,
" speculative_algorithm " : " EAGLE " ,
" speculative_num_steps " : 5 ,
" speculative_eagle_topk " : 4 ,
" speculative_num_draft_tokens " : 8 ,
" speculative_token_map " : " thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt " ,
" mem_fraction_static " : 0.7 ,
" cuda_graph_max_bs " : 4 ,
" dtype " : " bfloat16 " ,
}
engine = sgl . Engine ( * * config )
try :
self . _test_basic_generation ( engine )
self . _test_batch_generation ( engine )
finally :
engine . shutdown ( )
def _test_basic_generation ( self , engine ) :
output = engine . generate ( self . prompt , self . sampling_params ) [ " text " ]
print ( f " { output =} , { self . ref_output =} " )
self . assertEqual ( output , self . ref_output )
def _test_batch_generation ( self , engine ) :
prompts = [
" Hello, my name is " ,
" The president of the United States is " ,
" The capital of France is " ,
" The future of AI is " ,
]
params = { " temperature " : 0 , " max_new_tokens " : 30 }
outputs = engine . generate ( prompts , params )
for prompt , output in zip ( prompts , outputs ) :
print ( f " Prompt: { prompt } " )
print ( f " Generated: { output [ ' text ' ] } " )
print ( " - " * 40 )
2025-01-02 19:22:34 +08:00
if __name__ == " __main__ " :
unittest . main ( )