2025-01-11 06:00:43 +08:00
import random
2025-01-21 02:55:14 -08:00
import threading
2025-01-11 06:00:43 +08:00
import time
2025-01-02 19:22:34 +08:00
import unittest
2025-01-21 02:55:14 -08:00
from types import SimpleNamespace
2025-01-02 19:22:34 +08:00
2025-01-11 06:00:43 +08:00
import requests
2025-01-08 13:46:02 +08:00
2025-01-02 19:22:34 +08:00
import sglang as sgl
2025-01-21 02:55:14 -08:00
from sglang . srt . hf_transformers_utils import get_tokenizer
2025-01-11 06:00:43 +08:00
from sglang . srt . utils import kill_process_tree
2025-01-21 02:55:14 -08:00
from sglang . test . few_shot_gsm8k import run_eval
2025-01-11 06:00:43 +08:00
from sglang . test . test_utils import (
2025-01-21 02:55:14 -08:00
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST ,
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST ,
2025-01-11 06:00:43 +08:00
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH ,
DEFAULT_URL_FOR_TEST ,
popen_launch_server ,
)
2025-01-02 19:22:34 +08:00
class TestEAGLEEngine ( unittest . TestCase ) :
def test_eagle_accuracy ( self ) :
prompt = " Today is a sunny day and I like "
sampling_params = { " temperature " : 0 , " max_new_tokens " : 8 }
2025-01-21 02:55:14 -08:00
# Get the reference output
ref_engine = sgl . Engine ( model_path = DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST )
ref_output = ref_engine . generate ( prompt , sampling_params ) [ " text " ]
ref_engine . shutdown ( )
# Launch EAGLE engine
2025-01-02 19:22:34 +08:00
engine = sgl . Engine (
2025-01-21 02:55:14 -08:00
model_path = DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST ,
speculative_draft_model_path = DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST ,
2025-01-02 19:22:34 +08:00
speculative_algorithm = " EAGLE " ,
2025-01-21 02:55:14 -08:00
speculative_num_steps = 5 ,
speculative_eagle_topk = 8 ,
speculative_num_draft_tokens = 64 ,
mem_fraction_static = 0.7 ,
2025-01-02 19:22:34 +08:00
)
2025-01-21 02:55:14 -08:00
# Case 1: Test the output of EAGLE engine is the same as normal engine
out1 = engine . generate ( prompt , sampling_params ) [ " text " ]
print ( f " { out1 =} , { ref_output =} " )
self . assertEqual ( out1 , ref_output )
2025-01-02 19:22:34 +08:00
2025-01-21 02:55:14 -08:00
# Case 2: Test the output of EAGLE engine does not contain unexpected EOS
2025-01-08 13:46:02 +08:00
prompt = " [INST] <<SYS>> \\ nYou are a helpful assistant. \\ n<</SYS>> \\ nToday is a sunny day and I like [/INST] "
sampling_params = {
" temperature " : 0 ,
" max_new_tokens " : 1024 ,
" skip_special_tokens " : False ,
}
2025-01-21 02:55:14 -08:00
tokenizer = get_tokenizer ( DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST )
out2 = engine . generate ( prompt , sampling_params ) [ " text " ]
print ( f " { out2 =} " )
tokens = tokenizer . encode ( out2 , truncation = False )
2025-01-08 13:46:02 +08:00
assert tokenizer . eos_token_id not in tokens
2025-01-21 02:55:14 -08:00
# Case 3: Batched prompts
prompts = [
" Hello, my name is " ,
" The president of the United States is " ,
" The capital of France is " ,
" The future of AI is " ,
]
sampling_params = { " temperature " : 0 , " max_new_tokens " : 30 }
outputs = engine . generate ( prompts , sampling_params )
for prompt , output in zip ( prompts , outputs ) :
print ( " =============================== " )
print ( f " Prompt: { prompt } \n Generated text: { output [ ' text ' ] } " )
# Shutdown the engine
engine . shutdown ( )
2025-01-02 19:22:34 +08:00
2025-01-11 06:00:43 +08:00
prompts = [
" [INST] <<SYS>> \\ nYou are a helpful assistant. \\ n<</SYS>> \\ nToday is a sunny day and I like[/INST] "
' [INST] <<SYS>> \\ nYou are a helpful assistant. \\ n<</SYS>> \\ nWhat are the mental triggers in Jeff Walker \' s Product Launch Formula and " Launch " book?[/INST] ' ,
" [INST] <<SYS>> \\ nYou are a helpful assistant. \\ n<</SYS>> \\ nSummarize Russell Brunson ' s Perfect Webinar Script...[/INST] " ,
" [INST] <<SYS>> \\ nYou are a helpful assistant. \\ n<</SYS>> \\ nwho are you?[/INST] " ,
" [INST] <<SYS>> \\ nYou are a helpful assistant. \\ n<</SYS>> \\ nwhere are you from?[/INST] " ,
]
2025-01-21 02:55:14 -08:00
class TestEAGLEServer ( unittest . TestCase ) :
2025-01-11 06:00:43 +08:00
@classmethod
def setUpClass ( cls ) :
cls . base_url = DEFAULT_URL_FOR_TEST
cls . process = popen_launch_server (
2025-01-21 02:55:14 -08:00
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST ,
2025-01-11 06:00:43 +08:00
cls . base_url ,
timeout = DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH ,
other_args = [
" --speculative-algorithm " ,
" EAGLE " ,
" --speculative-draft-model-path " ,
2025-01-21 02:55:14 -08:00
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST ,
2025-01-11 06:00:43 +08:00
" --speculative-num-steps " ,
2025-01-21 02:55:14 -08:00
" 5 " ,
2025-01-11 06:00:43 +08:00
" --speculative-eagle-topk " ,
2025-01-21 02:55:14 -08:00
" 8 " ,
2025-01-11 06:00:43 +08:00
" --speculative-num-draft-tokens " ,
2025-01-21 02:55:14 -08:00
" 64 " ,
" --mem-fraction-static " ,
" 0.7 " ,
2025-01-11 06:00:43 +08:00
] ,
)
@classmethod
def tearDownClass ( cls ) :
kill_process_tree ( cls . process . pid )
2025-01-21 02:55:14 -08:00
def send_request ( self ) :
time . sleep ( random . uniform ( 0 , 2 ) )
for prompt in prompts :
url = self . base_url + " /generate "
data = {
" text " : prompt ,
" sampling_params " : {
" temperature " : 0 ,
" max_new_tokens " : 1024 ,
} ,
}
response = requests . post ( url , json = data )
assert response . status_code == 200
def send_requests_abort ( self ) :
for prompt in prompts :
try :
time . sleep ( random . uniform ( 0 , 2 ) )
url = self . base_url + " /generate "
data = {
" model " : " base " ,
" text " : prompt ,
" sampling_params " : {
" temperature " : 0 ,
" max_new_tokens " : 1024 ,
} ,
}
# set timeout = 1s,mock disconnected
requests . post ( url , json = data , timeout = 1 )
except Exception as e :
print ( e )
pass
def test_request_abort ( self ) :
2025-01-11 06:00:43 +08:00
concurrency = 4
2025-01-21 02:55:14 -08:00
threads = [
threading . Thread ( target = self . send_request ) for _ in range ( concurrency )
2025-01-11 06:00:43 +08:00
] + [
2025-01-21 02:55:14 -08:00
threading . Thread ( target = self . send_requests_abort )
2025-01-11 06:00:43 +08:00
for _ in range ( concurrency )
]
2025-01-21 02:55:14 -08:00
for worker in threads :
2025-01-11 06:00:43 +08:00
worker . start ( )
2025-01-21 02:55:14 -08:00
for p in threads :
2025-01-11 06:00:43 +08:00
p . join ( )
2025-01-21 02:55:14 -08:00
def test_gsm8k ( self ) :
args = SimpleNamespace (
num_shots = 5 ,
data_path = None ,
num_questions = 200 ,
max_new_tokens = 512 ,
parallel = 128 ,
host = " http://127.0.0.1 " ,
port = int ( self . base_url . split ( " : " ) [ - 1 ] ) ,
)
metrics = run_eval ( args )
print ( f " { metrics =} " )
self . assertGreater ( metrics [ " accuracy " ] , 0.20 )
2025-01-11 06:00:43 +08:00
2025-01-02 19:22:34 +08:00
if __name__ == " __main__ " :
unittest . main ( )