2024-01-25 01:16:25 +08:00
import argparse
import json
import time
from concurrent . futures import ThreadPoolExecutor
from functools import partial
import guidance
2024-04-28 21:06:22 +08:00
from tqdm import tqdm
2024-05-05 16:14:17 +08:00
from sglang . test . test_utils import add_common_other_args_and_parse , get_call_generate
2024-02-01 13:38:47 +08:00
from sglang . utils import dump_state_text , read_jsonl
2024-01-25 01:16:25 +08:00
# there are some FSM bugs with json regex converted from pydantic model
# here use a string regex instead
# regex_string = build_regex_from_object(HarryPoterRole)
character_regex = (
r """ \ { \ n """
+ r """ " name " : " [ \ w \ d \ s] { 1,16} " , \ n """
+ r """ " house " : " (Gryffindor|Slytherin|Ravenclaw|Hufflepuff) " , \ n """
+ r """ " blood status " : " (Pure-blood|Half-blood|Muggle-born) " , \ n """
+ r """ " occupation " : " (student|teacher|auror|ministry of magic|death eater|order of the phoenix) " , \ n """
+ r """ " wand " : \ { \ n """
+ r """ " wood " : " [ \ w \ d \ s] { 1,16} " , \ n """
+ r """ " core " : " [ \ w \ d \ s] { 1,16} " , \ n """
+ r """ " length " : [0-9] { 1,2} \ .[0-9] { 0,2} \ n """
+ r """ \ }, \ n """
+ r """ " alive " : " (Alive|Deceased) " , \ n """
+ r """ " patronus " : " [ \ w \ d \ s] { 1,16} " , \ n """
+ r """ " bogart " : " [ \ w \ d \ s] { 1,16} " \ n """
+ r """ \ } """
)
2024-02-01 13:38:47 +08:00
city_regex = (
r """ \ { \ n """
+ r """ " name " : " [ \ w \ d \ s] { 1,16} " , \ n """
+ r """ " country " : " [ \ w \ d \ s] { 1,16} " , \ n """
+ r """ " latitude " : [-+]?[0-9]* \ .?[0-9] { 0,2}, \ n """
+ r """ " population " : [-+]?[0-9] { 1,9}, \ n """
+ r """ " top 3 landmarks " : \ [ " [ \ w \ d \ s] { 1,16} " , " [ \ w \ d \ s] { 1,16} " , " [ \ w \ d \ s] { 1,16} " \ ] \ n """
+ r """ \ } """
)
2024-01-25 01:16:25 +08:00
# fmt: off
def character_gen ( name , generate ) :
2024-01-30 05:45:27 -08:00
s = name + " is a character in Harry Potter. Please fill in the following information about this character. \n "
2024-01-25 01:16:25 +08:00
s + = generate ( s , max_tokens = 256 , regex = character_regex )
return s
# fmt: on
2024-02-01 13:38:47 +08:00
# fmt: off
def city_gen ( document , generate ) :
s = " Please extract the information of a city from the following wikipedia page. \n "
s + = " Page begin. \n " + document + " Page end. \n "
s + = " Here is the name, country, and symbol of the city in JSON format. \n "
s + = generate ( s , max_tokens = 256 , regex = city_regex )
return s
# fmt: on
2024-01-25 01:16:25 +08:00
@guidance
def character_maker ( lm , name ) :
regex_str_no_quote = r " [ \ w \ d \ s]+ "
regex_float = r " [0-9]+ \ .[0-9]+ "
lm + = f """ \
2024-02-05 11:22:06 +00:00
{ name } is a character in Harry Potter . Please fill in the following information about this character .
2024-01-25 01:16:25 +08:00
{ {
" name " : " { guidance.gen( " name " , max_tokens=16, regex=regex_str_no_quote)} " ,
" house " : " { guidance.select(options=[ ' Gryffindor ' , ' Slytherin ' , ' Ravenclaw ' , ' Hufflepuff ' ], name= ' house ' )} " ,
" blood status " : " { guidance.select(options=[ ' Pure-blood ' , ' Half-blood ' , ' Muggle-born ' ], name= ' blood status ' )} " ,
" occupation " : " { guidance.select(options=[ ' student ' , ' teacher ' , ' auror ' , ' ministry of magic ' , ' death eater ' , ' order of the phoenix ' ], name= ' occupation ' )} " ,
" wand " : { {
" wood " : " { guidance.gen( " wood " , max_tokens=16, regex=regex_str_no_quote)} " ,
" core " : " { guidance.gen( ' core ' , max_tokens=16, regex=regex_str_no_quote)} " ,
" length " : { guidance . gen ( ' length ' , max_tokens = 10 , regex = regex_float ) }
} } ,
" alive " : " { guidance.select(options=[ ' Alive ' , ' Deceased ' ], name= ' alive ' )} " ,
" patronus " : " { guidance.gen( ' patronus ' , max_tokens=16, regex=regex_str_no_quote)} " ,
" bogart " : " { guidance.gen( ' bogart ' , max_tokens=16, regex=regex_str_no_quote)} "
} }
"""
return lm
2024-05-05 16:14:17 +08:00
async def call_generate_lmql (
prompt , temperature , max_tokens , regex , max_len = 4096 , model = None , * * kwargs
) :
assert model is not None
import lmql
@lmql.query ( model = model )
async def program ( question , max_tokens , regex ) :
''' lmql
""" {question} [ANSWER] """ where len ( TOKENS ( ANSWER ) ) < max_tokens and REGEX ( ANSWER , regex )
return ANSWER
'''
return await program (
question = prompt ,
temperature = temperature ,
max_tokens = max_tokens ,
max_len = max_len ,
regex = regex ,
* * kwargs ,
)
2024-02-01 13:38:47 +08:00
@guidance
def city_maker ( lm , document ) :
regex_str_no_quote = r " [ \ w \ d \ s]+ "
regex_float = r " [0-9]+ \ .[0-9]+ "
lm + = f """ \
Please extract the information of a city from the following wikipedia page .
Page begin .
{ document }
Page end .
Here is the name , country , and symbol of the city in JSON format .
{ {
" name " : " { guidance.gen( " name " , max_tokens=16, regex=regex_str_no_quote)} " ,
" country " : " { guidance.gen( " country " , max_tokens=16, regex=regex_str_no_quote)} " ,
" latitude " : { guidance . gen ( " latitude " , max_tokens = 10 , regex = regex_float ) } ,
" population " : { guidance . gen ( " population " , max_tokens = 10 , regex = r " [0-9]+ " ) } ,
" top 3 landmarks " : [
" { guidance.gen( " landmark1 " , max_tokens=16, regex=regex_str_no_quote)} " , " { guidance.gen( " landmark2 " , max_tokens=16, regex=regex_str_no_quote)} " , " { guidance.gen( " landmark3 " , max_tokens=16, regex=regex_str_no_quote)} "
]
} }
"""
return lm
def bench_character ( args ) :
2024-01-25 01:16:25 +08:00
arguments = [ ]
with open ( args . data_path , " r " ) as f :
for line in f :
arguments . append ( { " name " : line . strip ( ) } )
arguments = arguments [ : args . num_jsons ]
states = [ None ] * len ( arguments )
# Select backend
2024-05-05 16:14:17 +08:00
if args . backend == " outlines " :
call_generate = partial ( get_call_generate ( args ) , temperature = 0 )
2024-01-25 01:16:25 +08:00
2024-05-05 16:14:17 +08:00
def get_one_answer ( i ) :
states [ i ] = character_gen ( * * arguments [ i ] , generate = call_generate )
2024-01-25 01:16:25 +08:00
elif args . backend == " guidance " :
model = guidance . models . LlamaCpp (
2024-05-05 16:14:17 +08:00
args . model_path ,
2024-01-25 01:16:25 +08:00
n_gpu_layers = - 1 ,
2024-05-05 16:14:17 +08:00
n_ctx = args . n_ctx ,
2024-01-25 01:16:25 +08:00
)
2024-05-05 16:14:17 +08:00
def get_one_answer ( i ) :
2024-01-25 01:16:25 +08:00
lm = model + character_maker ( * * arguments [ i ] )
states [ i ] = lm
2024-05-05 16:14:17 +08:00
elif args . backend == " lmql " :
import asyncio
import lmql
model = lmql . model ( args . model_path , endpoint = f " { args . host } : { args . port } " )
call_generate = partial (
call_generate_lmql ,
model = model ,
max_tokens = 256 ,
regex = character_regex ,
)
async def get_one_answer_async ( i ) :
states [ i ] = await call_generate ( prompt = arguments [ i ] [ " name " ] , temperature = 0 )
2024-01-25 01:16:25 +08:00
else :
raise ValueError ( f " Invalid backend: { args . backend } " )
tic = time . time ( )
2024-05-05 16:14:17 +08:00
if args . backend != " lmql " :
if args . parallel == 1 :
for i in tqdm ( range ( len ( arguments ) ) ) :
get_one_answer ( i )
else :
with ThreadPoolExecutor ( args . parallel ) as executor :
rets = list (
tqdm (
executor . map ( get_one_answer , list ( range ( len ( arguments ) ) ) ) ,
total = len ( arguments ) ,
)
)
for _ in rets :
pass
2024-01-25 01:16:25 +08:00
else :
2024-05-05 16:14:17 +08:00
batches = [ ]
for i in range ( 0 , len ( arguments ) , args . parallel ) :
batches . append ( list ( range ( i , min ( i + args . parallel , len ( arguments ) ) ) ) )
loop = asyncio . get_event_loop ( )
for bt in tqdm ( batches ) :
loop . run_until_complete (
asyncio . gather ( * [ get_one_answer_async ( i ) for i in bt ] )
)
2024-01-25 01:16:25 +08:00
latency = time . time ( ) - tic
2024-02-01 13:38:47 +08:00
return states , latency
def bench_city_doc ( args ) :
arguments = [ ]
for line in read_jsonl ( args . data_path ) :
arguments . append ( { " document " : line [ " document " ] } )
arguments = arguments [ : args . num_jsons ]
states = [ None ] * len ( arguments )
# Select backend
2024-05-05 16:14:17 +08:00
if args . backend == " outlines " :
call_generate = partial ( get_call_generate ( args ) , temperature = 0 )
2024-02-01 13:38:47 +08:00
2024-05-05 16:14:17 +08:00
def get_one_answer ( i ) :
states [ i ] = city_gen ( * * arguments [ i ] , generate = call_generate )
2024-02-01 13:38:47 +08:00
elif args . backend == " guidance " :
model = guidance . models . LlamaCpp (
2024-05-05 16:14:17 +08:00
args . model_path ,
2024-02-01 13:38:47 +08:00
n_gpu_layers = - 1 ,
2024-05-05 16:14:17 +08:00
n_ctx = args . n_ctx ,
2024-02-01 13:38:47 +08:00
)
2024-05-05 16:14:17 +08:00
def get_one_answer ( i ) :
2024-02-01 13:38:47 +08:00
lm = model + city_maker ( * * arguments [ i ] )
states [ i ] = lm
else :
raise ValueError ( f " Invalid backend: { args . backend } " )
tic = time . time ( )
if args . parallel == 1 :
for i in tqdm ( range ( len ( arguments ) ) ) :
get_one_answer ( i )
else :
with ThreadPoolExecutor ( args . parallel ) as executor :
rets = executor . map ( get_one_answer , list ( range ( len ( arguments ) ) ) )
for _ in rets :
pass
latency = time . time ( ) - tic
return states , latency
def main ( args ) :
if args . mode == " character " :
args . data_path = " dataset.txt "
states , latency = bench_character ( args )
elif args . mode == " city " :
args . data_path = " questions.jsonl "
states , latency = bench_city_doc ( args )
2024-01-25 01:16:25 +08:00
# Compute accuracy
print ( f " Latency: { latency : .3f } " )
# Write results
2024-02-01 13:38:47 +08:00
dump_state_text ( f " tmp_output_ { args . backend } _ { args . mode } .txt " , states )
2024-01-25 01:16:25 +08:00
with open ( args . result_file , " a " ) as fout :
value = {
2024-02-05 16:50:37 +08:00
" task " : " json_jump_forward " ,
2024-01-25 01:16:25 +08:00
" backend " : args . backend ,
" latency " : round ( latency , 3 ) ,
" num_jsons " : args . num_jsons ,
2024-02-02 09:58:24 +00:00
" mode " : args . mode ,
2024-01-25 01:16:25 +08:00
" parallel " : args . parallel ,
}
fout . write ( json . dumps ( value ) + " \n " )
if __name__ == " __main__ " :
parser = argparse . ArgumentParser ( )
2024-02-01 13:38:47 +08:00
parser . add_argument ( " --data-path " , type = str )
2024-01-25 01:16:25 +08:00
parser . add_argument ( " --num-jsons " , type = int , default = 50 )
2024-02-01 13:38:47 +08:00
parser . add_argument (
" --mode " , type = str , default = " character " , choices = [ " character " , " city " ]
)
2024-01-25 01:16:25 +08:00
args = add_common_other_args_and_parse ( parser )
main ( args )