[OAI Server Refactor] [ChatCompletions & Completions] Implement UsageInfo Processor (#7360)
Co-authored-by: Chang Su <chang.s.su@oracle.com>
This commit is contained in:
@@ -192,6 +192,17 @@ async def v1_score_request(raw_request: Request):
|
||||
pass
|
||||
|
||||
|
||||
@app.api_route("/v1/models/{model_id}", methods=["GET"])
|
||||
async def show_model_detail(model_id: str):
|
||||
served_model_name = app.state.tokenizer_manager.served_model_name
|
||||
|
||||
return ModelCard(
|
||||
id=served_model_name,
|
||||
root=served_model_name,
|
||||
max_model_len=app.state.tokenizer_manager.model_config.context_len,
|
||||
)
|
||||
|
||||
|
||||
# Additional API endpoints will be implemented in separate serving_*.py modules
|
||||
# and mounted as APIRouters in future PRs
|
||||
|
||||
|
||||
@@ -114,33 +114,6 @@ class OpenAIServingBase(ABC):
|
||||
"""Validate request"""
|
||||
pass
|
||||
|
||||
def _calculate_streaming_usage_base(
|
||||
self,
|
||||
prompt_tokens: Dict[int, int],
|
||||
completion_tokens: Dict[int, int],
|
||||
cached_tokens: Dict[int, int],
|
||||
n_choices: int,
|
||||
) -> UsageInfo:
|
||||
"""Calculate usage information for streaming responses (common logic)"""
|
||||
total_prompt_tokens = sum(
|
||||
tokens for i, tokens in prompt_tokens.items() if i % n_choices == 0
|
||||
)
|
||||
total_completion_tokens = sum(tokens for tokens in completion_tokens.values())
|
||||
|
||||
cache_report = self.tokenizer_manager.server_args.enable_cache_report
|
||||
prompt_tokens_details = None
|
||||
if cache_report:
|
||||
cached_tokens_sum = sum(tokens for tokens in cached_tokens.values())
|
||||
if cached_tokens_sum > 0:
|
||||
prompt_tokens_details = {"cached_tokens": cached_tokens_sum}
|
||||
|
||||
return UsageInfo(
|
||||
prompt_tokens=total_prompt_tokens,
|
||||
completion_tokens=total_completion_tokens,
|
||||
total_tokens=total_prompt_tokens + total_completion_tokens,
|
||||
prompt_tokens_details=prompt_tokens_details,
|
||||
)
|
||||
|
||||
def create_error_response(
|
||||
self,
|
||||
message: str,
|
||||
|
||||
@@ -26,8 +26,8 @@ from sglang.srt.entrypoints.openai.protocol import (
|
||||
TopLogprob,
|
||||
)
|
||||
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
|
||||
from sglang.srt.entrypoints.openai.usage_processor import UsageProcessor
|
||||
from sglang.srt.entrypoints.openai.utils import (
|
||||
aggregate_token_usage,
|
||||
detect_template_content_format,
|
||||
process_content_for_template_format,
|
||||
to_openai_style_logprobs,
|
||||
@@ -546,11 +546,12 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
|
||||
# Additional usage chunk
|
||||
if request.stream_options and request.stream_options.include_usage:
|
||||
usage = self._calculate_streaming_usage_base(
|
||||
usage = UsageProcessor.calculate_streaming_usage(
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
cached_tokens,
|
||||
request.n,
|
||||
n_choices=request.n,
|
||||
enable_cache_report=self.tokenizer_manager.server_args.enable_cache_report,
|
||||
)
|
||||
usage_chunk = ChatCompletionStreamResponse(
|
||||
id=content["meta_info"]["id"],
|
||||
@@ -658,7 +659,9 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
|
||||
# Calculate usage
|
||||
cache_report = self.tokenizer_manager.server_args.enable_cache_report
|
||||
usage = aggregate_token_usage(ret, request.n, cache_report)
|
||||
usage = UsageProcessor.calculate_response_usage(
|
||||
ret, n_choices=request.n, enable_cache_report=cache_report
|
||||
)
|
||||
|
||||
return ChatCompletionResponse(
|
||||
id=ret[0]["meta_info"]["id"],
|
||||
|
||||
@@ -18,10 +18,8 @@ from sglang.srt.entrypoints.openai.protocol import (
|
||||
ErrorResponse,
|
||||
)
|
||||
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
|
||||
from sglang.srt.entrypoints.openai.utils import (
|
||||
aggregate_token_usage,
|
||||
to_openai_style_logprobs,
|
||||
)
|
||||
from sglang.srt.entrypoints.openai.usage_processor import UsageProcessor
|
||||
from sglang.srt.entrypoints.openai.utils import to_openai_style_logprobs
|
||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -214,11 +212,12 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
||||
|
||||
# Handle final usage chunk
|
||||
if request.stream_options and request.stream_options.include_usage:
|
||||
usage = self._calculate_streaming_usage_base(
|
||||
usage = UsageProcessor.calculate_streaming_usage(
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
cached_tokens,
|
||||
request.n,
|
||||
n_choices=request.n,
|
||||
enable_cache_report=self.tokenizer_manager.server_args.enable_cache_report,
|
||||
)
|
||||
final_usage_chunk = CompletionStreamResponse(
|
||||
id=content["meta_info"]["id"],
|
||||
@@ -322,7 +321,9 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
||||
|
||||
# Calculate usage
|
||||
cache_report = self.tokenizer_manager.server_args.enable_cache_report
|
||||
usage = aggregate_token_usage(ret, request.n, cache_report)
|
||||
usage = UsageProcessor.calculate_response_usage(
|
||||
ret, n_choices=request.n, enable_cache_report=cache_report
|
||||
)
|
||||
|
||||
return CompletionResponse(
|
||||
id=ret[0]["meta_info"]["id"],
|
||||
|
||||
81
python/sglang/srt/entrypoints/openai/usage_processor.py
Normal file
81
python/sglang/srt/entrypoints/openai/usage_processor.py
Normal file
@@ -0,0 +1,81 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Mapping, Optional, final
|
||||
|
||||
from python.sglang.srt.entrypoints.openai.protocol import UsageInfo
|
||||
|
||||
|
||||
@final
|
||||
class UsageProcessor:
|
||||
"""Stateless helpers that turn raw token counts into a UsageInfo."""
|
||||
|
||||
@staticmethod
|
||||
def _details_if_cached(count: int) -> Optional[Dict[str, int]]:
|
||||
"""Return {"cached_tokens": N} only when N > 0 (keeps JSON slim)."""
|
||||
return {"cached_tokens": count} if count > 0 else None
|
||||
|
||||
@staticmethod
|
||||
def calculate_response_usage(
|
||||
responses: List[Dict[str, Any]],
|
||||
n_choices: int = 1,
|
||||
enable_cache_report: bool = False,
|
||||
) -> UsageInfo:
|
||||
completion_tokens = sum(r["meta_info"]["completion_tokens"] for r in responses)
|
||||
|
||||
prompt_tokens = sum(
|
||||
responses[i]["meta_info"]["prompt_tokens"]
|
||||
for i in range(0, len(responses), n_choices)
|
||||
)
|
||||
|
||||
cached_details = None
|
||||
if enable_cache_report:
|
||||
cached_total = sum(
|
||||
r["meta_info"].get("cached_tokens", 0) for r in responses
|
||||
)
|
||||
cached_details = UsageProcessor._details_if_cached(cached_total)
|
||||
|
||||
return UsageProcessor.calculate_token_usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
cached_tokens=cached_details,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def calculate_streaming_usage(
|
||||
prompt_tokens: Mapping[int, int],
|
||||
completion_tokens: Mapping[int, int],
|
||||
cached_tokens: Mapping[int, int],
|
||||
n_choices: int,
|
||||
enable_cache_report: bool = False,
|
||||
) -> UsageInfo:
|
||||
# index % n_choices == 0 marks the first choice of a prompt
|
||||
total_prompt_tokens = sum(
|
||||
tok for idx, tok in prompt_tokens.items() if idx % n_choices == 0
|
||||
)
|
||||
total_completion_tokens = sum(completion_tokens.values())
|
||||
|
||||
cached_details = (
|
||||
UsageProcessor._details_if_cached(sum(cached_tokens.values()))
|
||||
if enable_cache_report
|
||||
else None
|
||||
)
|
||||
|
||||
return UsageProcessor.calculate_token_usage(
|
||||
prompt_tokens=total_prompt_tokens,
|
||||
completion_tokens=total_completion_tokens,
|
||||
cached_tokens=cached_details,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def calculate_token_usage(
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
cached_tokens: Optional[Dict[str, int]] = None,
|
||||
) -> UsageInfo:
|
||||
"""Calculate token usage information"""
|
||||
return UsageInfo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
prompt_tokens_details=cached_tokens,
|
||||
)
|
||||
@@ -1,10 +1,9 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import jinja2.nodes
|
||||
import transformers.utils.chat_template_utils as hf_chat_utils
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import LogProbs, UsageInfo
|
||||
from sglang.srt.entrypoints.openai.protocol import LogProbs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -171,62 +170,6 @@ def process_content_for_template_format(
|
||||
return new_msg
|
||||
|
||||
|
||||
def calculate_token_usage(
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
cached_tokens: Optional[Dict[str, int]] = None,
|
||||
) -> UsageInfo:
|
||||
"""Calculate token usage information"""
|
||||
return UsageInfo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
prompt_tokens_details=cached_tokens,
|
||||
)
|
||||
|
||||
|
||||
def aggregate_token_usage(
|
||||
responses: List[Dict[str, Any]],
|
||||
n_choices: int = 1,
|
||||
enable_cache_report: bool = False,
|
||||
) -> UsageInfo:
|
||||
"""Aggregate token usage from multiple responses
|
||||
|
||||
Args:
|
||||
responses: List of response dictionaries with meta_info
|
||||
n_choices: Number of choices per request (for prompt token counting)
|
||||
enable_cache_report: Whether to include cached token details
|
||||
|
||||
Returns:
|
||||
Aggregated UsageInfo
|
||||
"""
|
||||
# Sum completion tokens from all responses
|
||||
completion_tokens = sum(
|
||||
response["meta_info"]["completion_tokens"] for response in responses
|
||||
)
|
||||
|
||||
# For prompt tokens, only count every n_choices-th response to avoid double counting
|
||||
prompt_tokens = sum(
|
||||
responses[i]["meta_info"]["prompt_tokens"]
|
||||
for i in range(0, len(responses), n_choices)
|
||||
)
|
||||
|
||||
# Handle cached tokens if cache reporting is enabled
|
||||
cached_tokens_details = None
|
||||
if enable_cache_report:
|
||||
cached_tokens_sum = sum(
|
||||
response["meta_info"].get("cached_tokens", 0) for response in responses
|
||||
)
|
||||
if cached_tokens_sum > 0:
|
||||
cached_tokens_details = {"cached_tokens": cached_tokens_sum}
|
||||
|
||||
return calculate_token_usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
cached_tokens=cached_tokens_details,
|
||||
)
|
||||
|
||||
|
||||
def to_openai_style_logprobs(
|
||||
input_token_logprobs=None,
|
||||
output_token_logprobs=None,
|
||||
|
||||
Reference in New Issue
Block a user