support more optioin about usage in stream mode (#985)
Co-authored-by: Ying Sheng <sqy1415@gmail.com>
This commit is contained in:
@@ -217,7 +217,9 @@ class Req:
|
|||||||
return
|
return
|
||||||
|
|
||||||
if len(self.output_ids) >= self.sampling_params.max_new_tokens:
|
if len(self.output_ids) >= self.sampling_params.max_new_tokens:
|
||||||
self.finished_reason = FINISH_LENGTH(len(self.output_ids))
|
self.finished_reason = FINISH_LENGTH(
|
||||||
|
length=self.sampling_params.max_new_tokens
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
|||||||
@@ -84,6 +84,19 @@ file_id_storage: Dict[str, str] = {}
|
|||||||
storage_dir = None
|
storage_dir = None
|
||||||
|
|
||||||
|
|
||||||
|
def format_finish_reason(finish_reason) -> Optional[str]:
|
||||||
|
if finish_reason.startswith("None"):
|
||||||
|
return None
|
||||||
|
elif finish_reason.startswith("FINISH_MATCHED"):
|
||||||
|
return "stop"
|
||||||
|
elif finish_reason.startswith("FINISH_LENGTH"):
|
||||||
|
return "length"
|
||||||
|
elif finish_reason.startswith("FINISH_ABORT"):
|
||||||
|
return "abort"
|
||||||
|
else:
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
|
||||||
def create_error_response(
|
def create_error_response(
|
||||||
message: str,
|
message: str,
|
||||||
err_type: str = "BadRequestError",
|
err_type: str = "BadRequestError",
|
||||||
@@ -486,14 +499,18 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
|
|||||||
"index": 0,
|
"index": 0,
|
||||||
"text": text,
|
"text": text,
|
||||||
"logprobs": logprobs,
|
"logprobs": logprobs,
|
||||||
"finish_reason": ret_item["meta_info"]["finish_reason"],
|
"finish_reason": format_finish_reason(
|
||||||
|
ret_item["meta_info"]["finish_reason"]
|
||||||
|
),
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
choice_data = CompletionResponseChoice(
|
choice_data = CompletionResponseChoice(
|
||||||
index=idx,
|
index=idx,
|
||||||
text=text,
|
text=text,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
finish_reason=ret_item["meta_info"]["finish_reason"],
|
finish_reason=format_finish_reason(
|
||||||
|
ret_item["meta_info"]["finish_reason"]
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
choices.append(choice_data)
|
choices.append(choice_data)
|
||||||
@@ -608,20 +625,34 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|||||||
index=0,
|
index=0,
|
||||||
text=delta,
|
text=delta,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
finish_reason=content["meta_info"]["finish_reason"],
|
finish_reason=format_finish_reason(
|
||||||
|
content["meta_info"]["finish_reason"]
|
||||||
|
),
|
||||||
)
|
)
|
||||||
chunk = CompletionStreamResponse(
|
chunk = CompletionStreamResponse(
|
||||||
id=content["meta_info"]["id"],
|
id=content["meta_info"]["id"],
|
||||||
object="text_completion",
|
object="text_completion",
|
||||||
choices=[choice_data],
|
choices=[choice_data],
|
||||||
model=request.model,
|
model=request.model,
|
||||||
usage=UsageInfo(
|
|
||||||
prompt_tokens=prompt_tokens,
|
|
||||||
completion_tokens=completion_tokens,
|
|
||||||
total_tokens=prompt_tokens + completion_tokens,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||||
|
if request.stream_options and request.stream_options.include_usage:
|
||||||
|
usage = UsageInfo(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
final_usage_chunk = CompletionStreamResponse(
|
||||||
|
id=str(uuid.uuid4().hex),
|
||||||
|
choices=[],
|
||||||
|
model=request.model,
|
||||||
|
usage=usage,
|
||||||
|
)
|
||||||
|
final_usage_data = final_usage_chunk.model_dump_json(
|
||||||
|
exclude_unset=True, exclude_none=True
|
||||||
|
)
|
||||||
|
yield f"data: {final_usage_data}\n\n"
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
error = create_streaming_error_response(str(e))
|
error = create_streaming_error_response(str(e))
|
||||||
yield f"data: {error}\n\n"
|
yield f"data: {error}\n\n"
|
||||||
@@ -776,14 +807,18 @@ def v1_chat_generate_response(request, ret, to_file=False):
|
|||||||
"index": 0,
|
"index": 0,
|
||||||
"message": {"role": "assistant", "content": ret_item["text"]},
|
"message": {"role": "assistant", "content": ret_item["text"]},
|
||||||
"logprobs": choice_logprobs,
|
"logprobs": choice_logprobs,
|
||||||
"finish_reason": ret_item["meta_info"]["finish_reason"],
|
"finish_reason": format_finish_reason(
|
||||||
|
ret_item["meta_info"]["finish_reason"]
|
||||||
|
),
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
choice_data = ChatCompletionResponseChoice(
|
choice_data = ChatCompletionResponseChoice(
|
||||||
index=idx,
|
index=idx,
|
||||||
message=ChatMessage(role="assistant", content=ret_item["text"]),
|
message=ChatMessage(role="assistant", content=ret_item["text"]),
|
||||||
logprobs=choice_logprobs,
|
logprobs=choice_logprobs,
|
||||||
finish_reason=ret_item["meta_info"]["finish_reason"],
|
finish_reason=format_finish_reason(
|
||||||
|
ret_item["meta_info"]["finish_reason"]
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
choices.append(choice_data)
|
choices.append(choice_data)
|
||||||
@@ -900,18 +935,15 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
index=0,
|
index=0,
|
||||||
delta=DeltaMessage(role="assistant"),
|
delta=DeltaMessage(role="assistant"),
|
||||||
finish_reason=content["meta_info"]["finish_reason"],
|
finish_reason=format_finish_reason(
|
||||||
|
content["meta_info"]["finish_reason"]
|
||||||
|
),
|
||||||
logprobs=choice_logprobs,
|
logprobs=choice_logprobs,
|
||||||
)
|
)
|
||||||
chunk = ChatCompletionStreamResponse(
|
chunk = ChatCompletionStreamResponse(
|
||||||
id=content["meta_info"]["id"],
|
id=content["meta_info"]["id"],
|
||||||
choices=[choice_data],
|
choices=[choice_data],
|
||||||
model=request.model,
|
model=request.model,
|
||||||
usage=UsageInfo(
|
|
||||||
prompt_tokens=prompt_tokens,
|
|
||||||
completion_tokens=completion_tokens,
|
|
||||||
total_tokens=prompt_tokens + completion_tokens,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||||
|
|
||||||
@@ -921,20 +953,34 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
index=0,
|
index=0,
|
||||||
delta=DeltaMessage(content=delta),
|
delta=DeltaMessage(content=delta),
|
||||||
finish_reason=content["meta_info"]["finish_reason"],
|
finish_reason=format_finish_reason(
|
||||||
|
content["meta_info"]["finish_reason"]
|
||||||
|
),
|
||||||
logprobs=choice_logprobs,
|
logprobs=choice_logprobs,
|
||||||
)
|
)
|
||||||
chunk = ChatCompletionStreamResponse(
|
chunk = ChatCompletionStreamResponse(
|
||||||
id=content["meta_info"]["id"],
|
id=content["meta_info"]["id"],
|
||||||
choices=[choice_data],
|
choices=[choice_data],
|
||||||
model=request.model,
|
model=request.model,
|
||||||
usage=UsageInfo(
|
|
||||||
prompt_tokens=prompt_tokens,
|
|
||||||
completion_tokens=completion_tokens,
|
|
||||||
total_tokens=prompt_tokens + completion_tokens,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||||
|
if request.stream_options and request.stream_options.include_usage:
|
||||||
|
usage = UsageInfo(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
final_usage_chunk = ChatCompletionStreamResponse(
|
||||||
|
id=str(uuid.uuid4().hex),
|
||||||
|
choices=[],
|
||||||
|
model=request.model,
|
||||||
|
usage=usage,
|
||||||
|
)
|
||||||
|
final_usage_data = final_usage_chunk.model_dump_json(
|
||||||
|
exclude_unset=True, exclude_none=True
|
||||||
|
)
|
||||||
|
yield f"data: {final_usage_data}\n\n"
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
error = create_streaming_error_response(str(e))
|
error = create_streaming_error_response(str(e))
|
||||||
yield f"data: {error}\n\n"
|
yield f"data: {error}\n\n"
|
||||||
|
|||||||
@@ -78,6 +78,10 @@ class UsageInfo(BaseModel):
|
|||||||
completion_tokens: Optional[int] = 0
|
completion_tokens: Optional[int] = 0
|
||||||
|
|
||||||
|
|
||||||
|
class StreamOptions(BaseModel):
|
||||||
|
include_usage: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
class FileRequest(BaseModel):
|
class FileRequest(BaseModel):
|
||||||
# https://platform.openai.com/docs/api-reference/files/create
|
# https://platform.openai.com/docs/api-reference/files/create
|
||||||
file: bytes # The File object (not file name) to be uploaded
|
file: bytes # The File object (not file name) to be uploaded
|
||||||
@@ -149,6 +153,7 @@ class CompletionRequest(BaseModel):
|
|||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
|
stream_options: Optional[StreamOptions] = None
|
||||||
suffix: Optional[str] = None
|
suffix: Optional[str] = None
|
||||||
temperature: Optional[float] = 1.0
|
temperature: Optional[float] = 1.0
|
||||||
top_p: Optional[float] = 1.0
|
top_p: Optional[float] = 1.0
|
||||||
@@ -188,7 +193,7 @@ class CompletionStreamResponse(BaseModel):
|
|||||||
created: int = Field(default_factory=lambda: int(time.time()))
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
model: str
|
model: str
|
||||||
choices: List[CompletionResponseStreamChoice]
|
choices: List[CompletionResponseStreamChoice]
|
||||||
usage: UsageInfo
|
usage: Optional[UsageInfo] = None
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionMessageGenericParam(BaseModel):
|
class ChatCompletionMessageGenericParam(BaseModel):
|
||||||
@@ -247,6 +252,7 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
|
stream_options: Optional[StreamOptions] = None
|
||||||
temperature: Optional[float] = 0.7
|
temperature: Optional[float] = 0.7
|
||||||
top_p: Optional[float] = 1.0
|
top_p: Optional[float] = 1.0
|
||||||
user: Optional[str] = None
|
user: Optional[str] = None
|
||||||
@@ -294,6 +300,7 @@ class ChatCompletionStreamResponse(BaseModel):
|
|||||||
created: int = Field(default_factory=lambda: int(time.time()))
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
model: str
|
model: str
|
||||||
choices: List[ChatCompletionResponseStreamChoice]
|
choices: List[ChatCompletionResponseStreamChoice]
|
||||||
|
usage: Optional[UsageInfo] = None
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingRequest(BaseModel):
|
class EmbeddingRequest(BaseModel):
|
||||||
@@ -310,3 +317,4 @@ class EmbeddingResponse(BaseModel):
|
|||||||
index: str
|
index: str
|
||||||
embedding: List[float] = None
|
embedding: List[float] = None
|
||||||
object: str = "embedding"
|
object: str = "embedding"
|
||||||
|
usage: Optional[UsageInfo] = None
|
||||||
|
|||||||
@@ -98,10 +98,17 @@ class TestOpenAIServer(unittest.TestCase):
|
|||||||
echo=echo,
|
echo=echo,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
stream=True,
|
stream=True,
|
||||||
|
stream_options={"include_usage": True},
|
||||||
)
|
)
|
||||||
|
|
||||||
first = True
|
first = True
|
||||||
for response in generator:
|
for response in generator:
|
||||||
|
usage = response.usage
|
||||||
|
if usage is not None:
|
||||||
|
assert usage.prompt_tokens > 0
|
||||||
|
assert usage.completion_tokens > 0
|
||||||
|
assert usage.total_tokens > 0
|
||||||
|
continue
|
||||||
if logprobs:
|
if logprobs:
|
||||||
assert response.choices[0].logprobs
|
assert response.choices[0].logprobs
|
||||||
assert isinstance(response.choices[0].logprobs.tokens[0], str)
|
assert isinstance(response.choices[0].logprobs.tokens[0], str)
|
||||||
@@ -122,12 +129,8 @@ class TestOpenAIServer(unittest.TestCase):
|
|||||||
prompt
|
prompt
|
||||||
), f"{response.choices[0].text} and all args {echo} {logprobs} {token_input} {first}"
|
), f"{response.choices[0].text} and all args {echo} {logprobs} {token_input} {first}"
|
||||||
first = False
|
first = False
|
||||||
|
|
||||||
assert response.id
|
assert response.id
|
||||||
assert response.created
|
assert response.created
|
||||||
assert response.usage.prompt_tokens > 0
|
|
||||||
assert response.usage.completion_tokens > 0
|
|
||||||
assert response.usage.total_tokens > 0
|
|
||||||
|
|
||||||
def run_chat_completion(self, logprobs, parallel_sample_num):
|
def run_chat_completion(self, logprobs, parallel_sample_num):
|
||||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
@@ -179,11 +182,20 @@ class TestOpenAIServer(unittest.TestCase):
|
|||||||
logprobs=logprobs is not None and logprobs > 0,
|
logprobs=logprobs is not None and logprobs > 0,
|
||||||
top_logprobs=logprobs,
|
top_logprobs=logprobs,
|
||||||
stream=True,
|
stream=True,
|
||||||
|
stream_options={"include_usage": True},
|
||||||
)
|
)
|
||||||
|
|
||||||
is_first = True
|
is_first = True
|
||||||
for response in generator:
|
for response in generator:
|
||||||
|
usage = response.usage
|
||||||
|
if usage is not None:
|
||||||
|
assert usage.prompt_tokens > 0
|
||||||
|
assert usage.completion_tokens > 0
|
||||||
|
assert usage.total_tokens > 0
|
||||||
|
continue
|
||||||
|
|
||||||
data = response.choices[0].delta
|
data = response.choices[0].delta
|
||||||
|
|
||||||
if is_first:
|
if is_first:
|
||||||
data.role == "assistant"
|
data.role == "assistant"
|
||||||
is_first = False
|
is_first = False
|
||||||
|
|||||||
Reference in New Issue
Block a user