Fix the double BOS problem in the HF chat template (#888)
This commit is contained in:
@@ -594,7 +594,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|||||||
|
|
||||||
def v1_chat_generate_request(all_requests, tokenizer_manager):
|
def v1_chat_generate_request(all_requests, tokenizer_manager):
|
||||||
|
|
||||||
texts = []
|
input_ids = []
|
||||||
sampling_params_list = []
|
sampling_params_list = []
|
||||||
image_data_list = []
|
image_data_list = []
|
||||||
return_logprobs = []
|
return_logprobs = []
|
||||||
@@ -608,8 +608,8 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
|
|||||||
if not isinstance(request.messages, str):
|
if not isinstance(request.messages, str):
|
||||||
# Apply chat template and its stop strings.
|
# Apply chat template and its stop strings.
|
||||||
if chat_template_name is None:
|
if chat_template_name is None:
|
||||||
prompt = tokenizer_manager.tokenizer.apply_chat_template(
|
prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
|
||||||
request.messages, tokenize=False, add_generation_prompt=True
|
request.messages, tokenize=True, add_generation_prompt=True
|
||||||
)
|
)
|
||||||
stop = request.stop
|
stop = request.stop
|
||||||
image_data = None
|
image_data = None
|
||||||
@@ -623,12 +623,13 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
|
|||||||
stop.append(request.stop)
|
stop.append(request.stop)
|
||||||
else:
|
else:
|
||||||
stop.extend(request.stop)
|
stop.extend(request.stop)
|
||||||
|
prompt_ids = tokenizer_manager.tokenizer.encode(prompt)
|
||||||
else:
|
else:
|
||||||
# Use the raw prompt and stop strings if the messages is already a string.
|
# Use the raw prompt and stop strings if the messages is already a string.
|
||||||
prompt = request.messages
|
prompt = request.messages
|
||||||
stop = request.stop
|
stop = request.stop
|
||||||
image_data = None
|
image_data = None
|
||||||
texts.append(prompt)
|
input_ids.append(prompt_ids)
|
||||||
return_logprobs.append(request.logprobs)
|
return_logprobs.append(request.logprobs)
|
||||||
top_logprobs_nums.append(request.top_logprobs)
|
top_logprobs_nums.append(request.top_logprobs)
|
||||||
sampling_params_list.append(
|
sampling_params_list.append(
|
||||||
@@ -645,13 +646,13 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
|
|||||||
)
|
)
|
||||||
image_data_list.append(image_data)
|
image_data_list.append(image_data)
|
||||||
if len(all_requests) == 1:
|
if len(all_requests) == 1:
|
||||||
texts = texts[0]
|
input_ids = input_ids[0]
|
||||||
sampling_params_list = sampling_params_list[0]
|
sampling_params_list = sampling_params_list[0]
|
||||||
image_data = image_data_list[0]
|
image_data = image_data_list[0]
|
||||||
return_logprobs = return_logprobs[0]
|
return_logprobs = return_logprobs[0]
|
||||||
top_logprobs_nums = top_logprobs_nums[0]
|
top_logprobs_nums = top_logprobs_nums[0]
|
||||||
adapted_request = GenerateReqInput(
|
adapted_request = GenerateReqInput(
|
||||||
text=texts,
|
input_ids=input_ids,
|
||||||
image_data=image_data,
|
image_data=image_data,
|
||||||
sampling_params=sampling_params_list,
|
sampling_params=sampling_params_list,
|
||||||
return_logprob=return_logprobs,
|
return_logprob=return_logprobs,
|
||||||
|
|||||||
Reference in New Issue
Block a user