From 065bb9475344c66d468c9a7ba71fb1ea465292a0 Mon Sep 17 00:00:00 2001 From: Jeffrey Fong Date: Sun, 29 Sep 2024 05:04:06 +0800 Subject: [PATCH] Fix RuntimeEndpoint.select method (#1495) --- python/sglang/lang/backend/runtime_endpoint.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/python/sglang/lang/backend/runtime_endpoint.py b/python/sglang/lang/backend/runtime_endpoint.py index e1194b6cf..dc202ff1e 100644 --- a/python/sglang/lang/backend/runtime_endpoint.py +++ b/python/sglang/lang/backend/runtime_endpoint.py @@ -244,7 +244,8 @@ class RuntimeEndpoint(BaseBackend): "temperature": 0, }, "return_logprob": True, - "logprob_start_len": max(prompt_len - 2, 0), # for token healing + "return_text_in_logprobs": True, + "logprob_start_len": prompt_len - 2, # For token healing } obj = self._generate_http_request(s, data) @@ -254,6 +255,17 @@ class RuntimeEndpoint(BaseBackend): input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj] output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj] + # Remove extra token if no token healing occurred + for i in range(len(input_token_logprobs)): + healed_token_str = input_token_logprobs[i][0][-1] + healed_token_logprob = input_token_logprobs[i][0][0] + if s.text_.endswith(healed_token_str): + normalized_prompt_logprobs[i] = ( + normalized_prompt_logprobs[i] * len(input_token_logprobs[i]) + - healed_token_logprob + ) / (len(input_token_logprobs[i]) - 1) + input_token_logprobs[i] = input_token_logprobs[i][1:] + # Compute unconditional logprobs if required if choices_method.requires_unconditional_logprobs: input_ids = [[el[1] for el in subl] for subl in input_token_logprobs]