[Fix] Fix logprob and normalized_logprob (#1428)
This commit is contained in:
@@ -22,7 +22,7 @@ import os
|
||||
import time
|
||||
import uuid
|
||||
from http import HTTPStatus
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List
|
||||
|
||||
from fastapi import HTTPException, Request, UploadFile
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
@@ -472,7 +472,7 @@ def v1_generate_request(
|
||||
first_prompt_type = type(all_requests[0].prompt)
|
||||
for request in all_requests:
|
||||
assert (
|
||||
type(request.prompt) == first_prompt_type
|
||||
type(request.prompt) is first_prompt_type
|
||||
), "All prompts must be of the same type in file input settings"
|
||||
if len(all_requests) > 1 and request.n > 1:
|
||||
raise ValueError(
|
||||
@@ -887,7 +887,7 @@ def v1_chat_generate_request(
|
||||
input_ids.append(prompt_ids)
|
||||
return_logprobs.append(request.logprobs)
|
||||
logprob_start_lens.append(-1)
|
||||
top_logprobs_nums.append(request.top_logprobs)
|
||||
top_logprobs_nums.append(request.top_logprobs or 0)
|
||||
|
||||
sampling_params = {
|
||||
"temperature": request.temperature,
|
||||
|
||||
Reference in New Issue
Block a user