Multiple minor fixes (#1530)
This commit is contained in:
@@ -235,6 +235,7 @@ class RuntimeEndpoint(BaseBackend):
|
||||
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
|
||||
obj = self._generate_http_request(s, data)
|
||||
prompt_len = obj["meta_info"]["prompt_tokens"]
|
||||
logprob_start_len = max(prompt_len - 2, 0) # For token healing
|
||||
|
||||
# Compute logprob
|
||||
data = {
|
||||
@@ -245,7 +246,7 @@ class RuntimeEndpoint(BaseBackend):
|
||||
},
|
||||
"return_logprob": True,
|
||||
"return_text_in_logprobs": True,
|
||||
"logprob_start_len": prompt_len - 2, # For token healing
|
||||
"logprob_start_len": logprob_start_len,
|
||||
}
|
||||
obj = self._generate_http_request(s, data)
|
||||
|
||||
@@ -258,8 +259,8 @@ class RuntimeEndpoint(BaseBackend):
|
||||
# Remove extra token if no token healing occurred
|
||||
for i in range(len(input_token_logprobs)):
|
||||
healed_token_str = input_token_logprobs[i][0][-1]
|
||||
healed_token_logprob = input_token_logprobs[i][0][0]
|
||||
if s.text_.endswith(healed_token_str):
|
||||
healed_token_logprob = input_token_logprobs[i][0][0]
|
||||
normalized_prompt_logprobs[i] = (
|
||||
normalized_prompt_logprobs[i] * len(input_token_logprobs[i])
|
||||
- healed_token_logprob
|
||||
|
||||
Reference in New Issue
Block a user