Add skip_tokenizer_init args. (#959)
Co-authored-by: lzhang <zhanglei@modelbest.cn>
This commit is contained in:
@@ -20,10 +20,20 @@ from sglang.srt.constrained.base_tool_cache import BaseToolCache
|
|||||||
|
|
||||||
|
|
||||||
class FSMCache(BaseToolCache):
|
class FSMCache(BaseToolCache):
|
||||||
def __init__(self, tokenizer_path, tokenizer_args_dict, enable=True):
|
def __init__(
|
||||||
|
self,
|
||||||
|
tokenizer_path,
|
||||||
|
tokenizer_args_dict,
|
||||||
|
enable=True,
|
||||||
|
skip_tokenizer_init=False,
|
||||||
|
):
|
||||||
super().__init__(enable=enable)
|
super().__init__(enable=enable)
|
||||||
|
|
||||||
if tokenizer_path.endswith(".json") or tokenizer_path.endswith(".model"):
|
if (
|
||||||
|
skip_tokenizer_init
|
||||||
|
or tokenizer_path.endswith(".json")
|
||||||
|
or tokenizer_path.endswith(".model")
|
||||||
|
):
|
||||||
# Do not support TiktokenTokenizer or SentencePieceTokenizer
|
# Do not support TiktokenTokenizer or SentencePieceTokenizer
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
@@ -59,11 +59,14 @@ class DetokenizerManager:
|
|||||||
self.send_to_tokenizer = context.socket(zmq.PUSH)
|
self.send_to_tokenizer = context.socket(zmq.PUSH)
|
||||||
self.send_to_tokenizer.connect(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
|
self.send_to_tokenizer.connect(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
|
||||||
|
|
||||||
self.tokenizer = get_tokenizer(
|
if server_args.skip_tokenizer_init:
|
||||||
server_args.tokenizer_path,
|
self.tokenizer = None
|
||||||
tokenizer_mode=server_args.tokenizer_mode,
|
else:
|
||||||
trust_remote_code=server_args.trust_remote_code,
|
self.tokenizer = get_tokenizer(
|
||||||
)
|
server_args.tokenizer_path,
|
||||||
|
tokenizer_mode=server_args.tokenizer_mode,
|
||||||
|
trust_remote_code=server_args.trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
self.decode_status = {}
|
self.decode_status = {}
|
||||||
|
|
||||||
@@ -85,6 +88,11 @@ class DetokenizerManager:
|
|||||||
assert isinstance(recv_obj, BatchTokenIDOut)
|
assert isinstance(recv_obj, BatchTokenIDOut)
|
||||||
bs = len(recv_obj.rids)
|
bs = len(recv_obj.rids)
|
||||||
|
|
||||||
|
if self.tokenizer is None:
|
||||||
|
# Send BatchTokenIDOut if no tokenizer init'ed.
|
||||||
|
self.send_to_tokenizer.send_pyobj(recv_obj)
|
||||||
|
continue
|
||||||
|
|
||||||
# Initialize decode status
|
# Initialize decode status
|
||||||
read_ids, surr_ids = [], []
|
read_ids, surr_ids = [], []
|
||||||
for i in range(bs):
|
for i in range(bs):
|
||||||
|
|||||||
@@ -195,6 +195,8 @@ class Req:
|
|||||||
return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
|
return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
|
||||||
|
|
||||||
def get_next_inc_detokenization(self):
|
def get_next_inc_detokenization(self):
|
||||||
|
if self.tokenizer is None:
|
||||||
|
return False, ""
|
||||||
read_ids, read_offset = self.init_incremental_detokenize()
|
read_ids, read_offset = self.init_incremental_detokenize()
|
||||||
surr_ids = read_ids[:read_offset]
|
surr_ids = read_ids[:read_offset]
|
||||||
|
|
||||||
@@ -225,16 +227,11 @@ class Req:
|
|||||||
return
|
return
|
||||||
|
|
||||||
last_token_id = self.output_ids[-1]
|
last_token_id = self.output_ids[-1]
|
||||||
if (
|
if self.tokenizer is None:
|
||||||
last_token_id == self.tokenizer.eos_token_id
|
matched_eos = last_token_id in self.sampling_params.stop_token_ids
|
||||||
and not self.sampling_params.ignore_eos
|
else:
|
||||||
):
|
matched_eos = last_token_id == self.tokenizer.eos_token_id
|
||||||
self.finished_reason = FINISH_MATCHED_TOKEN(
|
if matched_eos and not self.sampling_params.ignore_eos:
|
||||||
matched=self.tokenizer.eos_token_id
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
if last_token_id in self.sampling_params.stop_token_ids:
|
|
||||||
self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
|
self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
@@ -95,25 +95,28 @@ class TokenizerManager:
|
|||||||
else:
|
else:
|
||||||
self.context_len = get_context_length(self.hf_config)
|
self.context_len = get_context_length(self.hf_config)
|
||||||
|
|
||||||
if is_multimodal_model(self.model_path):
|
if server_args.skip_tokenizer_init:
|
||||||
self.processor = get_processor(
|
self.tokenizer = self.processor = None
|
||||||
server_args.tokenizer_path,
|
|
||||||
tokenizer_mode=server_args.tokenizer_mode,
|
|
||||||
trust_remote_code=server_args.trust_remote_code,
|
|
||||||
)
|
|
||||||
self.tokenizer = self.processor.tokenizer
|
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
||||||
self.executor = concurrent.futures.ProcessPoolExecutor(
|
|
||||||
initializer=init_global_processor,
|
|
||||||
mp_context=mp.get_context("fork"),
|
|
||||||
initargs=(server_args,),
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
self.tokenizer = get_tokenizer(
|
if is_multimodal_model(self.model_path):
|
||||||
server_args.tokenizer_path,
|
self.processor = get_processor(
|
||||||
tokenizer_mode=server_args.tokenizer_mode,
|
server_args.tokenizer_path,
|
||||||
trust_remote_code=server_args.trust_remote_code,
|
tokenizer_mode=server_args.tokenizer_mode,
|
||||||
)
|
trust_remote_code=server_args.trust_remote_code,
|
||||||
|
)
|
||||||
|
self.tokenizer = self.processor.tokenizer
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
self.executor = concurrent.futures.ProcessPoolExecutor(
|
||||||
|
initializer=init_global_processor,
|
||||||
|
mp_context=mp.get_context("fork"),
|
||||||
|
initargs=(server_args,),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.tokenizer = get_tokenizer(
|
||||||
|
server_args.tokenizer_path,
|
||||||
|
tokenizer_mode=server_args.tokenizer_mode,
|
||||||
|
trust_remote_code=server_args.trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
self.to_create_loop = True
|
self.to_create_loop = True
|
||||||
self.rid_to_state: Dict[str, ReqState] = {}
|
self.rid_to_state: Dict[str, ReqState] = {}
|
||||||
@@ -171,6 +174,7 @@ class TokenizerManager:
|
|||||||
rid = obj.rid if not_use_index else obj.rid[index]
|
rid = obj.rid if not_use_index else obj.rid[index]
|
||||||
input_text = obj.text if not_use_index else obj.text[index]
|
input_text = obj.text if not_use_index else obj.text[index]
|
||||||
if obj.input_ids is None:
|
if obj.input_ids is None:
|
||||||
|
assert self.tokenizer is not None
|
||||||
input_ids = self.tokenizer.encode(input_text)
|
input_ids = self.tokenizer.encode(input_text)
|
||||||
else:
|
else:
|
||||||
input_ids = obj.input_ids if not_use_index else obj.input_ids[index]
|
input_ids = obj.input_ids if not_use_index else obj.input_ids[index]
|
||||||
@@ -207,7 +211,20 @@ class TokenizerManager:
|
|||||||
else:
|
else:
|
||||||
input_text = obj.text
|
input_text = obj.text
|
||||||
rid = obj.rid[0]
|
rid = obj.rid[0]
|
||||||
input_ids = self.tokenizer.encode(input_text)
|
if self.tokenizer is not None:
|
||||||
|
input_ids = self.tokenizer.encode(input_text)
|
||||||
|
else:
|
||||||
|
assert obj.input_ids is not None
|
||||||
|
input_ids = obj.input_ids
|
||||||
|
if isinstance(obj.input_ids, list) and isinstance(
|
||||||
|
obj.input_ids[0], list
|
||||||
|
):
|
||||||
|
# when obj["input_ids"] is List[List[int]]
|
||||||
|
input_ids = obj.input_ids[index]
|
||||||
|
rid = obj.rid[index]
|
||||||
|
else:
|
||||||
|
input_ids = obj.input_ids
|
||||||
|
rid = obj.rid[0]
|
||||||
else:
|
else:
|
||||||
input_text = None
|
input_text = None
|
||||||
if isinstance(obj.input_ids, list) and isinstance(
|
if isinstance(obj.input_ids, list) and isinstance(
|
||||||
@@ -420,7 +437,7 @@ class TokenizerManager:
|
|||||||
# Log requests
|
# Log requests
|
||||||
if self.server_args.log_requests and state.finished:
|
if self.server_args.log_requests and state.finished:
|
||||||
if obj.text is None:
|
if obj.text is None:
|
||||||
in_obj = {"text": self.tokenizer.decode(obj.input_ids)}
|
in_obj = {"input_ids": obj.input_ids}
|
||||||
else:
|
else:
|
||||||
in_obj = {"text": obj.text}
|
in_obj = {"text": obj.text}
|
||||||
logger.info(f"in={in_obj}, out={out}")
|
logger.info(f"in={in_obj}, out={out}")
|
||||||
@@ -488,11 +505,12 @@ class TokenizerManager:
|
|||||||
|
|
||||||
async def handle_loop(self):
|
async def handle_loop(self):
|
||||||
while True:
|
while True:
|
||||||
recv_obj: Union[BatchStrOut, BatchEmbeddingOut] = (
|
recv_obj: Union[BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut] = (
|
||||||
await self.recv_from_detokenizer.recv_pyobj()
|
await self.recv_from_detokenizer.recv_pyobj()
|
||||||
)
|
)
|
||||||
assert isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut))
|
assert isinstance(
|
||||||
|
recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
|
||||||
|
), f"Unexpected obj received: {type(recv_obj)}"
|
||||||
for i, rid in enumerate(recv_obj.rids):
|
for i, rid in enumerate(recv_obj.rids):
|
||||||
state = self.rid_to_state.get(rid, None)
|
state = self.rid_to_state.get(rid, None)
|
||||||
if state is None:
|
if state is None:
|
||||||
@@ -504,6 +522,15 @@ class TokenizerManager:
|
|||||||
"text": recv_obj.output_strs[i],
|
"text": recv_obj.output_strs[i],
|
||||||
"meta_info": recv_obj.meta_info[i],
|
"meta_info": recv_obj.meta_info[i],
|
||||||
}
|
}
|
||||||
|
elif isinstance(recv_obj, BatchTokenIDOut):
|
||||||
|
read_start = 0 if i == 0 else recv_obj.read_offsets[i - 1]
|
||||||
|
out_dict = {
|
||||||
|
"token_ids": recv_obj.decode_ids[
|
||||||
|
read_start : recv_obj.read_offsets[i]
|
||||||
|
],
|
||||||
|
"meta_info": recv_obj.meta_info[i],
|
||||||
|
}
|
||||||
|
|
||||||
else:
|
else:
|
||||||
assert isinstance(recv_obj, BatchEmbeddingOut)
|
assert isinstance(recv_obj, BatchEmbeddingOut)
|
||||||
out_dict = {
|
out_dict = {
|
||||||
@@ -549,6 +576,7 @@ class TokenizerManager:
|
|||||||
if not decode_to_text:
|
if not decode_to_text:
|
||||||
return [(logprob, token_id, None) for logprob, token_id in token_logprobs]
|
return [(logprob, token_id, None) for logprob, token_id in token_logprobs]
|
||||||
|
|
||||||
|
assert self.tokenizer is not None
|
||||||
token_ids = [tid for _, tid in token_logprobs]
|
token_ids = [tid for _, tid in token_logprobs]
|
||||||
token_texts = self.tokenizer.batch_decode(token_ids)
|
token_texts = self.tokenizer.batch_decode(token_ids)
|
||||||
return [
|
return [
|
||||||
|
|||||||
@@ -100,20 +100,22 @@ class ModelTpServer:
|
|||||||
nccl_port=nccl_port,
|
nccl_port=nccl_port,
|
||||||
server_args=server_args,
|
server_args=server_args,
|
||||||
)
|
)
|
||||||
|
if server_args.skip_tokenizer_init:
|
||||||
if is_multimodal_model(server_args.model_path):
|
self.tokenizer = self.processor = None
|
||||||
self.processor = get_processor(
|
|
||||||
server_args.tokenizer_path,
|
|
||||||
tokenizer_mode=server_args.tokenizer_mode,
|
|
||||||
trust_remote_code=server_args.trust_remote_code,
|
|
||||||
)
|
|
||||||
self.tokenizer = self.processor.tokenizer
|
|
||||||
else:
|
else:
|
||||||
self.tokenizer = get_tokenizer(
|
if is_multimodal_model(server_args.model_path):
|
||||||
server_args.tokenizer_path,
|
self.processor = get_processor(
|
||||||
tokenizer_mode=server_args.tokenizer_mode,
|
server_args.tokenizer_path,
|
||||||
trust_remote_code=server_args.trust_remote_code,
|
tokenizer_mode=server_args.tokenizer_mode,
|
||||||
)
|
trust_remote_code=server_args.trust_remote_code,
|
||||||
|
)
|
||||||
|
self.tokenizer = self.processor.tokenizer
|
||||||
|
else:
|
||||||
|
self.tokenizer = get_tokenizer(
|
||||||
|
server_args.tokenizer_path,
|
||||||
|
tokenizer_mode=server_args.tokenizer_mode,
|
||||||
|
trust_remote_code=server_args.trust_remote_code,
|
||||||
|
)
|
||||||
self.max_total_num_tokens = self.model_runner.max_total_num_tokens
|
self.max_total_num_tokens = self.model_runner.max_total_num_tokens
|
||||||
self.max_prefill_tokens = (
|
self.max_prefill_tokens = (
|
||||||
16384
|
16384
|
||||||
@@ -182,13 +184,15 @@ class ModelTpServer:
|
|||||||
self.last_stats_tic = time.time()
|
self.last_stats_tic = time.time()
|
||||||
|
|
||||||
# Init the FSM cache for constrained generation
|
# Init the FSM cache for constrained generation
|
||||||
self.regex_fsm_cache = FSMCache(
|
if not server_args.skip_tokenizer_init:
|
||||||
server_args.tokenizer_path,
|
self.regex_fsm_cache = FSMCache(
|
||||||
{
|
server_args.tokenizer_path,
|
||||||
"tokenizer_mode": server_args.tokenizer_mode,
|
{
|
||||||
"trust_remote_code": server_args.trust_remote_code,
|
"tokenizer_mode": server_args.tokenizer_mode,
|
||||||
},
|
"trust_remote_code": server_args.trust_remote_code,
|
||||||
)
|
},
|
||||||
|
skip_tokenizer_init=server_args.skip_tokenizer_init,
|
||||||
|
)
|
||||||
self.jump_forward_cache = JumpForwardCache()
|
self.jump_forward_cache = JumpForwardCache()
|
||||||
|
|
||||||
# Init new token estimation
|
# Init new token estimation
|
||||||
@@ -466,7 +470,11 @@ class ModelTpServer:
|
|||||||
|
|
||||||
next_token_ids = next_token_ids.tolist()
|
next_token_ids = next_token_ids.tolist()
|
||||||
else:
|
else:
|
||||||
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
|
if self.tokenizer is None:
|
||||||
|
for i, req in enumerate(batch.reqs):
|
||||||
|
next_token_ids.extend(req.sampling_params.stop_token_ids)
|
||||||
|
else:
|
||||||
|
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
|
||||||
|
|
||||||
# Check finish conditions
|
# Check finish conditions
|
||||||
pt = 0
|
pt = 0
|
||||||
|
|||||||
@@ -111,13 +111,19 @@ class SamplingParams:
|
|||||||
# Process stop strings
|
# Process stop strings
|
||||||
if self.stop_strs is None:
|
if self.stop_strs is None:
|
||||||
self.stop_strs = []
|
self.stop_strs = []
|
||||||
self.stop_str_max_len = 0
|
if self.stop_token_ids is None:
|
||||||
|
self.stop_str_max_len = 0
|
||||||
|
else:
|
||||||
|
self.stop_str_max_len = 1
|
||||||
else:
|
else:
|
||||||
if isinstance(self.stop_strs, str):
|
if isinstance(self.stop_strs, str):
|
||||||
self.stop_strs = [self.stop_strs]
|
self.stop_strs = [self.stop_strs]
|
||||||
|
|
||||||
stop_str_max_len = 0
|
stop_str_max_len = 0
|
||||||
for stop_str in self.stop_strs:
|
for stop_str in self.stop_strs:
|
||||||
stop_str_ids = tokenizer.encode(stop_str, add_special_tokens=False)
|
if tokenizer is not None:
|
||||||
stop_str_max_len = max(stop_str_max_len, len(stop_str_ids))
|
stop_str_ids = tokenizer.encode(stop_str, add_special_tokens=False)
|
||||||
|
stop_str_max_len = max(stop_str_max_len, len(stop_str_ids))
|
||||||
|
else:
|
||||||
|
stop_str_max_len = max(stop_str_max_len, len(stop_str))
|
||||||
self.stop_str_max_len = stop_str_max_len
|
self.stop_str_max_len = stop_str_max_len
|
||||||
|
|||||||
@@ -420,17 +420,22 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
|
|||||||
# Send a warmup request
|
# Send a warmup request
|
||||||
request_name = "/generate" if model_info["is_generation"] else "/encode"
|
request_name = "/generate" if model_info["is_generation"] else "/encode"
|
||||||
max_new_tokens = 8 if model_info["is_generation"] else 1
|
max_new_tokens = 8 if model_info["is_generation"] else 1
|
||||||
|
json_data = {
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_new_tokens": max_new_tokens,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if server_args.skip_tokenizer_init:
|
||||||
|
json_data["input_ids"] = [10, 11, 12]
|
||||||
|
else:
|
||||||
|
json_data["text"] = "The capital city of France is"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for _ in range(server_args.dp_size):
|
for _ in range(server_args.dp_size):
|
||||||
res = requests.post(
|
res = requests.post(
|
||||||
url + request_name,
|
url + request_name,
|
||||||
json={
|
json=json_data,
|
||||||
"text": "The capital city of France is",
|
|
||||||
"sampling_params": {
|
|
||||||
"temperature": 0,
|
|
||||||
"max_new_tokens": max_new_tokens,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
headers=headers,
|
headers=headers,
|
||||||
timeout=600,
|
timeout=600,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ class ServerArgs:
|
|||||||
model_path: str
|
model_path: str
|
||||||
tokenizer_path: Optional[str] = None
|
tokenizer_path: Optional[str] = None
|
||||||
tokenizer_mode: str = "auto"
|
tokenizer_mode: str = "auto"
|
||||||
|
skip_tokenizer_init: bool = False
|
||||||
load_format: str = "auto"
|
load_format: str = "auto"
|
||||||
dtype: str = "auto"
|
dtype: str = "auto"
|
||||||
trust_remote_code: bool = True
|
trust_remote_code: bool = True
|
||||||
@@ -151,6 +152,11 @@ class ServerArgs:
|
|||||||
"tokenizer if available, and 'slow' will "
|
"tokenizer if available, and 'slow' will "
|
||||||
"always use the slow tokenizer.",
|
"always use the slow tokenizer.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--skip-tokenizer-init",
|
||||||
|
action="store_true",
|
||||||
|
help="If set, skip init tokenizer and pass input_ids in generate request",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--load-format",
|
"--load-format",
|
||||||
type=str,
|
type=str,
|
||||||
|
|||||||
@@ -197,6 +197,8 @@ def allocate_init_ports(
|
|||||||
def get_int_token_logit_bias(tokenizer, vocab_size):
|
def get_int_token_logit_bias(tokenizer, vocab_size):
|
||||||
"""Get the logit bias for integer-only tokens."""
|
"""Get the logit bias for integer-only tokens."""
|
||||||
# a bug when model's vocab size > tokenizer.vocab_size
|
# a bug when model's vocab size > tokenizer.vocab_size
|
||||||
|
if tokenizer == None:
|
||||||
|
return [-1e5] * vocab_size
|
||||||
vocab_size = tokenizer.vocab_size
|
vocab_size = tokenizer.vocab_size
|
||||||
logit_bias = np.zeros(vocab_size, dtype=np.float32)
|
logit_bias = np.zeros(vocab_size, dtype=np.float32)
|
||||||
for t_id in range(vocab_size):
|
for t_id in range(vocab_size):
|
||||||
|
|||||||
77
test/srt/test_skip_tokenizer_srt.py
Normal file
77
test/srt/test_skip_tokenizer_srt.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from sglang.srt.utils import kill_child_process
|
||||||
|
from sglang.test.run_eval import run_eval
|
||||||
|
from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST, popen_launch_server
|
||||||
|
|
||||||
|
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
|
||||||
|
|
||||||
|
|
||||||
|
class TestSRTEndpoint(unittest.TestCase):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||||
|
cls.base_url = "http://127.0.0.1:8157"
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model, cls.base_url, timeout=300, other_args=["--skip-tokenizer-init"]
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_child_process(cls.process.pid)
|
||||||
|
|
||||||
|
def run_decode(
|
||||||
|
self, return_logprob=False, top_logprobs_num=0, return_text=False, n=1
|
||||||
|
):
|
||||||
|
response = requests.post(
|
||||||
|
self.base_url + "/generate",
|
||||||
|
json={
|
||||||
|
"input_ids": [
|
||||||
|
119689,
|
||||||
|
50650,
|
||||||
|
18291,
|
||||||
|
30061,
|
||||||
|
5316,
|
||||||
|
26951,
|
||||||
|
119690,
|
||||||
|
], # The capital of France is
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0 if n == 1 else 0.5,
|
||||||
|
"max_new_tokens": 32,
|
||||||
|
"n": n,
|
||||||
|
"stop_token_ids": [119690],
|
||||||
|
},
|
||||||
|
"stream": False,
|
||||||
|
"return_logprob": return_logprob,
|
||||||
|
"top_logprobs_num": top_logprobs_num,
|
||||||
|
"return_text_in_logprobs": return_text,
|
||||||
|
"logprob_start_len": 0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
print(json.dumps(response.json()))
|
||||||
|
print("=" * 100)
|
||||||
|
|
||||||
|
def test_simple_decode(self):
|
||||||
|
self.run_decode()
|
||||||
|
|
||||||
|
def test_parallel_sample(self):
|
||||||
|
self.run_decode(n=3)
|
||||||
|
|
||||||
|
def test_logprob(self):
|
||||||
|
for top_logprobs_num in [0, 3]:
|
||||||
|
for return_text in [False, False]:
|
||||||
|
self.run_decode(
|
||||||
|
return_logprob=True,
|
||||||
|
top_logprobs_num=top_logprobs_num,
|
||||||
|
return_text=return_text,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main(warnings="ignore")
|
||||||
Reference in New Issue
Block a user