Fix logit processor bugs (#427)
This commit is contained in:
@@ -185,7 +185,10 @@ class TokenizerManager:
|
||||
|
||||
while True:
|
||||
await event.wait()
|
||||
yield state.out_list[-1]
|
||||
yield self.convert_logprob_style(state.out_list[-1],
|
||||
obj.return_logprob,
|
||||
obj.top_logprobs_num,
|
||||
obj.return_text_in_logprobs)
|
||||
state.out_list = []
|
||||
if state.finished:
|
||||
del self.rid_to_state[rid]
|
||||
@@ -231,16 +234,16 @@ class TokenizerManager:
|
||||
rid = obj.rid[i]
|
||||
state = self.rid_to_state[rid]
|
||||
await state.event.wait()
|
||||
output_list.append(state.out_list[-1])
|
||||
output_list.append(
|
||||
self.convert_logprob_style(state.out_list[-1],
|
||||
obj.return_logprob[i],
|
||||
obj.top_logprobs_num[i],
|
||||
obj.return_text_in_logprobs))
|
||||
assert state.finished
|
||||
del self.rid_to_state[rid]
|
||||
|
||||
yield output_list
|
||||
|
||||
async def detokenize(self, obj: DetokenizeReqInput):
|
||||
token_texts = self.tokenizer.convert_ids_to_tokens(obj.input_ids)
|
||||
return [t.decode() if isinstance(t, bytes) else t for t in token_texts]
|
||||
|
||||
async def flush_cache(self):
|
||||
flush_cache_req = FlushCacheReq()
|
||||
self.send_to_router.send_pyobj(flush_cache_req)
|
||||
@@ -267,3 +270,37 @@ class TokenizerManager:
|
||||
state.event.set()
|
||||
else:
|
||||
raise ValueError(f"Invalid object: {recv_obj}")
|
||||
|
||||
def convert_logprob_style(self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs):
|
||||
if return_logprob:
|
||||
ret["meta_info"]["prefill_token_logprobs"] = self.detokenize_logprob_tokens(
|
||||
ret["meta_info"]["prefill_token_logprobs"], return_text_in_logprobs
|
||||
)
|
||||
ret["meta_info"]["decode_token_logprobs"] = self.detokenize_logprob_tokens(
|
||||
ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs
|
||||
)
|
||||
if top_logprobs_num > 0:
|
||||
ret["meta_info"]["prefill_top_logprobs"] = self.detokenize_top_logprobs_tokens(
|
||||
ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
|
||||
)
|
||||
ret["meta_info"]["decode_top_logprobs"] = self.detokenize_top_logprobs_tokens(
|
||||
ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
|
||||
)
|
||||
return ret
|
||||
|
||||
def detokenize_logprob_tokens(self, token_logprobs, decode_to_text):
|
||||
if not decode_to_text:
|
||||
return [(logprob, token_id, None) for logprob, token_id in token_logprobs]
|
||||
|
||||
token_ids = [tid for _, tid in token_logprobs]
|
||||
token_texts = self.tokenizer.batch_decode(token_ids)
|
||||
return [
|
||||
(logprob, token_id, token_text)
|
||||
for (logprob, token_id), token_text, in zip(token_logprobs, token_texts)
|
||||
]
|
||||
|
||||
def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text):
|
||||
for i, t in enumerate(top_logprobs):
|
||||
if t:
|
||||
top_logprobs[i] = self.detokenize_logprob_tokens(t, decode_to_text)
|
||||
return top_logprobs
|
||||
|
||||
Reference in New Issue
Block a user