From fbd6b94d6982298c1b488779e581029a1792df9b Mon Sep 17 00:00:00 2001 From: Ying Sheng Date: Fri, 2 Aug 2024 00:30:50 -0700 Subject: [PATCH] Fix the double BOS problem in the HF chat template (#888) --- python/sglang/srt/openai_api/adapter.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index c52d298d8..8ad028b1d 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -594,7 +594,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request): def v1_chat_generate_request(all_requests, tokenizer_manager): - texts = [] + input_ids = [] sampling_params_list = [] image_data_list = [] return_logprobs = [] @@ -608,8 +608,8 @@ def v1_chat_generate_request(all_requests, tokenizer_manager): if not isinstance(request.messages, str): # Apply chat template and its stop strings. if chat_template_name is None: - prompt = tokenizer_manager.tokenizer.apply_chat_template( - request.messages, tokenize=False, add_generation_prompt=True + prompt_ids = tokenizer_manager.tokenizer.apply_chat_template( + request.messages, tokenize=True, add_generation_prompt=True ) stop = request.stop image_data = None @@ -623,12 +623,13 @@ def v1_chat_generate_request(all_requests, tokenizer_manager): stop.append(request.stop) else: stop.extend(request.stop) + prompt_ids = tokenizer_manager.tokenizer.encode(prompt) else: # Use the raw prompt and stop strings if the messages is already a string. prompt = request.messages stop = request.stop image_data = None - texts.append(prompt) + input_ids.append(prompt_ids) return_logprobs.append(request.logprobs) top_logprobs_nums.append(request.top_logprobs) sampling_params_list.append( @@ -645,13 +646,13 @@ def v1_chat_generate_request(all_requests, tokenizer_manager): ) image_data_list.append(image_data) if len(all_requests) == 1: - texts = texts[0] + input_ids = input_ids[0] sampling_params_list = sampling_params_list[0] image_data = image_data_list[0] return_logprobs = return_logprobs[0] top_logprobs_nums = top_logprobs_nums[0] adapted_request = GenerateReqInput( - text=texts, + input_ids=input_ids, image_data=image_data, sampling_params=sampling_params_list, return_logprob=return_logprobs,