Add support for OpenAI API parallel sampling (#640)
This commit is contained in:
75
examples/usage/openai_parallel_sample.py
Normal file
75
examples/usage/openai_parallel_sample.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
import openai
|
||||||
|
|
||||||
|
client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="EMPTY")
|
||||||
|
|
||||||
|
# Text completion
|
||||||
|
response = client.completions.create(
|
||||||
|
model="default",
|
||||||
|
prompt="I am a robot and I want to study like humans. Now let's tell a story. Once upon a time, there was a little",
|
||||||
|
n=1,
|
||||||
|
temperature=0.8,
|
||||||
|
max_tokens=32,
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
|
||||||
|
# Text completion
|
||||||
|
response = client.completions.create(
|
||||||
|
model="default",
|
||||||
|
prompt="I am a robot and I want to study like humans. Now let's tell a story. Once upon a time, there was a little",
|
||||||
|
n=3,
|
||||||
|
temperature=0.8,
|
||||||
|
max_tokens=32,
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
|
||||||
|
# Text completion
|
||||||
|
response = client.completions.create(
|
||||||
|
model="default",
|
||||||
|
prompt=["The name of the famous soccer player is ", "The capital of US is"],
|
||||||
|
n=1,
|
||||||
|
temperature=0.8,
|
||||||
|
max_tokens=32,
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
|
||||||
|
# Text completion
|
||||||
|
response = client.completions.create(
|
||||||
|
model="default",
|
||||||
|
prompt=["The name of the famous soccer player is ", "The capital of US is"],
|
||||||
|
n=3,
|
||||||
|
temperature=0.8,
|
||||||
|
max_tokens=32,
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
|
||||||
|
# Text completion
|
||||||
|
response = client.completions.create(
|
||||||
|
model="default",
|
||||||
|
prompt=[
|
||||||
|
"The capital of France is",
|
||||||
|
"The capital of Germany is",
|
||||||
|
"The capital of US is",
|
||||||
|
],
|
||||||
|
n=3,
|
||||||
|
temperature=0.8,
|
||||||
|
max_tokens=32,
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
# Chat completion
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="default",
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||||
|
{"role": "user", "content": "List 3 countries and their capitals."},
|
||||||
|
],
|
||||||
|
temperature=0.8,
|
||||||
|
max_tokens=64,
|
||||||
|
logprobs=True,
|
||||||
|
n=4,
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
@@ -40,11 +40,13 @@ class GenerateReqInput:
|
|||||||
self.text is not None and self.input_ids is not None
|
self.text is not None and self.input_ids is not None
|
||||||
):
|
):
|
||||||
raise ValueError("Either text or input_ids should be provided.")
|
raise ValueError("Either text or input_ids should be provided.")
|
||||||
|
if "n" in self.sampling_params and self.sampling_params["n"] != 1:
|
||||||
if self.text is not None:
|
is_single = False
|
||||||
is_single = isinstance(self.text, str)
|
|
||||||
else:
|
else:
|
||||||
is_single = isinstance(self.input_ids[0], int)
|
if self.text is not None:
|
||||||
|
is_single = isinstance(self.text, str)
|
||||||
|
else:
|
||||||
|
is_single = isinstance(self.input_ids[0], int)
|
||||||
self.is_single = is_single
|
self.is_single = is_single
|
||||||
|
|
||||||
if is_single:
|
if is_single:
|
||||||
@@ -59,7 +61,22 @@ class GenerateReqInput:
|
|||||||
if self.top_logprobs_num is None:
|
if self.top_logprobs_num is None:
|
||||||
self.top_logprobs_num = 0
|
self.top_logprobs_num = 0
|
||||||
else:
|
else:
|
||||||
num = len(self.text) if self.text is not None else len(self.input_ids)
|
|
||||||
|
parallel_sample_num = self.sampling_params.get("n", 1)
|
||||||
|
|
||||||
|
if parallel_sample_num != 1:
|
||||||
|
# parallel sampling +1 represents the original prefill stage
|
||||||
|
num = parallel_sample_num + 1
|
||||||
|
if isinstance(self.text, List):
|
||||||
|
## suppot batch operation
|
||||||
|
self.batch_size = len(self.text)
|
||||||
|
num = num * len(self.text)
|
||||||
|
else:
|
||||||
|
self.batch_size = 1
|
||||||
|
else:
|
||||||
|
## support select operation
|
||||||
|
num = len(self.text) if self.text is not None else len(self.input_ids)
|
||||||
|
self.batch_size = num
|
||||||
|
|
||||||
if self.image_data is None:
|
if self.image_data is None:
|
||||||
self.image_data = [None] * num
|
self.image_data = [None] * num
|
||||||
|
|||||||
@@ -122,125 +122,150 @@ class TokenizerManager:
|
|||||||
|
|
||||||
obj.post_init()
|
obj.post_init()
|
||||||
is_single = obj.is_single
|
is_single = obj.is_single
|
||||||
|
|
||||||
if is_single:
|
if is_single:
|
||||||
rid = obj.rid
|
async for response in self._handle_single_request(obj, request):
|
||||||
|
yield response
|
||||||
if obj.input_ids is None:
|
|
||||||
input_ids = self.tokenizer.encode(obj.text)
|
|
||||||
else:
|
|
||||||
input_ids = obj.input_ids
|
|
||||||
|
|
||||||
if len(input_ids) >= self.context_len:
|
|
||||||
raise ValueError(
|
|
||||||
f"The input ({len(input_ids)} tokens) is longer than the "
|
|
||||||
f"model's context length ({self.context_len} tokens)."
|
|
||||||
)
|
|
||||||
|
|
||||||
sampling_params = SamplingParams(**obj.sampling_params)
|
|
||||||
if sampling_params.max_new_tokens != 0:
|
|
||||||
sampling_params.normalize(self.tokenizer)
|
|
||||||
sampling_params.verify()
|
|
||||||
|
|
||||||
if isinstance(obj.image_data, list) and len(obj.image_data) > 0:
|
|
||||||
pixel_values, image_hash, image_size = await self.get_pixel_values(
|
|
||||||
obj.image_data[0]
|
|
||||||
)
|
|
||||||
elif isinstance(obj.image_data, str):
|
|
||||||
pixel_values, image_hash, image_size = await self.get_pixel_values(
|
|
||||||
obj.image_data
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
pixel_values, image_hash, image_size = None, None, None
|
|
||||||
tokenized_obj = TokenizedGenerateReqInput(
|
|
||||||
rid=rid,
|
|
||||||
input_text=obj.text,
|
|
||||||
input_ids=input_ids,
|
|
||||||
pixel_values=pixel_values,
|
|
||||||
image_hash=image_hash,
|
|
||||||
image_size=image_size,
|
|
||||||
sampling_params=sampling_params,
|
|
||||||
return_logprob=obj.return_logprob,
|
|
||||||
logprob_start_len=obj.logprob_start_len,
|
|
||||||
top_logprobs_num=obj.top_logprobs_num,
|
|
||||||
stream=obj.stream,
|
|
||||||
)
|
|
||||||
self.send_to_router.send_pyobj(tokenized_obj)
|
|
||||||
|
|
||||||
event = asyncio.Event()
|
|
||||||
state = ReqState([], False, event)
|
|
||||||
self.rid_to_state[rid] = state
|
|
||||||
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(event.wait(), timeout=4)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
if request is not None and await request.is_disconnected():
|
|
||||||
self.abort_request(rid)
|
|
||||||
raise ValueError(f"Abort request {rid}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
out = self.convert_logprob_style(
|
|
||||||
state.out_list[-1],
|
|
||||||
obj.return_logprob,
|
|
||||||
obj.top_logprobs_num,
|
|
||||||
obj.return_text_in_logprobs,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.server_args.log_requests and state.finished:
|
|
||||||
logger.info(f"in={obj.text}, out={out}")
|
|
||||||
|
|
||||||
state.out_list = []
|
|
||||||
if state.finished:
|
|
||||||
del self.rid_to_state[rid]
|
|
||||||
|
|
||||||
yield out
|
|
||||||
|
|
||||||
break
|
|
||||||
|
|
||||||
event.clear()
|
|
||||||
|
|
||||||
yield out
|
|
||||||
else:
|
else:
|
||||||
if obj.stream:
|
if obj.stream:
|
||||||
raise ValueError("Do not support stream for batch mode.")
|
raise ValueError("Do not support stream for batch mode.")
|
||||||
|
|
||||||
if obj.input_ids is None:
|
async for response in self._handle_batch_request(obj, request):
|
||||||
bs = len(obj.text)
|
yield response
|
||||||
|
|
||||||
|
async def _handle_single_request(self, obj, request, index=None, is_prefill=False):
|
||||||
|
if is_prefill:
|
||||||
|
if isinstance(obj.text, list):
|
||||||
|
input_text = obj.text[index]
|
||||||
|
rid = obj.rid[index]
|
||||||
else:
|
else:
|
||||||
bs = len(obj.input_ids)
|
input_text = obj.text
|
||||||
|
rid = obj.rid[0]
|
||||||
|
input_ids = self.tokenizer.encode(input_text)
|
||||||
|
sampling_params = SamplingParams(**obj.sampling_params[0])
|
||||||
|
sampling_params.max_new_tokens = 0
|
||||||
|
pixel_values, image_hash, image_size = await self._get_pixel_values(
|
||||||
|
obj.image_data[0]
|
||||||
|
)
|
||||||
|
return_logprob = obj.return_logprob[0]
|
||||||
|
logprob_start_len = obj.logprob_start_len[0]
|
||||||
|
top_logprobs_num = obj.top_logprobs_num[0]
|
||||||
|
else:
|
||||||
|
rid = obj.rid if index is None else obj.rid[index]
|
||||||
|
input_text = obj.text if index is None else obj.text[index]
|
||||||
|
input_ids = (
|
||||||
|
self.tokenizer.encode(input_text)
|
||||||
|
if obj.input_ids is None
|
||||||
|
else obj.input_ids
|
||||||
|
)
|
||||||
|
if index is not None and obj.input_ids:
|
||||||
|
input_ids = obj.input_ids[index]
|
||||||
|
|
||||||
for i in range(bs):
|
self._validate_input_length(input_ids)
|
||||||
rid = obj.rid[i]
|
sampling_params = self._get_sampling_params(
|
||||||
|
obj.sampling_params if index is None else obj.sampling_params[index]
|
||||||
|
)
|
||||||
|
pixel_values, image_hash, image_size = await self._get_pixel_values(
|
||||||
|
obj.image_data if index is None else obj.image_data[index]
|
||||||
|
)
|
||||||
|
return_logprob = (
|
||||||
|
obj.return_logprob if index is None else obj.return_logprob[index]
|
||||||
|
)
|
||||||
|
logprob_start_len = (
|
||||||
|
obj.logprob_start_len if index is None else obj.logprob_start_len[index]
|
||||||
|
)
|
||||||
|
top_logprobs_num = (
|
||||||
|
obj.top_logprobs_num if index is None else obj.top_logprobs_num[index]
|
||||||
|
)
|
||||||
|
|
||||||
if obj.input_ids is None:
|
tokenized_obj = TokenizedGenerateReqInput(
|
||||||
input_text = obj.text[i]
|
rid,
|
||||||
input_ids = self.tokenizer.encode(obj.text[i])
|
input_text,
|
||||||
|
input_ids,
|
||||||
|
pixel_values,
|
||||||
|
image_hash,
|
||||||
|
image_size,
|
||||||
|
sampling_params,
|
||||||
|
return_logprob,
|
||||||
|
logprob_start_len,
|
||||||
|
top_logprobs_num,
|
||||||
|
obj.stream,
|
||||||
|
)
|
||||||
|
self.send_to_router.send_pyobj(tokenized_obj)
|
||||||
|
|
||||||
|
event = asyncio.Event()
|
||||||
|
state = ReqState([], False, event)
|
||||||
|
self.rid_to_state[rid] = state
|
||||||
|
if is_prefill == False:
|
||||||
|
async for response in self._wait_for_response(
|
||||||
|
event, state, obj, rid, request
|
||||||
|
):
|
||||||
|
yield response
|
||||||
|
else:
|
||||||
|
await self._wait_for_prefill_response(event, state, obj, request, rid)
|
||||||
|
yield input_ids
|
||||||
|
|
||||||
|
async def _handle_batch_request(self, obj, request):
|
||||||
|
batch_size = obj.batch_size
|
||||||
|
parallel_sample_num = obj.sampling_params[0].get("n", 1)
|
||||||
|
|
||||||
|
if parallel_sample_num != 1:
|
||||||
|
## send prefill requests
|
||||||
|
parallel_sample_num += 1
|
||||||
|
input_id_result = [] if obj.input_ids is None else None
|
||||||
|
for i in range(batch_size):
|
||||||
|
async for input_id in self._handle_single_request(
|
||||||
|
obj, request, index=i, is_prefill=True
|
||||||
|
):
|
||||||
|
if input_id_result is not None:
|
||||||
|
input_id_result.append(input_id)
|
||||||
|
pass
|
||||||
|
if len(input_id_result) > 1 and input_id_result is not None:
|
||||||
|
obj.input_ids = input_id_result
|
||||||
|
elif input_id_result is not None:
|
||||||
|
obj.input_ids = input_id_result[0]
|
||||||
|
# First send out all requests
|
||||||
|
for i in range(batch_size):
|
||||||
|
for j in range(parallel_sample_num):
|
||||||
|
if j == 0 and parallel_sample_num != 1:
|
||||||
|
continue
|
||||||
|
index = i * parallel_sample_num + j
|
||||||
|
if parallel_sample_num != 1:
|
||||||
|
# Here when using parallel sampling we shoul consider prefill stage so the index is : j + i * (parallel_sample_num-1) + batch_size - 1
|
||||||
|
index += batch_size - 1 - i
|
||||||
|
rid = obj.rid[index]
|
||||||
|
if parallel_sample_num == 1:
|
||||||
|
## select operation
|
||||||
|
if obj.input_ids is None:
|
||||||
|
input_text = obj.text[i]
|
||||||
|
input_ids = self.tokenizer.encode(obj.text[i])
|
||||||
|
else:
|
||||||
|
input_text = None
|
||||||
|
input_ids = obj.input_ids[i]
|
||||||
else:
|
else:
|
||||||
input_text = None
|
if batch_size == 1:
|
||||||
input_ids = obj.input_ids[i]
|
input_text = obj.text
|
||||||
|
input_ids = obj.input_ids
|
||||||
|
else:
|
||||||
|
input_text = obj.text[i]
|
||||||
|
input_ids = obj.input_ids[i]
|
||||||
|
sampling_params = self._get_sampling_params(obj.sampling_params[index])
|
||||||
|
pixel_values, image_hash, image_size = await self._get_pixel_values(
|
||||||
|
obj.image_data[index]
|
||||||
|
)
|
||||||
|
|
||||||
sampling_params = SamplingParams(**obj.sampling_params[i])
|
|
||||||
if sampling_params.max_new_tokens != 0:
|
|
||||||
sampling_params.normalize(self.tokenizer)
|
|
||||||
sampling_params.verify()
|
|
||||||
if obj.image_data[i] is None:
|
|
||||||
pixel_values, image_hash, image_size = None, None, None
|
|
||||||
else:
|
|
||||||
pixel_values, image_hash, image_size = await self.get_pixel_values(
|
|
||||||
obj.image_data[i]
|
|
||||||
)
|
|
||||||
tokenized_obj = TokenizedGenerateReqInput(
|
tokenized_obj = TokenizedGenerateReqInput(
|
||||||
rid=rid,
|
rid,
|
||||||
input_text=input_text,
|
input_text,
|
||||||
input_ids=input_ids,
|
input_ids,
|
||||||
pixel_values=pixel_values,
|
pixel_values,
|
||||||
image_hash=image_hash,
|
image_hash,
|
||||||
image_size=image_size,
|
image_size,
|
||||||
sampling_params=sampling_params,
|
sampling_params,
|
||||||
return_logprob=obj.return_logprob[i],
|
obj.return_logprob[index],
|
||||||
logprob_start_len=obj.logprob_start_len[i],
|
obj.logprob_start_len[index],
|
||||||
top_logprobs_num=obj.top_logprobs_num[i],
|
obj.top_logprobs_num[index],
|
||||||
stream=obj.stream,
|
obj.stream,
|
||||||
)
|
)
|
||||||
self.send_to_router.send_pyobj(tokenized_obj)
|
self.send_to_router.send_pyobj(tokenized_obj)
|
||||||
|
|
||||||
@@ -248,9 +273,16 @@ class TokenizerManager:
|
|||||||
state = ReqState([], False, event)
|
state = ReqState([], False, event)
|
||||||
self.rid_to_state[rid] = state
|
self.rid_to_state[rid] = state
|
||||||
|
|
||||||
output_list = []
|
# Then wait for all responses
|
||||||
for i in range(bs):
|
output_list = []
|
||||||
rid = obj.rid[i]
|
for i in range(batch_size):
|
||||||
|
for j in range(parallel_sample_num):
|
||||||
|
if j == 0 and parallel_sample_num != 1:
|
||||||
|
continue
|
||||||
|
index = i * parallel_sample_num + j
|
||||||
|
if parallel_sample_num != 1:
|
||||||
|
index += batch_size - 1 - i
|
||||||
|
rid = obj.rid[index]
|
||||||
state = self.rid_to_state[rid]
|
state = self.rid_to_state[rid]
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
@@ -263,19 +295,86 @@ class TokenizerManager:
|
|||||||
self.abort_request(rid)
|
self.abort_request(rid)
|
||||||
raise ValueError(f"Abort request {rid}")
|
raise ValueError(f"Abort request {rid}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
output_list.append(
|
output_list.append(
|
||||||
self.convert_logprob_style(
|
self.convert_logprob_style(
|
||||||
state.out_list[-1],
|
state.out_list[-1],
|
||||||
obj.return_logprob[i],
|
obj.return_logprob[index],
|
||||||
obj.top_logprobs_num[i],
|
obj.top_logprobs_num[index],
|
||||||
obj.return_text_in_logprobs,
|
obj.return_text_in_logprobs,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
assert state.finished
|
assert state.finished
|
||||||
del self.rid_to_state[rid]
|
del self.rid_to_state[rid]
|
||||||
|
|
||||||
yield output_list
|
yield output_list
|
||||||
|
|
||||||
|
def _validate_input_length(self, input_ids):
|
||||||
|
if len(input_ids) >= self.context_len:
|
||||||
|
raise ValueError(
|
||||||
|
f"The input ({len(input_ids)} tokens) is longer than the "
|
||||||
|
f"model's context length ({self.context_len} tokens)."
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_sampling_params(self, sampling_params_data, max_new_tokens=None):
|
||||||
|
sampling_params = SamplingParams(**sampling_params_data)
|
||||||
|
if max_new_tokens is not None:
|
||||||
|
sampling_params.max_new_tokens = max_new_tokens
|
||||||
|
if sampling_params.max_new_tokens != 0:
|
||||||
|
sampling_params.normalize(self.tokenizer)
|
||||||
|
sampling_params.verify()
|
||||||
|
return sampling_params
|
||||||
|
|
||||||
|
async def _get_pixel_values(self, image_data):
|
||||||
|
if isinstance(image_data, list) and len(image_data) > 0:
|
||||||
|
return await self.get_pixel_values(image_data[0])
|
||||||
|
elif isinstance(image_data, str):
|
||||||
|
return await self.get_pixel_values(image_data)
|
||||||
|
else:
|
||||||
|
return None, None, None
|
||||||
|
|
||||||
|
async def _wait_for_response(self, event, state, obj, rid, request):
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(event.wait(), timeout=4)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
if request is not None and await request.is_disconnected():
|
||||||
|
self.abort_request(rid)
|
||||||
|
raise ValueError(f"Abort request {rid}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
out = self.convert_logprob_style(
|
||||||
|
state.out_list[-1],
|
||||||
|
obj.return_logprob,
|
||||||
|
obj.top_logprobs_num,
|
||||||
|
obj.return_text_in_logprobs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.server_args.log_requests and state.finished:
|
||||||
|
logger.info(f"in={obj.text}, out={out}")
|
||||||
|
|
||||||
|
state.out_list = []
|
||||||
|
if state.finished:
|
||||||
|
del self.rid_to_state[rid]
|
||||||
|
yield out
|
||||||
|
break
|
||||||
|
|
||||||
|
event.clear()
|
||||||
|
yield out
|
||||||
|
|
||||||
|
async def _wait_for_prefill_response(self, event, state, obj, request, rid):
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(state.event.wait(), timeout=4)
|
||||||
|
break
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
if request is not None and await request.is_disconnected():
|
||||||
|
for rid in obj.rid:
|
||||||
|
self.abort_request(rid)
|
||||||
|
raise ValueError(f"Abort request {rid}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
assert state.finished
|
||||||
|
del self.rid_to_state[rid]
|
||||||
|
|
||||||
def flush_cache(self):
|
def flush_cache(self):
|
||||||
req = FlushCacheReq()
|
req = FlushCacheReq()
|
||||||
|
|||||||
@@ -95,9 +95,6 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|||||||
request_json = await raw_request.json()
|
request_json = await raw_request.json()
|
||||||
request = CompletionRequest(**request_json)
|
request = CompletionRequest(**request_json)
|
||||||
|
|
||||||
if request.n != 1:
|
|
||||||
return create_error_response("n != 1 is not supported")
|
|
||||||
|
|
||||||
adapted_request = GenerateReqInput(
|
adapted_request = GenerateReqInput(
|
||||||
text=request.prompt,
|
text=request.prompt,
|
||||||
sampling_params={
|
sampling_params={
|
||||||
@@ -108,6 +105,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|||||||
"presence_penalty": request.presence_penalty,
|
"presence_penalty": request.presence_penalty,
|
||||||
"frequency_penalty": request.frequency_penalty,
|
"frequency_penalty": request.frequency_penalty,
|
||||||
"regex": request.regex,
|
"regex": request.regex,
|
||||||
|
"n": request.n,
|
||||||
},
|
},
|
||||||
return_logprob=request.logprobs is not None and request.logprobs > 0,
|
return_logprob=request.logprobs is not None and request.logprobs > 0,
|
||||||
top_logprobs_num=request.logprobs if request.logprobs is not None else 0,
|
top_logprobs_num=request.logprobs if request.logprobs is not None else 0,
|
||||||
@@ -202,46 +200,56 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return create_error_response(str(e))
|
return create_error_response(str(e))
|
||||||
|
|
||||||
ret = ret[0] if isinstance(ret, list) else ret
|
if not isinstance(ret, list):
|
||||||
prompt_tokens = ret["meta_info"]["prompt_tokens"]
|
ret = [ret]
|
||||||
completion_tokens = ret["meta_info"]["completion_tokens"]
|
choices = []
|
||||||
text = ret["text"]
|
|
||||||
if request.echo:
|
for idx, ret_item in enumerate(ret):
|
||||||
text = request.prompt + text
|
text = ret_item["text"]
|
||||||
|
|
||||||
if request.logprobs:
|
|
||||||
if request.echo:
|
if request.echo:
|
||||||
prefill_token_logprobs = ret["meta_info"]["prefill_token_logprobs"]
|
text = request.prompt + text
|
||||||
prefill_top_logprobs = ret["meta_info"]["prefill_top_logprobs"]
|
|
||||||
|
if request.logprobs:
|
||||||
|
if request.echo:
|
||||||
|
prefill_token_logprobs = ret_item["meta_info"]["prefill_token_logprobs"]
|
||||||
|
prefill_top_logprobs = ret_item["meta_info"]["prefill_top_logprobs"]
|
||||||
|
else:
|
||||||
|
prefill_token_logprobs = None
|
||||||
|
prefill_top_logprobs = None
|
||||||
|
|
||||||
|
logprobs = to_openai_style_logprobs(
|
||||||
|
prefill_token_logprobs=prefill_token_logprobs,
|
||||||
|
prefill_top_logprobs=prefill_top_logprobs,
|
||||||
|
decode_token_logprobs=ret_item["meta_info"]["decode_token_logprobs"],
|
||||||
|
decode_top_logprobs=ret_item["meta_info"]["decode_top_logprobs"],
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
prefill_token_logprobs = None
|
logprobs = None
|
||||||
prefill_top_logprobs = None
|
|
||||||
|
|
||||||
logprobs = to_openai_style_logprobs(
|
choice_data = CompletionResponseChoice(
|
||||||
prefill_token_logprobs=prefill_token_logprobs,
|
index=idx,
|
||||||
prefill_top_logprobs=prefill_top_logprobs,
|
text=text,
|
||||||
decode_token_logprobs=ret["meta_info"]["decode_token_logprobs"],
|
logprobs=logprobs,
|
||||||
decode_top_logprobs=ret["meta_info"]["decode_top_logprobs"],
|
finish_reason=ret_item["meta_info"]["finish_reason"],
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
logprobs = None
|
|
||||||
|
|
||||||
choice_data = CompletionResponseChoice(
|
choices.append(choice_data)
|
||||||
index=0,
|
|
||||||
text=text,
|
|
||||||
logprobs=logprobs,
|
|
||||||
finish_reason=ret["meta_info"]["finish_reason"],
|
|
||||||
)
|
|
||||||
response = CompletionResponse(
|
response = CompletionResponse(
|
||||||
id=ret["meta_info"]["id"],
|
id=ret[0]["meta_info"]["id"],
|
||||||
model=request.model,
|
model=request.model,
|
||||||
choices=[choice_data],
|
choices=choices,
|
||||||
usage=UsageInfo(
|
usage=UsageInfo(
|
||||||
prompt_tokens=prompt_tokens,
|
prompt_tokens=ret[0]["meta_info"]["prompt_tokens"],
|
||||||
completion_tokens=completion_tokens,
|
completion_tokens=sum(
|
||||||
total_tokens=prompt_tokens + completion_tokens,
|
item["meta_info"]["completion_tokens"] for item in ret
|
||||||
|
),
|
||||||
|
total_tokens=ret[0]["meta_info"]["prompt_tokens"]
|
||||||
|
+ sum(item["meta_info"]["completion_tokens"] for item in ret),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
@@ -249,9 +257,6 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|||||||
request_json = await raw_request.json()
|
request_json = await raw_request.json()
|
||||||
request = ChatCompletionRequest(**request_json)
|
request = ChatCompletionRequest(**request_json)
|
||||||
|
|
||||||
if request.n != 1:
|
|
||||||
return create_error_response("n != 1 is not supported")
|
|
||||||
|
|
||||||
# Prep the data needed for the underlying GenerateReqInput:
|
# Prep the data needed for the underlying GenerateReqInput:
|
||||||
# - prompt: The full prompt string.
|
# - prompt: The full prompt string.
|
||||||
# - stop: Custom stop tokens.
|
# - stop: Custom stop tokens.
|
||||||
@@ -292,6 +297,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|||||||
"presence_penalty": request.presence_penalty,
|
"presence_penalty": request.presence_penalty,
|
||||||
"frequency_penalty": request.frequency_penalty,
|
"frequency_penalty": request.frequency_penalty,
|
||||||
"regex": request.regex,
|
"regex": request.regex,
|
||||||
|
"n": request.n,
|
||||||
},
|
},
|
||||||
stream=request.stream,
|
stream=request.stream,
|
||||||
)
|
)
|
||||||
@@ -354,23 +360,37 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return create_error_response(str(e))
|
return create_error_response(str(e))
|
||||||
|
|
||||||
prompt_tokens = ret["meta_info"]["prompt_tokens"]
|
if not isinstance(ret, list):
|
||||||
completion_tokens = ret["meta_info"]["completion_tokens"]
|
ret = [ret]
|
||||||
choice_data = ChatCompletionResponseChoice(
|
choices = []
|
||||||
index=0,
|
total_prompt_tokens = 0
|
||||||
message=ChatMessage(role="assistant", content=ret["text"]),
|
total_completion_tokens = 0
|
||||||
finish_reason=ret["meta_info"]["finish_reason"],
|
|
||||||
)
|
for idx, ret_item in enumerate(ret):
|
||||||
|
prompt_tokens = ret_item["meta_info"]["prompt_tokens"]
|
||||||
|
completion_tokens = ret_item["meta_info"]["completion_tokens"]
|
||||||
|
|
||||||
|
choice_data = ChatCompletionResponseChoice(
|
||||||
|
index=idx,
|
||||||
|
message=ChatMessage(role="assistant", content=ret_item["text"]),
|
||||||
|
finish_reason=ret_item["meta_info"]["finish_reason"],
|
||||||
|
)
|
||||||
|
|
||||||
|
choices.append(choice_data)
|
||||||
|
total_prompt_tokens = prompt_tokens
|
||||||
|
total_completion_tokens += completion_tokens
|
||||||
|
|
||||||
response = ChatCompletionResponse(
|
response = ChatCompletionResponse(
|
||||||
id=ret["meta_info"]["id"],
|
id=ret[0]["meta_info"]["id"],
|
||||||
model=request.model,
|
model=request.model,
|
||||||
choices=[choice_data],
|
choices=choices,
|
||||||
usage=UsageInfo(
|
usage=UsageInfo(
|
||||||
prompt_tokens=prompt_tokens,
|
prompt_tokens=total_prompt_tokens,
|
||||||
completion_tokens=completion_tokens,
|
completion_tokens=total_completion_tokens,
|
||||||
total_tokens=prompt_tokens + completion_tokens,
|
total_tokens=total_prompt_tokens + total_completion_tokens,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ class SamplingParams:
|
|||||||
spaces_between_special_tokens: bool = True,
|
spaces_between_special_tokens: bool = True,
|
||||||
dtype: Optional[str] = None,
|
dtype: Optional[str] = None,
|
||||||
regex: Optional[str] = None,
|
regex: Optional[str] = None,
|
||||||
|
n: int = 1,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
self.top_p = top_p
|
self.top_p = top_p
|
||||||
@@ -33,6 +34,7 @@ class SamplingParams:
|
|||||||
self.spaces_between_special_tokens = spaces_between_special_tokens
|
self.spaces_between_special_tokens = spaces_between_special_tokens
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.regex = regex
|
self.regex = regex
|
||||||
|
self.n = n
|
||||||
|
|
||||||
# Process some special cases
|
# Process some special cases
|
||||||
if self.temperature < _SAMPLING_EPS:
|
if self.temperature < _SAMPLING_EPS:
|
||||||
|
|||||||
Reference in New Issue
Block a user