2025-03-06 06:13:59 -08:00
import json
2025-03-06 00:13:20 -08:00
import os
2025-01-11 06:00:43 +08:00
import random
2025-01-21 02:55:14 -08:00
import threading
2025-01-11 06:00:43 +08:00
import time
2025-01-02 19:22:34 +08:00
import unittest
2025-03-06 06:13:59 -08:00
from concurrent . futures import ThreadPoolExecutor
from functools import partial
2025-01-21 02:55:14 -08:00
from types import SimpleNamespace
2025-01-02 19:22:34 +08:00
2025-03-06 06:13:59 -08:00
import numpy as np
2025-01-11 06:00:43 +08:00
import requests
2025-03-06 00:13:20 -08:00
import torch
2025-01-08 13:46:02 +08:00
2025-01-02 19:22:34 +08:00
import sglang as sgl
2025-01-21 02:55:14 -08:00
from sglang . srt . hf_transformers_utils import get_tokenizer
2025-01-11 06:00:43 +08:00
from sglang . srt . utils import kill_process_tree
2025-01-21 02:55:14 -08:00
from sglang . test . few_shot_gsm8k import run_eval
2025-01-11 06:00:43 +08:00
from sglang . test . test_utils import (
2025-01-21 02:55:14 -08:00
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST ,
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST ,
2025-01-11 06:00:43 +08:00
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH ,
DEFAULT_URL_FOR_TEST ,
2025-03-26 07:53:12 +08:00
CustomTestCase ,
2025-01-11 06:00:43 +08:00
popen_launch_server ,
2025-03-06 06:13:59 -08:00
run_logprob_check ,
2025-01-11 06:00:43 +08:00
)
2025-01-02 19:22:34 +08:00
2025-03-06 00:13:20 -08:00
torch_dtype = torch . float16
prefill_tolerance = 5e-2
decode_tolerance : float = 5e-2
2025-03-05 08:06:07 -08:00
2025-01-02 19:22:34 +08:00
2025-03-26 07:53:12 +08:00
class TestEAGLEEngine ( CustomTestCase ) :
2025-02-09 09:34:30 +08:00
BASE_CONFIG = {
" model_path " : DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST ,
" speculative_draft_model_path " : DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST ,
" speculative_algorithm " : " EAGLE " ,
" speculative_num_steps " : 5 ,
2025-03-06 00:13:20 -08:00
" speculative_eagle_topk " : 4 ,
" speculative_num_draft_tokens " : 8 ,
2025-02-09 09:34:30 +08:00
" mem_fraction_static " : 0.7 ,
2025-04-27 01:00:54 -07:00
" cuda_graph_max_bs " : 5 ,
2025-02-09 09:34:30 +08:00
}
2025-03-28 10:34:10 -07:00
NUM_CONFIGS = 2
2025-01-02 19:22:34 +08:00
2025-02-09 09:34:30 +08:00
def setUp ( self ) :
self . prompt = " Today is a sunny day and I like "
self . sampling_params = { " temperature " : 0 , " max_new_tokens " : 8 }
2025-01-02 19:22:34 +08:00
2025-03-06 00:13:20 -08:00
ref_engine = sgl . Engine (
model_path = self . BASE_CONFIG [ " model_path " ] , cuda_graph_max_bs = 1
)
2025-02-09 09:34:30 +08:00
self . ref_output = ref_engine . generate ( self . prompt , self . sampling_params ) [ " text " ]
2025-01-21 02:55:14 -08:00
ref_engine . shutdown ( )
2025-03-04 13:40:40 -08:00
def test_correctness ( self ) :
2025-02-09 08:02:56 +08:00
configs = [
2025-03-06 00:13:20 -08:00
# Basic config
2025-02-09 09:34:30 +08:00
self . BASE_CONFIG ,
2025-03-06 00:13:20 -08:00
# Chunked prefill
{ * * self . BASE_CONFIG , " chunked_prefill_size " : 4 } ,
2025-02-09 08:02:56 +08:00
]
2025-01-02 19:22:34 +08:00
2025-03-06 00:13:20 -08:00
for i , config in enumerate ( configs [ : self . NUM_CONFIGS ] ) :
with self . subTest ( i = i ) :
print ( f " { config =} " )
engine = sgl . Engine ( * * config , log_level = " info " , decode_log_interval = 10 )
2025-02-09 09:34:30 +08:00
try :
2025-03-06 00:13:20 -08:00
self . _test_single_generation ( engine )
2025-02-09 09:34:30 +08:00
self . _test_batch_generation ( engine )
2025-03-06 00:13:20 -08:00
self . _test_eos_token ( engine )
self . _test_acc_length ( engine )
2025-02-09 09:34:30 +08:00
finally :
engine . shutdown ( )
2025-03-06 00:13:20 -08:00
print ( " = " * 100 )
2025-02-09 09:34:30 +08:00
2025-03-06 00:13:20 -08:00
def _test_single_generation ( self , engine ) :
2025-02-09 09:34:30 +08:00
output = engine . generate ( self . prompt , self . sampling_params ) [ " text " ]
print ( f " { output =} , { self . ref_output =} " )
self . assertEqual ( output , self . ref_output )
2025-03-06 00:13:20 -08:00
def _test_batch_generation ( self , engine ) :
prompts = [
" Hello, my name is " ,
" The president of the United States is " ,
" The capital of France is " ,
" The future of AI is " ,
]
params = { " temperature " : 0 , " max_new_tokens " : 50 }
outputs = engine . generate ( prompts , params )
for prompt , output in zip ( prompts , outputs ) :
print ( f " Prompt: { prompt } " )
print ( f " Generated: { output [ ' text ' ] } " )
print ( " - " * 40 )
print ( f " { engine . get_server_info ( ) =} " )
2025-05-12 00:17:33 -07:00
avg_spec_accept_length = engine . get_server_info ( ) [ " internal_states " ] [ 0 ] [
" avg_spec_accept_length "
]
2025-03-06 00:13:20 -08:00
print ( f " { avg_spec_accept_length =} " )
self . assertGreater ( avg_spec_accept_length , 1.9 )
2025-02-09 09:34:30 +08:00
def _test_eos_token ( self , engine ) :
prompt = " [INST] <<SYS>> \n You are a helpful assistant. \n <</SYS>> \n Today is a sunny day and I like [/INST] "
params = {
2025-03-08 00:41:35 -08:00
" temperature " : 0.1 ,
2025-02-09 09:34:30 +08:00
" max_new_tokens " : 1024 ,
" skip_special_tokens " : False ,
}
tokenizer = get_tokenizer ( DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST )
output = engine . generate ( prompt , params ) [ " text " ]
print ( f " { output =} " )
tokens = tokenizer . encode ( output , truncation = False )
self . assertNotIn ( tokenizer . eos_token_id , tokens )
2025-03-06 00:13:20 -08:00
def _test_acc_length ( self , engine ) :
prompt = [
2025-03-16 02:48:55 -07:00
" Human: Give me a fully functional FastAPI server. Show the python code. \n \n Assistant: " ,
] * 5 # test batched generation
2025-03-06 00:13:20 -08:00
sampling_params = { " temperature " : 0 , " max_new_tokens " : 512 }
output = engine . generate ( prompt , sampling_params )
output = output [ 0 ]
if " spec_verify_ct " in output [ " meta_info " ] :
acc_length = (
output [ " meta_info " ] [ " completion_tokens " ]
/ output [ " meta_info " ] [ " spec_verify_ct " ]
)
else :
acc_length = 1.0
speed = (
output [ " meta_info " ] [ " completion_tokens " ]
/ output [ " meta_info " ] [ " e2e_latency " ]
)
print ( f " { acc_length =} " )
2025-03-07 22:12:13 -08:00
if engine . server_args . model_path == DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST :
self . assertGreater ( acc_length , 3.6 )
else :
2025-03-30 21:34:21 -07:00
self . assertGreater ( acc_length , 2.5 )
2025-01-21 02:55:14 -08:00
2025-01-02 19:22:34 +08:00
2025-03-07 22:12:13 -08:00
class TestEAGLEEngineTokenMap ( TestEAGLEEngine ) :
2025-03-06 00:13:20 -08:00
BASE_CONFIG = {
" model_path " : " meta-llama/Meta-Llama-3-8B-Instruct " ,
" speculative_draft_model_path " : " lmsys/sglang-EAGLE-LLaMA3-Instruct-8B " ,
" speculative_algorithm " : " EAGLE " ,
" speculative_num_steps " : 5 ,
" speculative_eagle_topk " : 4 ,
" speculative_num_draft_tokens " : 8 ,
" speculative_token_map " : " thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt " ,
" mem_fraction_static " : 0.7 ,
2025-04-27 01:00:54 -07:00
" cuda_graph_max_bs " : 5 ,
2025-03-07 22:12:13 -08:00
" dtype " : " float16 " ,
2025-03-06 00:13:20 -08:00
}
NUM_CONFIGS = 1
2025-01-11 06:00:43 +08:00
2025-03-18 10:35:23 -04:00
class TestEAGLE3Engine ( TestEAGLEEngine ) :
BASE_CONFIG = {
" model_path " : " meta-llama/Llama-3.1-8B-Instruct " ,
" speculative_draft_model_path " : " jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B " ,
" speculative_algorithm " : " EAGLE3 " ,
" speculative_num_steps " : 5 ,
" speculative_eagle_topk " : 16 ,
" speculative_num_draft_tokens " : 64 ,
" mem_fraction_static " : 0.7 ,
2025-04-27 01:00:54 -07:00
" cuda_graph_max_bs " : 5 ,
2025-03-18 10:35:23 -04:00
" dtype " : " float16 " ,
}
NUM_CONFIGS = 1
2025-03-26 07:53:12 +08:00
class TestEAGLEServer ( CustomTestCase ) :
2025-03-06 00:13:20 -08:00
PROMPTS = [
" [INST] <<SYS>> \\ nYou are a helpful assistant. \\ n<</SYS>> \\ nToday is a sunny day and I like[/INST] "
' [INST] <<SYS>> \\ nYou are a helpful assistant. \\ n<</SYS>> \\ nWhat are the mental triggers in Jeff Walker \' s Product Launch Formula and " Launch " book?[/INST] ' ,
" [INST] <<SYS>> \\ nYou are a helpful assistant. \\ n<</SYS>> \\ nSummarize Russell Brunson ' s Perfect Webinar Script...[/INST] " ,
" [INST] <<SYS>> \\ nYou are a helpful assistant. \\ n<</SYS>> \\ nwho are you?[/INST] " ,
" [INST] <<SYS>> \\ nYou are a helpful assistant. \\ n<</SYS>> \\ nwhere are you from?[/INST] " ,
]
2025-01-11 06:00:43 +08:00
@classmethod
def setUpClass ( cls ) :
cls . base_url = DEFAULT_URL_FOR_TEST
cls . process = popen_launch_server (
2025-01-21 02:55:14 -08:00
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST ,
2025-01-11 06:00:43 +08:00
cls . base_url ,
timeout = DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH ,
other_args = [
" --speculative-algorithm " ,
" EAGLE " ,
" --speculative-draft-model-path " ,
2025-01-21 02:55:14 -08:00
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST ,
2025-01-11 06:00:43 +08:00
" --speculative-num-steps " ,
2025-03-06 00:13:20 -08:00
5 ,
2025-01-11 06:00:43 +08:00
" --speculative-eagle-topk " ,
2025-03-06 00:13:20 -08:00
8 ,
2025-01-11 06:00:43 +08:00
" --speculative-num-draft-tokens " ,
2025-03-06 00:13:20 -08:00
64 ,
2025-01-21 02:55:14 -08:00
" --mem-fraction-static " ,
2025-03-06 00:13:20 -08:00
0.7 ,
2025-03-05 08:06:07 -08:00
" --chunked-prefill-size " ,
2025-03-06 00:13:20 -08:00
128 ,
" --max-running-requests " ,
8 ,
2025-01-11 06:00:43 +08:00
] ,
)
@classmethod
def tearDownClass ( cls ) :
kill_process_tree ( cls . process . pid )
2025-01-21 02:55:14 -08:00
def send_request ( self ) :
time . sleep ( random . uniform ( 0 , 2 ) )
2025-03-06 00:13:20 -08:00
for prompt in self . PROMPTS :
2025-01-21 02:55:14 -08:00
url = self . base_url + " /generate "
data = {
" text " : prompt ,
" sampling_params " : {
" temperature " : 0 ,
" max_new_tokens " : 1024 ,
} ,
}
response = requests . post ( url , json = data )
assert response . status_code == 200
def send_requests_abort ( self ) :
2025-03-06 00:13:20 -08:00
for prompt in self . PROMPTS :
2025-01-21 02:55:14 -08:00
try :
time . sleep ( random . uniform ( 0 , 2 ) )
url = self . base_url + " /generate "
data = {
" model " : " base " ,
" text " : prompt ,
" sampling_params " : {
" temperature " : 0 ,
" max_new_tokens " : 1024 ,
} ,
}
2025-03-04 13:40:40 -08:00
# set timeout = 1s, mock disconnected
2025-01-21 02:55:14 -08:00
requests . post ( url , json = data , timeout = 1 )
except Exception as e :
print ( e )
pass
def test_request_abort ( self ) :
2025-01-11 06:00:43 +08:00
concurrency = 4
2025-01-21 02:55:14 -08:00
threads = [
threading . Thread ( target = self . send_request ) for _ in range ( concurrency )
2025-01-11 06:00:43 +08:00
] + [
2025-01-21 02:55:14 -08:00
threading . Thread ( target = self . send_requests_abort )
2025-01-11 06:00:43 +08:00
for _ in range ( concurrency )
]
2025-01-21 02:55:14 -08:00
for worker in threads :
2025-01-11 06:00:43 +08:00
worker . start ( )
2025-01-21 02:55:14 -08:00
for p in threads :
2025-01-11 06:00:43 +08:00
p . join ( )
2025-03-07 22:12:13 -08:00
def test_max_token_one ( self ) :
requests . get ( self . base_url + " /flush_cache " )
args = SimpleNamespace (
num_shots = 5 ,
data_path = None ,
num_questions = 200 ,
max_new_tokens = 1 ,
parallel = 128 ,
host = " http://127.0.0.1 " ,
port = int ( self . base_url . split ( " : " ) [ - 1 ] ) ,
)
# Just run and check it does not hang
metrics = run_eval ( args )
self . assertGreater ( metrics [ " output_throughput " ] , 50 )
2025-01-21 02:55:14 -08:00
def test_gsm8k ( self ) :
2025-03-07 22:12:13 -08:00
requests . get ( self . base_url + " /flush_cache " )
2025-03-06 00:13:20 -08:00
2025-01-21 02:55:14 -08:00
args = SimpleNamespace (
num_shots = 5 ,
data_path = None ,
num_questions = 200 ,
max_new_tokens = 512 ,
parallel = 128 ,
host = " http://127.0.0.1 " ,
port = int ( self . base_url . split ( " : " ) [ - 1 ] ) ,
)
2025-03-06 00:13:20 -08:00
2025-01-21 02:55:14 -08:00
metrics = run_eval ( args )
print ( f " { metrics =} " )
self . assertGreater ( metrics [ " accuracy " ] , 0.20 )
2025-03-30 00:46:23 -07:00
server_info = requests . get ( self . base_url + " /get_server_info " ) . json ( )
2025-05-12 00:17:33 -07:00
avg_spec_accept_length = server_info [ " internal_states " ] [ 0 ] [
" avg_spec_accept_length "
]
2025-03-06 00:13:20 -08:00
print ( f " { avg_spec_accept_length =} " )
2025-03-30 00:46:23 -07:00
speculative_eagle_topk = server_info [ " speculative_eagle_topk " ]
if speculative_eagle_topk == 1 :
self . assertGreater ( avg_spec_accept_length , 2.5 )
else :
self . assertGreater ( avg_spec_accept_length , 3.5 )
2025-01-11 06:00:43 +08:00
2025-03-06 00:13:20 -08:00
# Wait a little bit so that the memory check happens.
time . sleep ( 4 )
2025-03-05 08:06:07 -08:00
2025-03-06 06:13:59 -08:00
def test_logprob_start_len ( self ) :
logprob_start_len = 4
new_tokens = 4
prompts = [
" I have a very good idea on " ,
" Today is a sunndy day and " ,
]
response = requests . post (
self . base_url + " /generate " ,
json = {
" text " : prompts ,
" sampling_params " : {
" temperature " : 0 ,
" max_new_tokens " : new_tokens ,
} ,
" return_logprob " : True ,
" top_logprobs_num " : 5 ,
" logprob_start_len " : logprob_start_len ,
} ,
)
response_json = response . json ( )
print ( json . dumps ( response_json , indent = 2 ) )
for res in response_json :
self . assertEqual (
res [ " meta_info " ] [ " prompt_tokens " ] ,
logprob_start_len + len ( res [ " meta_info " ] [ " input_token_logprobs " ] ) ,
)
self . assertEqual ( res [ " meta_info " ] [ " completion_tokens " ] , new_tokens )
self . assertEqual ( len ( res [ " meta_info " ] [ " output_token_logprobs " ] ) , new_tokens )
def test_logprob_match ( self ) :
""" Test the output logprobs are close to the input logprobs if we run a prefill again. """
def run_generate (
prompt , return_logprob = False , max_new_tokens = 512 , logprob_start_len = - 1
) :
if isinstance ( prompt , str ) :
prompt_kwargs = { " text " : prompt }
else :
prompt_kwargs = { " input_ids " : prompt }
response = requests . post (
self . base_url + " /generate " ,
json = {
* * prompt_kwargs ,
" sampling_params " : {
" temperature " : 1.0 ,
" max_new_tokens " : max_new_tokens ,
" ignore_eos " : True ,
} ,
" return_logprob " : return_logprob ,
" return_text_in_logprobs " : True ,
" logprob_start_len " : logprob_start_len ,
} ,
)
return response . json ( )
prompt = " I have a very good idea on how to "
gen = run_generate ( prompt , return_logprob = True , logprob_start_len = 0 )
output_logprobs = np . array (
[ x [ 0 ] for x in gen [ " meta_info " ] [ " output_token_logprobs " ] ]
)
num_prompts_tokens = gen [ " meta_info " ] [ " prompt_tokens " ]
input_tokens = [ x [ 1 ] for x in gen [ " meta_info " ] [ " input_token_logprobs " ] ]
output_tokens = [ x [ 1 ] for x in gen [ " meta_info " ] [ " output_token_logprobs " ] ]
new_prompt = input_tokens + output_tokens
score = run_generate (
new_prompt , return_logprob = True , logprob_start_len = 0 , max_new_tokens = 0
)
output_logprobs_score = np . array (
[
x [ 0 ]
for x in score [ " meta_info " ] [ " input_token_logprobs " ] [ num_prompts_tokens : ]
]
)
print ( f " { output_logprobs [ - 10 : ] =} " )
print ( f " { output_logprobs_score [ - 10 : ] =} " )
diff = np . abs ( output_logprobs - output_logprobs_score )
max_diff = np . max ( diff )
self . assertLess ( max_diff , 0.25 )
def test_logprob_mixed ( self ) :
args = [ ]
temperature = 0
# input_len, output_len, temperature, logprob_start_len, return_logprob, top_logprobs_num
# Llama 2 context length seems to be only 2k, so we can only test small length.
for input_len in [ 200 , 500 , 1000 , 2000 ] :
for output_len in [ 4 , 8 ] :
for logprob_start_len in [ 0 , 100 , 300 , 800 , 1998 ] :
for return_logprob in [ True , False ] :
for top_logprobs_num in [ 0 , 5 ] :
if logprob_start_len > = input_len :
continue
args . append (
(
input_len ,
output_len ,
temperature ,
logprob_start_len ,
return_logprob ,
top_logprobs_num ,
)
)
random . shuffle ( args )
func = partial ( run_logprob_check , self )
with ThreadPoolExecutor ( 8 ) as executor :
list ( executor . map ( func , args ) )
2025-03-07 22:12:13 -08:00
def run_decode ( self , sampling_params ) :
return_logprob = True
top_logprobs_num = 5
return_text = True
n = 1
response = requests . post (
self . base_url + " /generate " ,
json = {
" text " : " Human: Write a travel blog post to Hawaii. \n \n Assistant: " ,
" sampling_params " : {
" max_new_tokens " : 48 ,
" n " : n ,
" temperature " : 0.7 ,
* * sampling_params ,
} ,
" return_logprob " : return_logprob ,
" top_logprobs_num " : top_logprobs_num ,
" return_text_in_logprobs " : return_text ,
" logprob_start_len " : 0 ,
} ,
)
self . assertEqual ( response . status_code , 200 )
print ( json . dumps ( response . json ( ) ) )
print ( " = " * 100 )
def test_penalty_mixed ( self ) :
args = [
{ } ,
{ } ,
{ } ,
{ " frequency_penalty " : 2 } ,
{ " presence_penalty " : 1 } ,
{ " min_new_tokens " : 16 } ,
{ " frequency_penalty " : 0.2 } ,
{ " presence_penalty " : 0.4 } ,
{ " min_new_tokens " : 8 } ,
{ " frequency_penalty " : 0.4 , " presence_penalty " : 0.8 } ,
{ " frequency_penalty " : 0.4 , " min_new_tokens " : 12 } ,
{ " presence_penalty " : 0.8 , " min_new_tokens " : 12 } ,
{ " presence_penalty " : - 0.3 , " frequency_penalty " : 1.3 , " min_new_tokens " : 32 } ,
{ " presence_penalty " : 0.3 , " frequency_penalty " : - 1.3 , " min_new_tokens " : 32 } ,
]
random . shuffle ( args * 5 )
with ThreadPoolExecutor ( 8 ) as executor :
list ( executor . map ( self . run_decode , args ) )
2025-05-22 08:18:41 +08:00
def test_constrained_decoding ( self ) :
messages = [
{ " role " : " system " , " content " : " You are a helpful assistant. " } ,
{ " role " : " user " , " content " : " Give me a json " } ,
]
response = requests . post (
self . base_url + " /v1/chat/completions " ,
json = {
" model " : DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST ,
" messages " : messages ,
" temperature " : 0 ,
" response_format " : { " type " : " json_object " } ,
} ,
)
self . assertEqual ( response . status_code , 200 )
res = response . json ( )
# Validate response structure
self . assertIn ( " choices " , res )
self . assertEqual ( len ( res [ " choices " ] ) , 1 )
self . assertIn ( " message " , res [ " choices " ] [ 0 ] )
self . assertIn ( " content " , res [ " choices " ] [ 0 ] [ " message " ] )
# Validate JSON content
content_json = res [ " choices " ] [ 0 ] [ " message " ] [ " content " ]
is_valid_json = True
try :
content = json . loads ( content_json )
self . assertIsInstance ( content , dict )
except Exception :
print ( f " parse JSON failed: { content_json } " )
is_valid_json = False
self . assertTrue ( is_valid_json )
2025-03-05 08:06:07 -08:00
2025-03-06 00:13:20 -08:00
class TestEAGLERetract ( TestEAGLEServer ) :
2025-03-05 08:06:07 -08:00
@classmethod
def setUpClass ( cls ) :
2025-03-06 00:13:20 -08:00
# These config helps find a leak.
os . environ [ " SGLANG_CI_SMALL_KV_SIZE " ] = " 4500 "
2025-03-05 08:06:07 -08:00
cls . base_url = DEFAULT_URL_FOR_TEST
cls . process = popen_launch_server (
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST ,
cls . base_url ,
timeout = DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH ,
other_args = [
" --speculative-algorithm " ,
" EAGLE " ,
" --speculative-draft-model-path " ,
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST ,
" --speculative-num-steps " ,
2025-03-06 00:13:20 -08:00
5 ,
2025-03-05 08:06:07 -08:00
" --speculative-eagle-topk " ,
2025-03-06 00:13:20 -08:00
8 ,
2025-03-05 08:06:07 -08:00
" --speculative-num-draft-tokens " ,
2025-03-06 00:13:20 -08:00
64 ,
2025-03-05 08:06:07 -08:00
" --mem-fraction-static " ,
2025-03-06 00:13:20 -08:00
0.7 ,
2025-03-05 08:06:07 -08:00
" --chunked-prefill-size " ,
2025-03-06 00:13:20 -08:00
128 ,
2025-03-05 08:06:07 -08:00
" --max-running-requests " ,
2025-03-06 00:13:20 -08:00
64 ,
2025-03-05 08:06:07 -08:00
] ,
)
2025-02-10 20:00:42 +08:00
class TestEAGLEServerTriton ( TestEAGLEServer ) :
@classmethod
def setUpClass ( cls ) :
cls . base_url = DEFAULT_URL_FOR_TEST
cls . process = popen_launch_server (
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST ,
cls . base_url ,
timeout = DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH ,
other_args = [
" --speculative-algorithm " ,
" EAGLE " ,
" --speculative-draft-model-path " ,
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST ,
" --speculative-num-steps " ,
2025-03-06 00:13:20 -08:00
5 ,
2025-02-10 20:00:42 +08:00
" --speculative-eagle-topk " ,
2025-03-06 00:13:20 -08:00
8 ,
2025-02-10 20:00:42 +08:00
" --speculative-num-draft-tokens " ,
2025-03-06 00:13:20 -08:00
64 ,
2025-02-10 20:00:42 +08:00
" --mem-fraction-static " ,
2025-03-06 00:13:20 -08:00
0.7 ,
2025-02-10 20:00:42 +08:00
" --attention-backend " ,
" triton " ,
2025-03-06 00:13:20 -08:00
" --max-running-requests " ,
8 ,
2025-02-10 20:00:42 +08:00
] ,
)
2025-05-28 07:32:05 +08:00
class TestEAGLEDraftExtend ( CustomTestCase ) :
@classmethod
def setUpClass ( cls ) :
cls . base_url = DEFAULT_URL_FOR_TEST
cls . process = popen_launch_server (
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST ,
cls . base_url ,
timeout = DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH ,
other_args = [
" --speculative-algorithm " ,
" EAGLE " ,
" --speculative-draft-model-path " ,
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST ,
" --speculative-num-steps " ,
1 ,
" --speculative-eagle-topk " ,
1 ,
" --speculative-num-draft-tokens " ,
2 ,
" --max-running-requests " ,
4 ,
" --attention-backend " ,
" fa3 " ,
] ,
)
@classmethod
def tearDownClass ( cls ) :
kill_process_tree ( cls . process . pid )
def test_one_batch_accept_length ( self ) :
prompts = [
" Hello, my name is " ,
" The president of the United States is " ,
" The capital of France is " ,
" The future of AI is " ,
]
url = self . base_url + " /generate "
data = {
" text " : prompts ,
" sampling_params " : {
" temperature " : 0 ,
" max_new_tokens " : 512 ,
} ,
}
response = requests . post ( url , json = data )
self . assertEqual ( response . status_code , 200 )
outputs = response . json ( )
for i in range ( len ( prompts ) ) :
output = outputs [ i ]
if " spec_verify_ct " in output [ " meta_info " ] :
acc_length = (
output [ " meta_info " ] [ " completion_tokens " ]
/ output [ " meta_info " ] [ " spec_verify_ct " ]
)
else :
acc_length = 1.0
print ( f " { acc_length =} " )
self . assertGreater ( acc_length , 1.50 )
2025-01-02 19:22:34 +08:00
if __name__ == " __main__ " :
unittest . main ( )