[OAI Server Refactor] [ChatCompletions & Completions] Implement UsageInfo Processor (#7360)
Co-authored-by: Chang Su <chang.s.su@oracle.com>
This commit is contained in:
@@ -192,6 +192,17 @@ async def v1_score_request(raw_request: Request):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@app.api_route("/v1/models/{model_id}", methods=["GET"])
|
||||||
|
async def show_model_detail(model_id: str):
|
||||||
|
served_model_name = app.state.tokenizer_manager.served_model_name
|
||||||
|
|
||||||
|
return ModelCard(
|
||||||
|
id=served_model_name,
|
||||||
|
root=served_model_name,
|
||||||
|
max_model_len=app.state.tokenizer_manager.model_config.context_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Additional API endpoints will be implemented in separate serving_*.py modules
|
# Additional API endpoints will be implemented in separate serving_*.py modules
|
||||||
# and mounted as APIRouters in future PRs
|
# and mounted as APIRouters in future PRs
|
||||||
|
|
||||||
|
|||||||
@@ -114,33 +114,6 @@ class OpenAIServingBase(ABC):
|
|||||||
"""Validate request"""
|
"""Validate request"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _calculate_streaming_usage_base(
|
|
||||||
self,
|
|
||||||
prompt_tokens: Dict[int, int],
|
|
||||||
completion_tokens: Dict[int, int],
|
|
||||||
cached_tokens: Dict[int, int],
|
|
||||||
n_choices: int,
|
|
||||||
) -> UsageInfo:
|
|
||||||
"""Calculate usage information for streaming responses (common logic)"""
|
|
||||||
total_prompt_tokens = sum(
|
|
||||||
tokens for i, tokens in prompt_tokens.items() if i % n_choices == 0
|
|
||||||
)
|
|
||||||
total_completion_tokens = sum(tokens for tokens in completion_tokens.values())
|
|
||||||
|
|
||||||
cache_report = self.tokenizer_manager.server_args.enable_cache_report
|
|
||||||
prompt_tokens_details = None
|
|
||||||
if cache_report:
|
|
||||||
cached_tokens_sum = sum(tokens for tokens in cached_tokens.values())
|
|
||||||
if cached_tokens_sum > 0:
|
|
||||||
prompt_tokens_details = {"cached_tokens": cached_tokens_sum}
|
|
||||||
|
|
||||||
return UsageInfo(
|
|
||||||
prompt_tokens=total_prompt_tokens,
|
|
||||||
completion_tokens=total_completion_tokens,
|
|
||||||
total_tokens=total_prompt_tokens + total_completion_tokens,
|
|
||||||
prompt_tokens_details=prompt_tokens_details,
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_error_response(
|
def create_error_response(
|
||||||
self,
|
self,
|
||||||
message: str,
|
message: str,
|
||||||
|
|||||||
@@ -26,8 +26,8 @@ from sglang.srt.entrypoints.openai.protocol import (
|
|||||||
TopLogprob,
|
TopLogprob,
|
||||||
)
|
)
|
||||||
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
|
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
|
||||||
|
from sglang.srt.entrypoints.openai.usage_processor import UsageProcessor
|
||||||
from sglang.srt.entrypoints.openai.utils import (
|
from sglang.srt.entrypoints.openai.utils import (
|
||||||
aggregate_token_usage,
|
|
||||||
detect_template_content_format,
|
detect_template_content_format,
|
||||||
process_content_for_template_format,
|
process_content_for_template_format,
|
||||||
to_openai_style_logprobs,
|
to_openai_style_logprobs,
|
||||||
@@ -546,11 +546,12 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
|
|
||||||
# Additional usage chunk
|
# Additional usage chunk
|
||||||
if request.stream_options and request.stream_options.include_usage:
|
if request.stream_options and request.stream_options.include_usage:
|
||||||
usage = self._calculate_streaming_usage_base(
|
usage = UsageProcessor.calculate_streaming_usage(
|
||||||
prompt_tokens,
|
prompt_tokens,
|
||||||
completion_tokens,
|
completion_tokens,
|
||||||
cached_tokens,
|
cached_tokens,
|
||||||
request.n,
|
n_choices=request.n,
|
||||||
|
enable_cache_report=self.tokenizer_manager.server_args.enable_cache_report,
|
||||||
)
|
)
|
||||||
usage_chunk = ChatCompletionStreamResponse(
|
usage_chunk = ChatCompletionStreamResponse(
|
||||||
id=content["meta_info"]["id"],
|
id=content["meta_info"]["id"],
|
||||||
@@ -658,7 +659,9 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
|
|
||||||
# Calculate usage
|
# Calculate usage
|
||||||
cache_report = self.tokenizer_manager.server_args.enable_cache_report
|
cache_report = self.tokenizer_manager.server_args.enable_cache_report
|
||||||
usage = aggregate_token_usage(ret, request.n, cache_report)
|
usage = UsageProcessor.calculate_response_usage(
|
||||||
|
ret, n_choices=request.n, enable_cache_report=cache_report
|
||||||
|
)
|
||||||
|
|
||||||
return ChatCompletionResponse(
|
return ChatCompletionResponse(
|
||||||
id=ret[0]["meta_info"]["id"],
|
id=ret[0]["meta_info"]["id"],
|
||||||
|
|||||||
@@ -18,10 +18,8 @@ from sglang.srt.entrypoints.openai.protocol import (
|
|||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
)
|
)
|
||||||
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
|
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
|
||||||
from sglang.srt.entrypoints.openai.utils import (
|
from sglang.srt.entrypoints.openai.usage_processor import UsageProcessor
|
||||||
aggregate_token_usage,
|
from sglang.srt.entrypoints.openai.utils import to_openai_style_logprobs
|
||||||
to_openai_style_logprobs,
|
|
||||||
)
|
|
||||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -214,11 +212,12 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
|||||||
|
|
||||||
# Handle final usage chunk
|
# Handle final usage chunk
|
||||||
if request.stream_options and request.stream_options.include_usage:
|
if request.stream_options and request.stream_options.include_usage:
|
||||||
usage = self._calculate_streaming_usage_base(
|
usage = UsageProcessor.calculate_streaming_usage(
|
||||||
prompt_tokens,
|
prompt_tokens,
|
||||||
completion_tokens,
|
completion_tokens,
|
||||||
cached_tokens,
|
cached_tokens,
|
||||||
request.n,
|
n_choices=request.n,
|
||||||
|
enable_cache_report=self.tokenizer_manager.server_args.enable_cache_report,
|
||||||
)
|
)
|
||||||
final_usage_chunk = CompletionStreamResponse(
|
final_usage_chunk = CompletionStreamResponse(
|
||||||
id=content["meta_info"]["id"],
|
id=content["meta_info"]["id"],
|
||||||
@@ -322,7 +321,9 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
|||||||
|
|
||||||
# Calculate usage
|
# Calculate usage
|
||||||
cache_report = self.tokenizer_manager.server_args.enable_cache_report
|
cache_report = self.tokenizer_manager.server_args.enable_cache_report
|
||||||
usage = aggregate_token_usage(ret, request.n, cache_report)
|
usage = UsageProcessor.calculate_response_usage(
|
||||||
|
ret, n_choices=request.n, enable_cache_report=cache_report
|
||||||
|
)
|
||||||
|
|
||||||
return CompletionResponse(
|
return CompletionResponse(
|
||||||
id=ret[0]["meta_info"]["id"],
|
id=ret[0]["meta_info"]["id"],
|
||||||
|
|||||||
81
python/sglang/srt/entrypoints/openai/usage_processor.py
Normal file
81
python/sglang/srt/entrypoints/openai/usage_processor.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Mapping, Optional, final
|
||||||
|
|
||||||
|
from python.sglang.srt.entrypoints.openai.protocol import UsageInfo
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
|
class UsageProcessor:
|
||||||
|
"""Stateless helpers that turn raw token counts into a UsageInfo."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _details_if_cached(count: int) -> Optional[Dict[str, int]]:
|
||||||
|
"""Return {"cached_tokens": N} only when N > 0 (keeps JSON slim)."""
|
||||||
|
return {"cached_tokens": count} if count > 0 else None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def calculate_response_usage(
|
||||||
|
responses: List[Dict[str, Any]],
|
||||||
|
n_choices: int = 1,
|
||||||
|
enable_cache_report: bool = False,
|
||||||
|
) -> UsageInfo:
|
||||||
|
completion_tokens = sum(r["meta_info"]["completion_tokens"] for r in responses)
|
||||||
|
|
||||||
|
prompt_tokens = sum(
|
||||||
|
responses[i]["meta_info"]["prompt_tokens"]
|
||||||
|
for i in range(0, len(responses), n_choices)
|
||||||
|
)
|
||||||
|
|
||||||
|
cached_details = None
|
||||||
|
if enable_cache_report:
|
||||||
|
cached_total = sum(
|
||||||
|
r["meta_info"].get("cached_tokens", 0) for r in responses
|
||||||
|
)
|
||||||
|
cached_details = UsageProcessor._details_if_cached(cached_total)
|
||||||
|
|
||||||
|
return UsageProcessor.calculate_token_usage(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
cached_tokens=cached_details,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def calculate_streaming_usage(
|
||||||
|
prompt_tokens: Mapping[int, int],
|
||||||
|
completion_tokens: Mapping[int, int],
|
||||||
|
cached_tokens: Mapping[int, int],
|
||||||
|
n_choices: int,
|
||||||
|
enable_cache_report: bool = False,
|
||||||
|
) -> UsageInfo:
|
||||||
|
# index % n_choices == 0 marks the first choice of a prompt
|
||||||
|
total_prompt_tokens = sum(
|
||||||
|
tok for idx, tok in prompt_tokens.items() if idx % n_choices == 0
|
||||||
|
)
|
||||||
|
total_completion_tokens = sum(completion_tokens.values())
|
||||||
|
|
||||||
|
cached_details = (
|
||||||
|
UsageProcessor._details_if_cached(sum(cached_tokens.values()))
|
||||||
|
if enable_cache_report
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
return UsageProcessor.calculate_token_usage(
|
||||||
|
prompt_tokens=total_prompt_tokens,
|
||||||
|
completion_tokens=total_completion_tokens,
|
||||||
|
cached_tokens=cached_details,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def calculate_token_usage(
|
||||||
|
prompt_tokens: int,
|
||||||
|
completion_tokens: int,
|
||||||
|
cached_tokens: Optional[Dict[str, int]] = None,
|
||||||
|
) -> UsageInfo:
|
||||||
|
"""Calculate token usage information"""
|
||||||
|
return UsageInfo(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
|
prompt_tokens_details=cached_tokens,
|
||||||
|
)
|
||||||
@@ -1,10 +1,9 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
import jinja2.nodes
|
import jinja2.nodes
|
||||||
import transformers.utils.chat_template_utils as hf_chat_utils
|
import transformers.utils.chat_template_utils as hf_chat_utils
|
||||||
|
|
||||||
from sglang.srt.entrypoints.openai.protocol import LogProbs, UsageInfo
|
from sglang.srt.entrypoints.openai.protocol import LogProbs
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -171,62 +170,6 @@ def process_content_for_template_format(
|
|||||||
return new_msg
|
return new_msg
|
||||||
|
|
||||||
|
|
||||||
def calculate_token_usage(
|
|
||||||
prompt_tokens: int,
|
|
||||||
completion_tokens: int,
|
|
||||||
cached_tokens: Optional[Dict[str, int]] = None,
|
|
||||||
) -> UsageInfo:
|
|
||||||
"""Calculate token usage information"""
|
|
||||||
return UsageInfo(
|
|
||||||
prompt_tokens=prompt_tokens,
|
|
||||||
completion_tokens=completion_tokens,
|
|
||||||
total_tokens=prompt_tokens + completion_tokens,
|
|
||||||
prompt_tokens_details=cached_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def aggregate_token_usage(
|
|
||||||
responses: List[Dict[str, Any]],
|
|
||||||
n_choices: int = 1,
|
|
||||||
enable_cache_report: bool = False,
|
|
||||||
) -> UsageInfo:
|
|
||||||
"""Aggregate token usage from multiple responses
|
|
||||||
|
|
||||||
Args:
|
|
||||||
responses: List of response dictionaries with meta_info
|
|
||||||
n_choices: Number of choices per request (for prompt token counting)
|
|
||||||
enable_cache_report: Whether to include cached token details
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Aggregated UsageInfo
|
|
||||||
"""
|
|
||||||
# Sum completion tokens from all responses
|
|
||||||
completion_tokens = sum(
|
|
||||||
response["meta_info"]["completion_tokens"] for response in responses
|
|
||||||
)
|
|
||||||
|
|
||||||
# For prompt tokens, only count every n_choices-th response to avoid double counting
|
|
||||||
prompt_tokens = sum(
|
|
||||||
responses[i]["meta_info"]["prompt_tokens"]
|
|
||||||
for i in range(0, len(responses), n_choices)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Handle cached tokens if cache reporting is enabled
|
|
||||||
cached_tokens_details = None
|
|
||||||
if enable_cache_report:
|
|
||||||
cached_tokens_sum = sum(
|
|
||||||
response["meta_info"].get("cached_tokens", 0) for response in responses
|
|
||||||
)
|
|
||||||
if cached_tokens_sum > 0:
|
|
||||||
cached_tokens_details = {"cached_tokens": cached_tokens_sum}
|
|
||||||
|
|
||||||
return calculate_token_usage(
|
|
||||||
prompt_tokens=prompt_tokens,
|
|
||||||
completion_tokens=completion_tokens,
|
|
||||||
cached_tokens=cached_tokens_details,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def to_openai_style_logprobs(
|
def to_openai_style_logprobs(
|
||||||
input_token_logprobs=None,
|
input_token_logprobs=None,
|
||||||
output_token_logprobs=None,
|
output_token_logprobs=None,
|
||||||
|
|||||||
Reference in New Issue
Block a user