feat: allow streaming for multi-prompt and/or parallel sampling (#1134)
This commit is contained in:
@@ -153,9 +153,6 @@ class TokenizerManager:
|
|||||||
async for response in self._handle_single_request(obj, request):
|
async for response in self._handle_single_request(obj, request):
|
||||||
yield response
|
yield response
|
||||||
else:
|
else:
|
||||||
if hasattr(obj, "stream") and obj.stream:
|
|
||||||
raise ValueError("Do not support stream for batch mode.")
|
|
||||||
|
|
||||||
async for response in self._handle_batch_request(obj, request):
|
async for response in self._handle_batch_request(obj, request):
|
||||||
yield response
|
yield response
|
||||||
|
|
||||||
@@ -311,6 +308,7 @@ class TokenizerManager:
|
|||||||
parallel_sample_num = 1
|
parallel_sample_num = 1
|
||||||
|
|
||||||
# First send out all requests
|
# First send out all requests
|
||||||
|
generators = []
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
for j in range(parallel_sample_num):
|
for j in range(parallel_sample_num):
|
||||||
if j == 0 and parallel_sample_num != 1:
|
if j == 0 and parallel_sample_num != 1:
|
||||||
@@ -371,42 +369,48 @@ class TokenizerManager:
|
|||||||
state = ReqState([], False, event)
|
state = ReqState([], False, event)
|
||||||
self.rid_to_state[rid] = state
|
self.rid_to_state[rid] = state
|
||||||
|
|
||||||
# Then wait for all responses
|
generators.append(
|
||||||
output_list = []
|
self._wait_for_response(
|
||||||
for i in range(batch_size):
|
event,
|
||||||
for j in range(parallel_sample_num):
|
state,
|
||||||
if j == 0 and parallel_sample_num != 1:
|
obj,
|
||||||
continue
|
rid,
|
||||||
index = i * parallel_sample_num + j
|
request,
|
||||||
if parallel_sample_num != 1:
|
index=index,
|
||||||
index += batch_size - 1 - i
|
response_index=len(generators),
|
||||||
rid = obj.rid[index]
|
|
||||||
state = self.rid_to_state[rid]
|
|
||||||
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(state.event.wait(), timeout=4)
|
|
||||||
break
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
if request is not None and await request.is_disconnected():
|
|
||||||
for rid in obj.rid:
|
|
||||||
self.abort_request(rid)
|
|
||||||
raise ValueError(f"Abort request {rid}")
|
|
||||||
continue
|
|
||||||
if self.is_generation:
|
|
||||||
output_list.append(
|
|
||||||
self.convert_logprob_style(
|
|
||||||
state.out_list[-1],
|
|
||||||
obj.return_logprob[index],
|
|
||||||
obj.top_logprobs_num[index],
|
|
||||||
obj.return_text_in_logprobs,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
else:
|
)
|
||||||
output_list.append(state.out_list[-1])
|
|
||||||
assert state.finished
|
# Then process the responses based on streaming option
|
||||||
del self.rid_to_state[rid]
|
|
||||||
yield output_list
|
is_stream = hasattr(obj, "stream") and obj.stream
|
||||||
|
|
||||||
|
tasks = [asyncio.create_task(gen.__anext__()) for gen in generators]
|
||||||
|
output_list = []
|
||||||
|
|
||||||
|
while tasks:
|
||||||
|
done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
||||||
|
|
||||||
|
for task in done:
|
||||||
|
gen_index = tasks.index(task)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = task.result()
|
||||||
|
|
||||||
|
if is_stream:
|
||||||
|
yield result
|
||||||
|
else:
|
||||||
|
output_list.append(result)
|
||||||
|
|
||||||
|
tasks[gen_index] = asyncio.create_task(
|
||||||
|
generators[gen_index].__anext__()
|
||||||
|
)
|
||||||
|
except StopAsyncIteration:
|
||||||
|
del generators[gen_index]
|
||||||
|
del tasks[gen_index]
|
||||||
|
|
||||||
|
if not is_stream:
|
||||||
|
yield output_list
|
||||||
|
|
||||||
def _validate_input_length(self, input_ids: List[int]):
|
def _validate_input_length(self, input_ids: List[int]):
|
||||||
if len(input_ids) >= self.context_len:
|
if len(input_ids) >= self.context_len:
|
||||||
@@ -437,26 +441,35 @@ class TokenizerManager:
|
|||||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||||
rid: str,
|
rid: str,
|
||||||
request,
|
request,
|
||||||
|
index: int = None,
|
||||||
|
response_index: int = 0,
|
||||||
):
|
):
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(event.wait(), timeout=4)
|
await asyncio.wait_for(event.wait(), timeout=4)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
if request is not None and await request.is_disconnected():
|
if request is not None and await request.is_disconnected():
|
||||||
self.abort_request(rid)
|
for rid in [obj.rid] if obj.is_single else obj.rid:
|
||||||
|
self.abort_request(rid)
|
||||||
raise ValueError(f"Abort request {rid}")
|
raise ValueError(f"Abort request {rid}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if self.is_generation:
|
if self.is_generation:
|
||||||
out = self.convert_logprob_style(
|
out = self.convert_logprob_style(
|
||||||
state.out_list[-1],
|
state.out_list[-1],
|
||||||
obj.return_logprob,
|
obj.return_logprob if index is None else obj.return_logprob[index],
|
||||||
obj.top_logprobs_num,
|
(
|
||||||
|
obj.top_logprobs_num
|
||||||
|
if index is None
|
||||||
|
else obj.top_logprobs_num[index]
|
||||||
|
),
|
||||||
obj.return_text_in_logprobs,
|
obj.return_text_in_logprobs,
|
||||||
)
|
)
|
||||||
else: # isinstance(obj, EmbeddingReqInput)
|
else: # isinstance(obj, EmbeddingReqInput)
|
||||||
out = state.out_list[-1]
|
out = state.out_list[-1]
|
||||||
|
|
||||||
|
out["index"] = response_index
|
||||||
|
|
||||||
# Log requests
|
# Log requests
|
||||||
if self.server_args.log_requests and state.finished:
|
if self.server_args.log_requests and state.finished:
|
||||||
if obj.text is None:
|
if obj.text is None:
|
||||||
|
|||||||
@@ -277,6 +277,12 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
|
|||||||
request_data = json.loads(line)
|
request_data = json.loads(line)
|
||||||
file_request_list.append(request_data)
|
file_request_list.append(request_data)
|
||||||
body = request_data["body"]
|
body = request_data["body"]
|
||||||
|
|
||||||
|
# Although streaming is supported for standalone completions, it is not supported in
|
||||||
|
# batch mode (multiple completions in single request).
|
||||||
|
if body.get("stream", False):
|
||||||
|
raise ValueError("Streaming requests are not supported in batch mode")
|
||||||
|
|
||||||
if end_point == "/v1/chat/completions":
|
if end_point == "/v1/chat/completions":
|
||||||
all_requests.append(ChatCompletionRequest(**body))
|
all_requests.append(ChatCompletionRequest(**body))
|
||||||
elif end_point == "/v1/completions":
|
elif end_point == "/v1/completions":
|
||||||
@@ -592,27 +598,45 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|||||||
if adapted_request.stream:
|
if adapted_request.stream:
|
||||||
|
|
||||||
async def generate_stream_resp():
|
async def generate_stream_resp():
|
||||||
stream_buffer = ""
|
stream_buffers = {}
|
||||||
n_prev_token = 0
|
n_prev_tokens = {}
|
||||||
|
prompt_tokens = {}
|
||||||
|
completion_tokens = {}
|
||||||
try:
|
try:
|
||||||
async for content in tokenizer_manager.generate_request(
|
async for content in tokenizer_manager.generate_request(
|
||||||
adapted_request, raw_request
|
adapted_request, raw_request
|
||||||
):
|
):
|
||||||
|
index = content["index"]
|
||||||
|
|
||||||
|
stream_buffer = stream_buffers.get(index, "")
|
||||||
|
n_prev_token = n_prev_tokens.get(index, 0)
|
||||||
|
|
||||||
text = content["text"]
|
text = content["text"]
|
||||||
prompt_tokens = content["meta_info"]["prompt_tokens"]
|
prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
|
||||||
completion_tokens = content["meta_info"]["completion_tokens"]
|
completion_tokens[index] = content["meta_info"]["completion_tokens"]
|
||||||
|
|
||||||
if not stream_buffer: # The first chunk
|
if not stream_buffer: # The first chunk
|
||||||
if request.echo:
|
if request.echo:
|
||||||
if isinstance(request.prompt, str):
|
if isinstance(request.prompt, str):
|
||||||
# for the case of single str prompts
|
# for the case of single str prompts
|
||||||
prompts = request.prompt
|
prompts = request.prompt
|
||||||
elif isinstance(request.prompt, list) and isinstance(
|
elif isinstance(request.prompt, list):
|
||||||
request.prompt[0], int
|
if isinstance(request.prompt[0], str):
|
||||||
):
|
# for the case of multiple str prompts
|
||||||
prompts = tokenizer_manager.tokenizer.decode(
|
prompts = request.prompt[index // request.n]
|
||||||
request.prompt, skip_special_tokens=True
|
elif isinstance(request.prompt[0], int):
|
||||||
)
|
# for the case of single token ids prompt
|
||||||
|
prompts = tokenizer_manager.tokenizer.decode(
|
||||||
|
request.prompt, skip_special_tokens=True
|
||||||
|
)
|
||||||
|
elif isinstance(request.prompt[0], list) and isinstance(
|
||||||
|
request.prompt[0][0], int
|
||||||
|
):
|
||||||
|
# for the case of multiple token ids prompts
|
||||||
|
prompts = tokenizer_manager.tokenizer.decode(
|
||||||
|
request.prompt[index // request.n],
|
||||||
|
skip_special_tokens=True,
|
||||||
|
)
|
||||||
|
|
||||||
# Prepend prompt in response text.
|
# Prepend prompt in response text.
|
||||||
text = prompts + text
|
text = prompts + text
|
||||||
@@ -649,7 +673,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|||||||
delta = text[len(stream_buffer) :]
|
delta = text[len(stream_buffer) :]
|
||||||
stream_buffer = stream_buffer + delta
|
stream_buffer = stream_buffer + delta
|
||||||
choice_data = CompletionResponseStreamChoice(
|
choice_data = CompletionResponseStreamChoice(
|
||||||
index=0,
|
index=index,
|
||||||
text=delta,
|
text=delta,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
finish_reason=format_finish_reason(
|
finish_reason=format_finish_reason(
|
||||||
@@ -662,12 +686,24 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|||||||
choices=[choice_data],
|
choices=[choice_data],
|
||||||
model=request.model,
|
model=request.model,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
stream_buffers[index] = stream_buffer
|
||||||
|
n_prev_tokens[index] = n_prev_token
|
||||||
|
|
||||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||||
if request.stream_options and request.stream_options.include_usage:
|
if request.stream_options and request.stream_options.include_usage:
|
||||||
|
total_prompt_tokens = sum(
|
||||||
|
tokens
|
||||||
|
for i, tokens in prompt_tokens.items()
|
||||||
|
if i % request.n == 0
|
||||||
|
)
|
||||||
|
total_completion_tokens = sum(
|
||||||
|
tokens for tokens in completion_tokens.values()
|
||||||
|
)
|
||||||
usage = UsageInfo(
|
usage = UsageInfo(
|
||||||
prompt_tokens=prompt_tokens,
|
prompt_tokens=total_prompt_tokens,
|
||||||
completion_tokens=completion_tokens,
|
completion_tokens=total_completion_tokens,
|
||||||
total_tokens=prompt_tokens + completion_tokens,
|
total_tokens=total_prompt_tokens + total_completion_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
final_usage_chunk = CompletionStreamResponse(
|
final_usage_chunk = CompletionStreamResponse(
|
||||||
@@ -914,16 +950,23 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|||||||
if adapted_request.stream:
|
if adapted_request.stream:
|
||||||
|
|
||||||
async def generate_stream_resp():
|
async def generate_stream_resp():
|
||||||
is_first = True
|
is_firsts = {}
|
||||||
|
stream_buffers = {}
|
||||||
stream_buffer = ""
|
n_prev_tokens = {}
|
||||||
n_prev_token = 0
|
prompt_tokens = {}
|
||||||
|
completion_tokens = {}
|
||||||
try:
|
try:
|
||||||
async for content in tokenizer_manager.generate_request(
|
async for content in tokenizer_manager.generate_request(
|
||||||
adapted_request, raw_request
|
adapted_request, raw_request
|
||||||
):
|
):
|
||||||
prompt_tokens = content["meta_info"]["prompt_tokens"]
|
index = content["index"]
|
||||||
completion_tokens = content["meta_info"]["completion_tokens"]
|
|
||||||
|
is_first = is_firsts.get(index, True)
|
||||||
|
stream_buffer = stream_buffers.get(index, "")
|
||||||
|
n_prev_token = n_prev_tokens.get(index, 0)
|
||||||
|
|
||||||
|
prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
|
||||||
|
completion_tokens[index] = content["meta_info"]["completion_tokens"]
|
||||||
if request.logprobs:
|
if request.logprobs:
|
||||||
logprobs = to_openai_style_logprobs(
|
logprobs = to_openai_style_logprobs(
|
||||||
output_token_logprobs=content["meta_info"][
|
output_token_logprobs=content["meta_info"][
|
||||||
@@ -973,7 +1016,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|||||||
# First chunk with role
|
# First chunk with role
|
||||||
is_first = False
|
is_first = False
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
index=0,
|
index=index,
|
||||||
delta=DeltaMessage(role="assistant"),
|
delta=DeltaMessage(role="assistant"),
|
||||||
finish_reason=format_finish_reason(
|
finish_reason=format_finish_reason(
|
||||||
content["meta_info"]["finish_reason"]
|
content["meta_info"]["finish_reason"]
|
||||||
@@ -991,7 +1034,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|||||||
delta = text[len(stream_buffer) :]
|
delta = text[len(stream_buffer) :]
|
||||||
stream_buffer = stream_buffer + delta
|
stream_buffer = stream_buffer + delta
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
index=0,
|
index=index,
|
||||||
delta=DeltaMessage(content=delta),
|
delta=DeltaMessage(content=delta),
|
||||||
finish_reason=format_finish_reason(
|
finish_reason=format_finish_reason(
|
||||||
content["meta_info"]["finish_reason"]
|
content["meta_info"]["finish_reason"]
|
||||||
@@ -1003,12 +1046,25 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|||||||
choices=[choice_data],
|
choices=[choice_data],
|
||||||
model=request.model,
|
model=request.model,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
is_firsts[index] = is_first
|
||||||
|
stream_buffers[index] = stream_buffer
|
||||||
|
n_prev_tokens[index] = n_prev_token
|
||||||
|
|
||||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||||
if request.stream_options and request.stream_options.include_usage:
|
if request.stream_options and request.stream_options.include_usage:
|
||||||
|
total_prompt_tokens = sum(
|
||||||
|
tokens
|
||||||
|
for i, tokens in prompt_tokens.items()
|
||||||
|
if i % request.n == 0
|
||||||
|
)
|
||||||
|
total_completion_tokens = sum(
|
||||||
|
tokens for tokens in completion_tokens.values()
|
||||||
|
)
|
||||||
usage = UsageInfo(
|
usage = UsageInfo(
|
||||||
prompt_tokens=prompt_tokens,
|
prompt_tokens=total_prompt_tokens,
|
||||||
completion_tokens=completion_tokens,
|
completion_tokens=total_completion_tokens,
|
||||||
total_tokens=prompt_tokens + completion_tokens,
|
total_tokens=total_prompt_tokens + total_completion_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
final_usage_chunk = ChatCompletionStreamResponse(
|
final_usage_chunk = ChatCompletionStreamResponse(
|
||||||
|
|||||||
@@ -85,13 +85,26 @@ class TestOpenAIServer(unittest.TestCase):
|
|||||||
assert response.usage.completion_tokens > 0
|
assert response.usage.completion_tokens > 0
|
||||||
assert response.usage.total_tokens > 0
|
assert response.usage.total_tokens > 0
|
||||||
|
|
||||||
def run_completion_stream(self, echo, logprobs, token_input):
|
def run_completion_stream(
|
||||||
|
self, echo, logprobs, use_list_input, parallel_sample_num, token_input
|
||||||
|
):
|
||||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
prompt = "The capital of France is"
|
prompt = "The capital of France is"
|
||||||
if token_input:
|
if token_input:
|
||||||
prompt_arg = self.tokenizer.encode(prompt)
|
prompt_input = self.tokenizer.encode(prompt)
|
||||||
|
num_prompt_tokens = len(prompt_input)
|
||||||
else:
|
else:
|
||||||
prompt_arg = prompt
|
prompt_input = prompt
|
||||||
|
num_prompt_tokens = len(self.tokenizer.encode(prompt))
|
||||||
|
|
||||||
|
if use_list_input:
|
||||||
|
prompt_arg = [prompt_input, prompt_input]
|
||||||
|
num_choices = len(prompt_arg)
|
||||||
|
num_prompt_tokens *= 2
|
||||||
|
else:
|
||||||
|
prompt_arg = prompt_input
|
||||||
|
num_choices = 1
|
||||||
|
|
||||||
generator = client.completions.create(
|
generator = client.completions.create(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
prompt=prompt_arg,
|
prompt=prompt_arg,
|
||||||
@@ -101,9 +114,10 @@ class TestOpenAIServer(unittest.TestCase):
|
|||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
stream=True,
|
stream=True,
|
||||||
stream_options={"include_usage": True},
|
stream_options={"include_usage": True},
|
||||||
|
n=parallel_sample_num,
|
||||||
)
|
)
|
||||||
|
|
||||||
first = True
|
is_firsts = {}
|
||||||
for response in generator:
|
for response in generator:
|
||||||
usage = response.usage
|
usage = response.usage
|
||||||
if usage is not None:
|
if usage is not None:
|
||||||
@@ -111,10 +125,14 @@ class TestOpenAIServer(unittest.TestCase):
|
|||||||
assert usage.completion_tokens > 0
|
assert usage.completion_tokens > 0
|
||||||
assert usage.total_tokens > 0
|
assert usage.total_tokens > 0
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
index = response.choices[0].index
|
||||||
|
is_first = is_firsts.get(index, True)
|
||||||
|
|
||||||
if logprobs:
|
if logprobs:
|
||||||
assert response.choices[0].logprobs
|
assert response.choices[0].logprobs
|
||||||
assert isinstance(response.choices[0].logprobs.tokens[0], str)
|
assert isinstance(response.choices[0].logprobs.tokens[0], str)
|
||||||
if not (first and echo):
|
if not (is_first and echo):
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
response.choices[0].logprobs.top_logprobs[0], dict
|
response.choices[0].logprobs.top_logprobs[0], dict
|
||||||
)
|
)
|
||||||
@@ -125,15 +143,20 @@ class TestOpenAIServer(unittest.TestCase):
|
|||||||
# assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
|
# assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
|
||||||
assert ret_num_top_logprobs > 0
|
assert ret_num_top_logprobs > 0
|
||||||
|
|
||||||
if first:
|
if is_first:
|
||||||
if echo:
|
if echo:
|
||||||
assert response.choices[0].text.startswith(
|
assert response.choices[0].text.startswith(
|
||||||
prompt
|
prompt
|
||||||
), f"{response.choices[0].text} and all args {echo} {logprobs} {token_input} {first}"
|
), f"{response.choices[0].text} and all args {echo} {logprobs} {token_input} {is_first}"
|
||||||
first = False
|
is_firsts[index] = False
|
||||||
assert response.id
|
assert response.id
|
||||||
assert response.created
|
assert response.created
|
||||||
|
|
||||||
|
for index in [i for i in range(parallel_sample_num * num_choices)]:
|
||||||
|
assert not is_firsts.get(
|
||||||
|
index, True
|
||||||
|
), f"index {index} is not found in the response"
|
||||||
|
|
||||||
def run_chat_completion(self, logprobs, parallel_sample_num):
|
def run_chat_completion(self, logprobs, parallel_sample_num):
|
||||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
response = client.chat.completions.create(
|
response = client.chat.completions.create(
|
||||||
@@ -172,7 +195,7 @@ class TestOpenAIServer(unittest.TestCase):
|
|||||||
assert response.usage.completion_tokens > 0
|
assert response.usage.completion_tokens > 0
|
||||||
assert response.usage.total_tokens > 0
|
assert response.usage.total_tokens > 0
|
||||||
|
|
||||||
def run_chat_completion_stream(self, logprobs):
|
def run_chat_completion_stream(self, logprobs, parallel_sample_num=1):
|
||||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
generator = client.chat.completions.create(
|
generator = client.chat.completions.create(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
@@ -185,9 +208,10 @@ class TestOpenAIServer(unittest.TestCase):
|
|||||||
top_logprobs=logprobs,
|
top_logprobs=logprobs,
|
||||||
stream=True,
|
stream=True,
|
||||||
stream_options={"include_usage": True},
|
stream_options={"include_usage": True},
|
||||||
|
n=parallel_sample_num,
|
||||||
)
|
)
|
||||||
|
|
||||||
is_first = True
|
is_firsts = {}
|
||||||
for response in generator:
|
for response in generator:
|
||||||
usage = response.usage
|
usage = response.usage
|
||||||
if usage is not None:
|
if usage is not None:
|
||||||
@@ -196,11 +220,12 @@ class TestOpenAIServer(unittest.TestCase):
|
|||||||
assert usage.total_tokens > 0
|
assert usage.total_tokens > 0
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
index = response.choices[0].index
|
||||||
data = response.choices[0].delta
|
data = response.choices[0].delta
|
||||||
|
|
||||||
if is_first:
|
if is_firsts.get(index, True):
|
||||||
data.role == "assistant"
|
assert data.role == "assistant"
|
||||||
is_first = False
|
is_firsts[index] = False
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if logprobs:
|
if logprobs:
|
||||||
@@ -222,6 +247,11 @@ class TestOpenAIServer(unittest.TestCase):
|
|||||||
assert response.id
|
assert response.id
|
||||||
assert response.created
|
assert response.created
|
||||||
|
|
||||||
|
for index in [i for i in range(parallel_sample_num)]:
|
||||||
|
assert not is_firsts.get(
|
||||||
|
index, True
|
||||||
|
), f"index {index} is not found in the response"
|
||||||
|
|
||||||
def run_batch(self, mode):
|
def run_batch(self, mode):
|
||||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
if mode == "completion":
|
if mode == "completion":
|
||||||
@@ -320,7 +350,9 @@ class TestOpenAIServer(unittest.TestCase):
|
|||||||
f"Batch job status: {batch_job.status}...trying again in 3 seconds..."
|
f"Batch job status: {batch_job.status}...trying again in 3 seconds..."
|
||||||
)
|
)
|
||||||
batch_job = client.batches.retrieve(batch_job.id)
|
batch_job = client.batches.retrieve(batch_job.id)
|
||||||
assert batch_job.status == "completed"
|
assert (
|
||||||
|
batch_job.status == "completed"
|
||||||
|
), f"Batch job status is not completed: {batch_job.status}"
|
||||||
assert batch_job.request_counts.completed == len(content)
|
assert batch_job.request_counts.completed == len(content)
|
||||||
assert batch_job.request_counts.failed == 0
|
assert batch_job.request_counts.failed == 0
|
||||||
assert batch_job.request_counts.total == len(content)
|
assert batch_job.request_counts.total == len(content)
|
||||||
@@ -353,8 +385,16 @@ class TestOpenAIServer(unittest.TestCase):
|
|||||||
# parallel sampling adn list input are not supported in streaming mode
|
# parallel sampling adn list input are not supported in streaming mode
|
||||||
for echo in [False, True]:
|
for echo in [False, True]:
|
||||||
for logprobs in [None, 5]:
|
for logprobs in [None, 5]:
|
||||||
for token_input in [False, True]:
|
for use_list_input in [True, False]:
|
||||||
self.run_completion_stream(echo, logprobs, token_input)
|
for parallel_sample_num in [1, 2]:
|
||||||
|
for token_input in [False, True]:
|
||||||
|
self.run_completion_stream(
|
||||||
|
echo,
|
||||||
|
logprobs,
|
||||||
|
use_list_input,
|
||||||
|
parallel_sample_num,
|
||||||
|
token_input,
|
||||||
|
)
|
||||||
|
|
||||||
def test_chat_completion(self):
|
def test_chat_completion(self):
|
||||||
for logprobs in [None, 5]:
|
for logprobs in [None, 5]:
|
||||||
@@ -363,7 +403,8 @@ class TestOpenAIServer(unittest.TestCase):
|
|||||||
|
|
||||||
def test_chat_completion_stream(self):
|
def test_chat_completion_stream(self):
|
||||||
for logprobs in [None, 5]:
|
for logprobs in [None, 5]:
|
||||||
self.run_chat_completion_stream(logprobs)
|
for parallel_sample_num in [1, 2]:
|
||||||
|
self.run_chat_completion_stream(logprobs, parallel_sample_num)
|
||||||
|
|
||||||
def test_batch(self):
|
def test_batch(self):
|
||||||
for mode in ["completion", "chat"]:
|
for mode in ["completion", "chat"]:
|
||||||
|
|||||||
@@ -23,7 +23,12 @@ class TestSRTEndpoint(unittest.TestCase):
|
|||||||
kill_child_process(cls.process.pid)
|
kill_child_process(cls.process.pid)
|
||||||
|
|
||||||
def run_decode(
|
def run_decode(
|
||||||
self, return_logprob=False, top_logprobs_num=0, return_text=False, n=1
|
self,
|
||||||
|
return_logprob=False,
|
||||||
|
top_logprobs_num=0,
|
||||||
|
return_text=False,
|
||||||
|
n=1,
|
||||||
|
stream=False,
|
||||||
):
|
):
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
self.base_url + "/generate",
|
self.base_url + "/generate",
|
||||||
@@ -34,14 +39,21 @@ class TestSRTEndpoint(unittest.TestCase):
|
|||||||
"max_new_tokens": 32,
|
"max_new_tokens": 32,
|
||||||
"n": n,
|
"n": n,
|
||||||
},
|
},
|
||||||
"stream": False,
|
"stream": stream,
|
||||||
"return_logprob": return_logprob,
|
"return_logprob": return_logprob,
|
||||||
"top_logprobs_num": top_logprobs_num,
|
"top_logprobs_num": top_logprobs_num,
|
||||||
"return_text_in_logprobs": return_text,
|
"return_text_in_logprobs": return_text,
|
||||||
"logprob_start_len": 0,
|
"logprob_start_len": 0,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
print(json.dumps(response.json()))
|
if not stream:
|
||||||
|
response_json = response.json()
|
||||||
|
else:
|
||||||
|
response_json = []
|
||||||
|
for line in response.iter_lines():
|
||||||
|
if line.startswith(b"data: ") and line[6:] != b"[DONE]":
|
||||||
|
response_json.append(json.loads(line[6:]))
|
||||||
|
print(json.dumps(response_json))
|
||||||
print("=" * 100)
|
print("=" * 100)
|
||||||
|
|
||||||
def test_simple_decode(self):
|
def test_simple_decode(self):
|
||||||
@@ -50,6 +62,9 @@ class TestSRTEndpoint(unittest.TestCase):
|
|||||||
def test_parallel_sample(self):
|
def test_parallel_sample(self):
|
||||||
self.run_decode(n=3)
|
self.run_decode(n=3)
|
||||||
|
|
||||||
|
def test_parallel_sample_stream(self):
|
||||||
|
self.run_decode(n=3, stream=True)
|
||||||
|
|
||||||
def test_logprob(self):
|
def test_logprob(self):
|
||||||
for top_logprobs_num in [0, 3]:
|
for top_logprobs_num in [0, 3]:
|
||||||
for return_text in [True, False]:
|
for return_text in [True, False]:
|
||||||
|
|||||||
Reference in New Issue
Block a user