[Minor] Improve the code style in TokenizerManager (#767)
This commit is contained in:
@@ -376,7 +376,7 @@ class Batch:
|
|||||||
logit_bias = torch.zeros(
|
logit_bias = torch.zeros(
|
||||||
(bs, vocab_size), dtype=torch.float32, device=device
|
(bs, vocab_size), dtype=torch.float32, device=device
|
||||||
)
|
)
|
||||||
logit_bias[i] = int_token_logit_bias
|
logit_bias[i][: len(int_token_logit_bias)] = int_token_logit_bias
|
||||||
|
|
||||||
# Set fields
|
# Set fields
|
||||||
self.input_ids = torch.tensor(
|
self.input_ids = torch.tensor(
|
||||||
|
|||||||
@@ -133,24 +133,10 @@ class TokenizerManager:
|
|||||||
async for response in self._handle_batch_request(obj, request):
|
async for response in self._handle_batch_request(obj, request):
|
||||||
yield response
|
yield response
|
||||||
|
|
||||||
async def _handle_single_request(self, obj, request, index=None, is_prefill=False):
|
async def _handle_single_request(
|
||||||
if is_prefill:
|
self, obj, request, index=None, is_cache_for_prefill=False
|
||||||
if isinstance(obj.text, list):
|
):
|
||||||
input_text = obj.text[index]
|
if not is_cache_for_prefill:
|
||||||
rid = obj.rid[index]
|
|
||||||
else:
|
|
||||||
input_text = obj.text
|
|
||||||
rid = obj.rid[0]
|
|
||||||
input_ids = self.tokenizer.encode(input_text)
|
|
||||||
sampling_params = SamplingParams(**obj.sampling_params[0])
|
|
||||||
sampling_params.max_new_tokens = 0
|
|
||||||
pixel_values, image_hash, image_size = await self._get_pixel_values(
|
|
||||||
obj.image_data[0]
|
|
||||||
)
|
|
||||||
return_logprob = obj.return_logprob[0]
|
|
||||||
logprob_start_len = obj.logprob_start_len[0]
|
|
||||||
top_logprobs_num = obj.top_logprobs_num[0]
|
|
||||||
else:
|
|
||||||
rid = obj.rid if index is None else obj.rid[index]
|
rid = obj.rid if index is None else obj.rid[index]
|
||||||
input_text = obj.text if index is None else obj.text[index]
|
input_text = obj.text if index is None else obj.text[index]
|
||||||
input_ids = (
|
input_ids = (
|
||||||
@@ -177,6 +163,22 @@ class TokenizerManager:
|
|||||||
top_logprobs_num = (
|
top_logprobs_num = (
|
||||||
obj.top_logprobs_num if index is None else obj.top_logprobs_num[index]
|
obj.top_logprobs_num if index is None else obj.top_logprobs_num[index]
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
if isinstance(obj.text, list):
|
||||||
|
input_text = obj.text[index]
|
||||||
|
rid = obj.rid[index]
|
||||||
|
else:
|
||||||
|
input_text = obj.text
|
||||||
|
rid = obj.rid[0]
|
||||||
|
input_ids = self.tokenizer.encode(input_text)
|
||||||
|
sampling_params = SamplingParams(**obj.sampling_params[0])
|
||||||
|
sampling_params.max_new_tokens = 0
|
||||||
|
pixel_values, image_hash, image_size = await self._get_pixel_values(
|
||||||
|
obj.image_data[0]
|
||||||
|
)
|
||||||
|
return_logprob = obj.return_logprob[0]
|
||||||
|
logprob_start_len = obj.logprob_start_len[0]
|
||||||
|
top_logprobs_num = obj.top_logprobs_num[0]
|
||||||
|
|
||||||
tokenized_obj = TokenizedGenerateReqInput(
|
tokenized_obj = TokenizedGenerateReqInput(
|
||||||
rid,
|
rid,
|
||||||
@@ -196,26 +198,26 @@ class TokenizerManager:
|
|||||||
event = asyncio.Event()
|
event = asyncio.Event()
|
||||||
state = ReqState([], False, event)
|
state = ReqState([], False, event)
|
||||||
self.rid_to_state[rid] = state
|
self.rid_to_state[rid] = state
|
||||||
if is_prefill:
|
if not is_cache_for_prefill:
|
||||||
await self._wait_for_prefill_response(event, state, obj, request, rid)
|
|
||||||
yield input_ids
|
|
||||||
else:
|
|
||||||
async for response in self._wait_for_response(
|
async for response in self._wait_for_response(
|
||||||
event, state, obj, rid, request
|
event, state, obj, rid, request
|
||||||
):
|
):
|
||||||
yield response
|
yield response
|
||||||
|
else:
|
||||||
|
await self._wait_for_cache_prefill_response(event, state, obj, rid, request)
|
||||||
|
yield input_ids
|
||||||
|
|
||||||
async def _handle_batch_request(self, obj, request):
|
async def _handle_batch_request(self, obj: GenerateReqInput, request):
|
||||||
batch_size = obj.batch_size
|
batch_size = obj.batch_size
|
||||||
parallel_sample_num = obj.sampling_params[0].get("n", 1)
|
parallel_sample_num = obj.sampling_params[0].get("n", 1)
|
||||||
|
|
||||||
if parallel_sample_num != 1:
|
if parallel_sample_num != 1:
|
||||||
## send prefill requests
|
# Send prefill requests to cache the common input
|
||||||
parallel_sample_num += 1
|
parallel_sample_num += 1
|
||||||
input_id_result = [] if obj.input_ids is None else None
|
input_id_result = [] if obj.input_ids is None else None
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
async for input_id in self._handle_single_request(
|
async for input_id in self._handle_single_request(
|
||||||
obj, request, index=i, is_prefill=True
|
obj, request, index=i, is_cache_for_prefill=True
|
||||||
):
|
):
|
||||||
if input_id_result is not None:
|
if input_id_result is not None:
|
||||||
input_id_result.append(input_id)
|
input_id_result.append(input_id)
|
||||||
@@ -224,6 +226,7 @@ class TokenizerManager:
|
|||||||
obj.input_ids = input_id_result
|
obj.input_ids = input_id_result
|
||||||
elif input_id_result is not None:
|
elif input_id_result is not None:
|
||||||
obj.input_ids = input_id_result[0]
|
obj.input_ids = input_id_result[0]
|
||||||
|
|
||||||
# First send out all requests
|
# First send out all requests
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
for j in range(parallel_sample_num):
|
for j in range(parallel_sample_num):
|
||||||
@@ -308,17 +311,15 @@ class TokenizerManager:
|
|||||||
|
|
||||||
yield output_list
|
yield output_list
|
||||||
|
|
||||||
def _validate_input_length(self, input_ids):
|
def _validate_input_length(self, input_ids: List[int]):
|
||||||
if len(input_ids) >= self.context_len:
|
if len(input_ids) >= self.context_len:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"The input ({len(input_ids)} tokens) is longer than the "
|
f"The input ({len(input_ids)} tokens) is longer than the "
|
||||||
f"model's context length ({self.context_len} tokens)."
|
f"model's context length ({self.context_len} tokens)."
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_sampling_params(self, sampling_params_data, max_new_tokens=None):
|
def _get_sampling_params(self, sampling_params_data: dict):
|
||||||
sampling_params = SamplingParams(**sampling_params_data)
|
sampling_params = SamplingParams(**sampling_params_data)
|
||||||
if max_new_tokens is not None:
|
|
||||||
sampling_params.max_new_tokens = max_new_tokens
|
|
||||||
if sampling_params.max_new_tokens != 0:
|
if sampling_params.max_new_tokens != 0:
|
||||||
sampling_params.normalize(self.tokenizer)
|
sampling_params.normalize(self.tokenizer)
|
||||||
sampling_params.verify()
|
sampling_params.verify()
|
||||||
@@ -332,7 +333,14 @@ class TokenizerManager:
|
|||||||
else:
|
else:
|
||||||
return None, None, None
|
return None, None, None
|
||||||
|
|
||||||
async def _wait_for_response(self, event, state, obj, rid, request):
|
async def _wait_for_response(
|
||||||
|
self,
|
||||||
|
event: asyncio.Event,
|
||||||
|
state: ReqState,
|
||||||
|
obj: GenerateReqInput,
|
||||||
|
rid: str,
|
||||||
|
request,
|
||||||
|
):
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(event.wait(), timeout=4)
|
await asyncio.wait_for(event.wait(), timeout=4)
|
||||||
@@ -361,7 +369,14 @@ class TokenizerManager:
|
|||||||
event.clear()
|
event.clear()
|
||||||
yield out
|
yield out
|
||||||
|
|
||||||
async def _wait_for_prefill_response(self, event, state, obj, request, rid):
|
async def _wait_for_cache_prefill_response(
|
||||||
|
self,
|
||||||
|
event: asyncio.Event,
|
||||||
|
state: ReqState,
|
||||||
|
obj: GenerateReqInput,
|
||||||
|
rid: str,
|
||||||
|
request,
|
||||||
|
):
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(state.event.wait(), timeout=4)
|
await asyncio.wait_for(state.event.wait(), timeout=4)
|
||||||
@@ -380,7 +395,7 @@ class TokenizerManager:
|
|||||||
req = FlushCacheReq()
|
req = FlushCacheReq()
|
||||||
self.send_to_router.send_pyobj(req)
|
self.send_to_router.send_pyobj(req)
|
||||||
|
|
||||||
def abort_request(self, rid):
|
def abort_request(self, rid: str):
|
||||||
if rid not in self.rid_to_state:
|
if rid not in self.rid_to_state:
|
||||||
return
|
return
|
||||||
del self.rid_to_state[rid]
|
del self.rid_to_state[rid]
|
||||||
@@ -426,7 +441,11 @@ class TokenizerManager:
|
|||||||
state.event.set()
|
state.event.set()
|
||||||
|
|
||||||
def convert_logprob_style(
|
def convert_logprob_style(
|
||||||
self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs
|
self,
|
||||||
|
ret: dict,
|
||||||
|
return_logprob: bool,
|
||||||
|
top_logprobs_num: int,
|
||||||
|
return_text_in_logprobs: bool,
|
||||||
):
|
):
|
||||||
if return_logprob:
|
if return_logprob:
|
||||||
ret["meta_info"]["prefill_token_logprobs"] = self.detokenize_logprob_tokens(
|
ret["meta_info"]["prefill_token_logprobs"] = self.detokenize_logprob_tokens(
|
||||||
@@ -450,7 +469,7 @@ class TokenizerManager:
|
|||||||
)
|
)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def detokenize_logprob_tokens(self, token_logprobs, decode_to_text):
|
def detokenize_logprob_tokens(self, token_logprobs, decode_to_text: bool):
|
||||||
if not decode_to_text:
|
if not decode_to_text:
|
||||||
return [(logprob, token_id, None) for logprob, token_id in token_logprobs]
|
return [(logprob, token_id, None) for logprob, token_id in token_logprobs]
|
||||||
|
|
||||||
@@ -461,7 +480,7 @@ class TokenizerManager:
|
|||||||
for (logprob, token_id), token_text, in zip(token_logprobs, token_texts)
|
for (logprob, token_id), token_text, in zip(token_logprobs, token_texts)
|
||||||
]
|
]
|
||||||
|
|
||||||
def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text):
|
def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool):
|
||||||
for i, t in enumerate(top_logprobs):
|
for i, t in enumerate(top_logprobs):
|
||||||
if t:
|
if t:
|
||||||
top_logprobs[i] = self.detokenize_logprob_tokens(t, decode_to_text)
|
top_logprobs[i] = self.detokenize_logprob_tokens(t, decode_to_text)
|
||||||
|
|||||||
@@ -118,7 +118,11 @@ def test_decode_json_regex():
|
|||||||
s += "}"
|
s += "}"
|
||||||
|
|
||||||
ret = decode_json.run()
|
ret = decode_json.run()
|
||||||
js_obj = json.loads(ret["json_output"])
|
try:
|
||||||
|
js_obj = json.loads(ret["json_output"])
|
||||||
|
except json.decoder.JSONDecodeError:
|
||||||
|
print(ret["json_output"])
|
||||||
|
raise
|
||||||
assert isinstance(js_obj["name"], str)
|
assert isinstance(js_obj["name"], str)
|
||||||
assert isinstance(js_obj["population"], int)
|
assert isinstance(js_obj["population"], int)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user