2025-01-11 06:00:43 +08:00
import multiprocessing
import random
import time
2025-01-02 19:22:34 +08:00
import unittest
2025-01-11 06:00:43 +08:00
import requests
2025-01-08 13:46:02 +08:00
from transformers import AutoConfig , AutoTokenizer
2025-01-02 19:22:34 +08:00
import sglang as sgl
2025-01-11 06:00:43 +08:00
from sglang . srt . utils import kill_process_tree
from sglang . test . test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH ,
DEFAULT_URL_FOR_TEST ,
popen_launch_server ,
)
2025-01-02 19:22:34 +08:00
class TestEAGLEEngine ( unittest . TestCase ) :
def test_eagle_accuracy ( self ) :
prompt = " Today is a sunny day and I like "
target_model_path = " meta-llama/Llama-2-7b-chat-hf "
speculative_draft_model_path = " lmzheng/sglang-EAGLE-llama2-chat-7B "
sampling_params = { " temperature " : 0 , " max_new_tokens " : 8 }
engine = sgl . Engine (
model_path = target_model_path ,
speculative_draft_model_path = speculative_draft_model_path ,
speculative_algorithm = " EAGLE " ,
speculative_num_steps = 3 ,
speculative_eagle_topk = 4 ,
speculative_num_draft_tokens = 16 ,
)
out1 = engine . generate ( prompt , sampling_params ) [ " text " ]
engine . shutdown ( )
engine = sgl . Engine ( model_path = target_model_path )
out2 = engine . generate ( prompt , sampling_params ) [ " text " ]
engine . shutdown ( )
print ( " ==== Answer 1 ==== " )
print ( out1 )
print ( " ==== Answer 2 ==== " )
print ( out2 )
self . assertEqual ( out1 , out2 )
2025-01-08 13:46:02 +08:00
def test_eagle_end_check ( self ) :
prompt = " [INST] <<SYS>> \\ nYou are a helpful assistant. \\ n<</SYS>> \\ nToday is a sunny day and I like [/INST] "
target_model_path = " meta-llama/Llama-2-7b-chat-hf "
tokenizer = AutoTokenizer . from_pretrained ( target_model_path )
speculative_draft_model_path = " lmzheng/sglang-EAGLE-llama2-chat-7B "
sampling_params = {
" temperature " : 0 ,
" max_new_tokens " : 1024 ,
" skip_special_tokens " : False ,
}
engine = sgl . Engine (
model_path = target_model_path ,
speculative_draft_model_path = speculative_draft_model_path ,
speculative_algorithm = " EAGLE " ,
speculative_num_steps = 3 ,
speculative_eagle_topk = 4 ,
speculative_num_draft_tokens = 16 ,
)
out1 = engine . generate ( prompt , sampling_params ) [ " text " ]
engine . shutdown ( )
print ( " ==== Answer 1 ==== " )
print ( repr ( out1 ) )
tokens = tokenizer . encode ( out1 , truncation = False )
assert tokenizer . eos_token_id not in tokens
2025-01-02 19:22:34 +08:00
2025-01-11 06:00:43 +08:00
prompts = [
" [INST] <<SYS>> \\ nYou are a helpful assistant. \\ n<</SYS>> \\ nToday is a sunny day and I like[/INST] "
' [INST] <<SYS>> \\ nYou are a helpful assistant. \\ n<</SYS>> \\ nWhat are the mental triggers in Jeff Walker \' s Product Launch Formula and " Launch " book?[/INST] ' ,
" [INST] <<SYS>> \\ nYou are a helpful assistant. \\ n<</SYS>> \\ nSummarize Russell Brunson ' s Perfect Webinar Script...[/INST] " ,
" [INST] <<SYS>> \\ nYou are a helpful assistant. \\ n<</SYS>> \\ nwho are you?[/INST] " ,
" [INST] <<SYS>> \\ nYou are a helpful assistant. \\ n<</SYS>> \\ nwhere are you from?[/INST] " ,
]
def process ( server_url : str ) :
time . sleep ( random . uniform ( 0 , 2 ) )
for prompt in prompts :
url = server_url
data = {
" model " : " base " ,
" text " : prompt ,
" sampling_params " : {
" temperature " : 0 ,
" max_new_tokens " : 1024 ,
} ,
}
response = requests . post ( url , json = data )
assert response . status_code == 200
def abort_process ( server_url : str ) :
for prompt in prompts :
try :
time . sleep ( 1 )
url = server_url
data = {
" model " : " base " ,
" text " : prompt ,
" sampling_params " : {
" temperature " : 0 ,
" max_new_tokens " : 1024 ,
} ,
}
# set timeout = 1s,mock disconnected
requests . post ( url , json = data , timeout = 1 )
except :
pass
class TestEAGLELaunchServer ( unittest . TestCase ) :
@classmethod
def setUpClass ( cls ) :
speculative_draft_model_path = " lmzheng/sglang-EAGLE-llama2-chat-7B "
cls . model = " meta-llama/Llama-2-7b-chat-hf "
cls . base_url = DEFAULT_URL_FOR_TEST
cls . process = popen_launch_server (
cls . model ,
cls . base_url ,
timeout = DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH ,
other_args = [
" --speculative-algorithm " ,
" EAGLE " ,
" --speculative-draft-model-path " ,
speculative_draft_model_path ,
" --speculative-num-steps " ,
" 3 " ,
" --speculative-eagle-topk " ,
" 4 " ,
" --speculative-num-draft-tokens " ,
" 16 " ,
" --served-model-name " ,
" base " ,
] ,
)
@classmethod
def tearDownClass ( cls ) :
kill_process_tree ( cls . process . pid )
def test_eagle_server_concurrency ( self ) :
concurrency = 4
processes = [
multiprocessing . Process (
target = process ,
kwargs = { " server_url " : self . base_url + " /generate " } ,
)
for _ in range ( concurrency )
]
for worker in processes :
worker . start ( )
for p in processes :
p . join ( )
def test_eagle_server_request_abort ( self ) :
concurrency = 4
processes = [
multiprocessing . Process (
target = process ,
kwargs = { " server_url " : self . base_url + " /generate " } ,
)
for _ in range ( concurrency )
] + [
multiprocessing . Process (
target = abort_process ,
kwargs = { " server_url " : self . base_url + " /generate " } ,
)
for _ in range ( concurrency )
]
for worker in processes :
worker . start ( )
for p in processes :
p . join ( )
2025-01-02 19:22:34 +08:00
if __name__ == " __main__ " :
unittest . main ( )