Support embedding input as a list (#1014)
This commit is contained in:
@@ -153,9 +153,7 @@ class TokenizerManager:
|
||||
async for response in self._handle_single_request(obj, request):
|
||||
yield response
|
||||
else:
|
||||
if isinstance(obj, EmbeddingReqInput):
|
||||
raise NotImplementedError("Please send only one prompt in each request")
|
||||
if obj.stream:
|
||||
if hasattr(obj, "stream") and obj.stream:
|
||||
raise ValueError("Do not support stream for batch mode.")
|
||||
|
||||
async for response in self._handle_batch_request(obj, request):
|
||||
@@ -283,24 +281,29 @@ class TokenizerManager:
|
||||
await self._wait_for_cache_prefill_response(event, state, obj, rid, request)
|
||||
yield input_ids
|
||||
|
||||
async def _handle_batch_request(self, obj: GenerateReqInput, request):
|
||||
async def _handle_batch_request(
|
||||
self, obj: Union[GenerateReqInput, EmbeddingReqInput], request
|
||||
):
|
||||
batch_size = obj.batch_size
|
||||
parallel_sample_num = obj.parallel_sample_num
|
||||
if self.is_generation:
|
||||
parallel_sample_num = obj.parallel_sample_num
|
||||
|
||||
if parallel_sample_num != 1:
|
||||
# Send prefill requests to cache the common input
|
||||
parallel_sample_num += 1
|
||||
input_id_result = [] if obj.input_ids is None else None
|
||||
for i in range(batch_size):
|
||||
async for input_id in self._handle_single_request(
|
||||
obj, request, index=i, is_cache_for_prefill=True
|
||||
):
|
||||
if input_id_result is not None:
|
||||
input_id_result.append(input_id)
|
||||
if input_id_result is not None and len(input_id_result) > 1:
|
||||
obj.input_ids = input_id_result
|
||||
elif input_id_result is not None:
|
||||
obj.input_ids = input_id_result[0]
|
||||
if parallel_sample_num != 1:
|
||||
# Send prefill requests to cache the common input
|
||||
parallel_sample_num += 1
|
||||
input_id_result = [] if obj.input_ids is None else None
|
||||
for i in range(batch_size):
|
||||
async for input_id in self._handle_single_request(
|
||||
obj, request, index=i, is_cache_for_prefill=True
|
||||
):
|
||||
if input_id_result is not None:
|
||||
input_id_result.append(input_id)
|
||||
if input_id_result is not None and len(input_id_result) > 1:
|
||||
obj.input_ids = input_id_result
|
||||
elif input_id_result is not None:
|
||||
obj.input_ids = input_id_result[0]
|
||||
else:
|
||||
parallel_sample_num = 1
|
||||
|
||||
# First send out all requests
|
||||
for i in range(batch_size):
|
||||
@@ -329,28 +332,38 @@ class TokenizerManager:
|
||||
input_text = None
|
||||
input_ids = obj.input_ids[i]
|
||||
sampling_params = self._get_sampling_params(obj.sampling_params[index])
|
||||
pixel_values, image_hash, image_size = await self._get_pixel_values(
|
||||
obj.image_data[index]
|
||||
)
|
||||
|
||||
tokenized_obj = TokenizedGenerateReqInput(
|
||||
rid,
|
||||
input_text,
|
||||
input_ids,
|
||||
pixel_values,
|
||||
image_hash,
|
||||
image_size,
|
||||
sampling_params,
|
||||
obj.return_logprob[index],
|
||||
obj.logprob_start_len[index],
|
||||
obj.top_logprobs_num[index],
|
||||
obj.stream,
|
||||
)
|
||||
if self.is_generation:
|
||||
pixel_values, image_hash, image_size = await self._get_pixel_values(
|
||||
obj.image_data[index]
|
||||
)
|
||||
|
||||
tokenized_obj = TokenizedGenerateReqInput(
|
||||
rid,
|
||||
input_text,
|
||||
input_ids,
|
||||
pixel_values,
|
||||
image_hash,
|
||||
image_size,
|
||||
sampling_params,
|
||||
obj.return_logprob[index],
|
||||
obj.logprob_start_len[index],
|
||||
obj.top_logprobs_num[index],
|
||||
obj.stream,
|
||||
)
|
||||
else:
|
||||
tokenized_obj = TokenizedEmbeddingReqInput(
|
||||
rid,
|
||||
input_text,
|
||||
input_ids,
|
||||
sampling_params,
|
||||
)
|
||||
self.send_to_router.send_pyobj(tokenized_obj)
|
||||
|
||||
event = asyncio.Event()
|
||||
state = ReqState([], False, event)
|
||||
self.rid_to_state[rid] = state
|
||||
|
||||
# Then wait for all responses
|
||||
output_list = []
|
||||
for i in range(batch_size):
|
||||
@@ -373,14 +386,17 @@ class TokenizerManager:
|
||||
self.abort_request(rid)
|
||||
raise ValueError(f"Abort request {rid}")
|
||||
continue
|
||||
output_list.append(
|
||||
self.convert_logprob_style(
|
||||
state.out_list[-1],
|
||||
obj.return_logprob[index],
|
||||
obj.top_logprobs_num[index],
|
||||
obj.return_text_in_logprobs,
|
||||
if self.is_generation:
|
||||
output_list.append(
|
||||
self.convert_logprob_style(
|
||||
state.out_list[-1],
|
||||
obj.return_logprob[index],
|
||||
obj.top_logprobs_num[index],
|
||||
obj.return_text_in_logprobs,
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
output_list.append(state.out_list[-1])
|
||||
assert state.finished
|
||||
del self.rid_to_state[rid]
|
||||
yield output_list
|
||||
|
||||
@@ -219,11 +219,9 @@ class SRTRunner:
|
||||
output_strs=output_strs, top_input_logprobs=top_input_logprobs
|
||||
)
|
||||
else:
|
||||
logits = []
|
||||
for prompt in prompts:
|
||||
response = self.runtime.encode(prompt)
|
||||
response = json.loads(response)
|
||||
logits.append(response["embedding"])
|
||||
response = self.runtime.encode(prompts)
|
||||
response = json.loads(response)
|
||||
logits = [x["embedding"] for x in response]
|
||||
return ModelOutput(embed_logits=logits)
|
||||
|
||||
def __enter__(self):
|
||||
|
||||
Reference in New Issue
Block a user