diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index b51c12816..affa720f5 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -500,7 +500,9 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False): responses.append(response) return responses else: - prompt_tokens = sum(item["meta_info"]["prompt_tokens"] for item in ret) + prompt_tokens = sum( + ret[i]["meta_info"]["prompt_tokens"] for i in range(0, len(ret), request.n) + ) completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret) response = CompletionResponse( id=ret[0]["meta_info"]["id"], @@ -707,8 +709,6 @@ def v1_chat_generate_request(all_requests, tokenizer_manager): def v1_chat_generate_response(request, ret, to_file=False): choices = [] - total_prompt_tokens = 0 - total_completion_tokens = 0 for idx, ret_item in enumerate(ret): logprobs = False @@ -747,8 +747,6 @@ def v1_chat_generate_response(request, ret, to_file=False): choice_logprobs = ChoiceLogprobs(content=token_logprobs) else: choice_logprobs = None - prompt_tokens = ret_item["meta_info"]["prompt_tokens"] - completion_tokens = ret_item["meta_info"]["completion_tokens"] if to_file: # to make the choice data json serializable @@ -767,8 +765,7 @@ def v1_chat_generate_response(request, ret, to_file=False): ) choices.append(choice_data) - total_prompt_tokens += prompt_tokens - total_completion_tokens += completion_tokens + if to_file: responses = [] @@ -795,14 +792,18 @@ def v1_chat_generate_response(request, ret, to_file=False): responses.append(response) return responses else: + prompt_tokens = sum( + ret[i]["meta_info"]["prompt_tokens"] for i in range(0, len(ret), request.n) + ) + completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret) response = ChatCompletionResponse( id=ret[0]["meta_info"]["id"], model=request.model, choices=choices, usage=UsageInfo( - prompt_tokens=total_prompt_tokens, - completion_tokens=total_completion_tokens, - total_tokens=total_prompt_tokens + total_completion_tokens, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, ), ) return response diff --git a/test/srt/test_openai_server.py b/test/srt/test_openai_server.py index fef8da9ed..c98728ca8 100644 --- a/test/srt/test_openai_server.py +++ b/test/srt/test_openai_server.py @@ -45,11 +45,6 @@ class TestOpenAIServer(unittest.TestCase): prompt_arg = prompt_input num_choices = 1 - if parallel_sample_num: - # FIXME: This is wrong. We should not count the prompt tokens multiple times for - # parallel sampling. - num_prompt_tokens *= parallel_sample_num - response = client.completions.create( model=self.model, prompt=prompt_arg,