[Fix] Fix logprob and normalized_logprob (#1428)

This commit is contained in:
Lianmin Zheng
2024-09-15 06:36:06 -07:00
committed by GitHub
parent 282681b8a1
commit 9ba1f09760
22 changed files with 314 additions and 215 deletions

View File

@@ -22,7 +22,7 @@ import os
import time
import uuid
from http import HTTPStatus
from typing import Dict, List, Optional
from typing import Dict, List
from fastapi import HTTPException, Request, UploadFile
from fastapi.responses import JSONResponse, StreamingResponse
@@ -472,7 +472,7 @@ def v1_generate_request(
first_prompt_type = type(all_requests[0].prompt)
for request in all_requests:
assert (
type(request.prompt) == first_prompt_type
type(request.prompt) is first_prompt_type
), "All prompts must be of the same type in file input settings"
if len(all_requests) > 1 and request.n > 1:
raise ValueError(
@@ -887,7 +887,7 @@ def v1_chat_generate_request(
input_ids.append(prompt_ids)
return_logprobs.append(request.logprobs)
logprob_start_lens.append(-1)
top_logprobs_nums.append(request.top_logprobs)
top_logprobs_nums.append(request.top_logprobs or 0)
sampling_params = {
"temperature": request.temperature,