Support embedding input as a list (#1014)
This commit is contained in:
@@ -153,9 +153,7 @@ class TokenizerManager:
|
|||||||
async for response in self._handle_single_request(obj, request):
|
async for response in self._handle_single_request(obj, request):
|
||||||
yield response
|
yield response
|
||||||
else:
|
else:
|
||||||
if isinstance(obj, EmbeddingReqInput):
|
if hasattr(obj, "stream") and obj.stream:
|
||||||
raise NotImplementedError("Please send only one prompt in each request")
|
|
||||||
if obj.stream:
|
|
||||||
raise ValueError("Do not support stream for batch mode.")
|
raise ValueError("Do not support stream for batch mode.")
|
||||||
|
|
||||||
async for response in self._handle_batch_request(obj, request):
|
async for response in self._handle_batch_request(obj, request):
|
||||||
@@ -283,24 +281,29 @@ class TokenizerManager:
|
|||||||
await self._wait_for_cache_prefill_response(event, state, obj, rid, request)
|
await self._wait_for_cache_prefill_response(event, state, obj, rid, request)
|
||||||
yield input_ids
|
yield input_ids
|
||||||
|
|
||||||
async def _handle_batch_request(self, obj: GenerateReqInput, request):
|
async def _handle_batch_request(
|
||||||
|
self, obj: Union[GenerateReqInput, EmbeddingReqInput], request
|
||||||
|
):
|
||||||
batch_size = obj.batch_size
|
batch_size = obj.batch_size
|
||||||
parallel_sample_num = obj.parallel_sample_num
|
if self.is_generation:
|
||||||
|
parallel_sample_num = obj.parallel_sample_num
|
||||||
|
|
||||||
if parallel_sample_num != 1:
|
if parallel_sample_num != 1:
|
||||||
# Send prefill requests to cache the common input
|
# Send prefill requests to cache the common input
|
||||||
parallel_sample_num += 1
|
parallel_sample_num += 1
|
||||||
input_id_result = [] if obj.input_ids is None else None
|
input_id_result = [] if obj.input_ids is None else None
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
async for input_id in self._handle_single_request(
|
async for input_id in self._handle_single_request(
|
||||||
obj, request, index=i, is_cache_for_prefill=True
|
obj, request, index=i, is_cache_for_prefill=True
|
||||||
):
|
):
|
||||||
if input_id_result is not None:
|
if input_id_result is not None:
|
||||||
input_id_result.append(input_id)
|
input_id_result.append(input_id)
|
||||||
if input_id_result is not None and len(input_id_result) > 1:
|
if input_id_result is not None and len(input_id_result) > 1:
|
||||||
obj.input_ids = input_id_result
|
obj.input_ids = input_id_result
|
||||||
elif input_id_result is not None:
|
elif input_id_result is not None:
|
||||||
obj.input_ids = input_id_result[0]
|
obj.input_ids = input_id_result[0]
|
||||||
|
else:
|
||||||
|
parallel_sample_num = 1
|
||||||
|
|
||||||
# First send out all requests
|
# First send out all requests
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
@@ -329,28 +332,38 @@ class TokenizerManager:
|
|||||||
input_text = None
|
input_text = None
|
||||||
input_ids = obj.input_ids[i]
|
input_ids = obj.input_ids[i]
|
||||||
sampling_params = self._get_sampling_params(obj.sampling_params[index])
|
sampling_params = self._get_sampling_params(obj.sampling_params[index])
|
||||||
pixel_values, image_hash, image_size = await self._get_pixel_values(
|
|
||||||
obj.image_data[index]
|
|
||||||
)
|
|
||||||
|
|
||||||
tokenized_obj = TokenizedGenerateReqInput(
|
if self.is_generation:
|
||||||
rid,
|
pixel_values, image_hash, image_size = await self._get_pixel_values(
|
||||||
input_text,
|
obj.image_data[index]
|
||||||
input_ids,
|
)
|
||||||
pixel_values,
|
|
||||||
image_hash,
|
tokenized_obj = TokenizedGenerateReqInput(
|
||||||
image_size,
|
rid,
|
||||||
sampling_params,
|
input_text,
|
||||||
obj.return_logprob[index],
|
input_ids,
|
||||||
obj.logprob_start_len[index],
|
pixel_values,
|
||||||
obj.top_logprobs_num[index],
|
image_hash,
|
||||||
obj.stream,
|
image_size,
|
||||||
)
|
sampling_params,
|
||||||
|
obj.return_logprob[index],
|
||||||
|
obj.logprob_start_len[index],
|
||||||
|
obj.top_logprobs_num[index],
|
||||||
|
obj.stream,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
tokenized_obj = TokenizedEmbeddingReqInput(
|
||||||
|
rid,
|
||||||
|
input_text,
|
||||||
|
input_ids,
|
||||||
|
sampling_params,
|
||||||
|
)
|
||||||
self.send_to_router.send_pyobj(tokenized_obj)
|
self.send_to_router.send_pyobj(tokenized_obj)
|
||||||
|
|
||||||
event = asyncio.Event()
|
event = asyncio.Event()
|
||||||
state = ReqState([], False, event)
|
state = ReqState([], False, event)
|
||||||
self.rid_to_state[rid] = state
|
self.rid_to_state[rid] = state
|
||||||
|
|
||||||
# Then wait for all responses
|
# Then wait for all responses
|
||||||
output_list = []
|
output_list = []
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
@@ -373,14 +386,17 @@ class TokenizerManager:
|
|||||||
self.abort_request(rid)
|
self.abort_request(rid)
|
||||||
raise ValueError(f"Abort request {rid}")
|
raise ValueError(f"Abort request {rid}")
|
||||||
continue
|
continue
|
||||||
output_list.append(
|
if self.is_generation:
|
||||||
self.convert_logprob_style(
|
output_list.append(
|
||||||
state.out_list[-1],
|
self.convert_logprob_style(
|
||||||
obj.return_logprob[index],
|
state.out_list[-1],
|
||||||
obj.top_logprobs_num[index],
|
obj.return_logprob[index],
|
||||||
obj.return_text_in_logprobs,
|
obj.top_logprobs_num[index],
|
||||||
|
obj.return_text_in_logprobs,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
else:
|
||||||
|
output_list.append(state.out_list[-1])
|
||||||
assert state.finished
|
assert state.finished
|
||||||
del self.rid_to_state[rid]
|
del self.rid_to_state[rid]
|
||||||
yield output_list
|
yield output_list
|
||||||
|
|||||||
@@ -219,11 +219,9 @@ class SRTRunner:
|
|||||||
output_strs=output_strs, top_input_logprobs=top_input_logprobs
|
output_strs=output_strs, top_input_logprobs=top_input_logprobs
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logits = []
|
response = self.runtime.encode(prompts)
|
||||||
for prompt in prompts:
|
response = json.loads(response)
|
||||||
response = self.runtime.encode(prompt)
|
logits = [x["embedding"] for x in response]
|
||||||
response = json.loads(response)
|
|
||||||
logits.append(response["embedding"])
|
|
||||||
return ModelOutput(embed_logits=logits)
|
return ModelOutput(embed_logits=logits)
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
|
|||||||
@@ -38,8 +38,9 @@ class TestOpenAIServer(unittest.TestCase):
|
|||||||
num_prompt_tokens = len(self.tokenizer.encode(prompt))
|
num_prompt_tokens = len(self.tokenizer.encode(prompt))
|
||||||
|
|
||||||
if use_list_input:
|
if use_list_input:
|
||||||
prompt_arg = [prompt_input, prompt_input]
|
prompt_arg = [prompt_input] * 2
|
||||||
num_prompts = len(prompt_arg)
|
num_prompts = len(prompt_arg)
|
||||||
|
num_prompt_tokens *= num_prompts
|
||||||
else:
|
else:
|
||||||
prompt_arg = prompt_input
|
prompt_arg = prompt_input
|
||||||
num_prompts = 1
|
num_prompts = 1
|
||||||
@@ -70,7 +71,7 @@ class TestOpenAIServer(unittest.TestCase):
|
|||||||
def test_embedding(self):
|
def test_embedding(self):
|
||||||
# TODO the fields of encoding_format, dimensions, user are skipped
|
# TODO the fields of encoding_format, dimensions, user are skipped
|
||||||
# TODO support use_list_input
|
# TODO support use_list_input
|
||||||
for use_list_input in [False]:
|
for use_list_input in [False, True]:
|
||||||
for token_input in [False, True]:
|
for token_input in [False, True]:
|
||||||
self.run_embedding(use_list_input, token_input)
|
self.run_embedding(use_list_input, token_input)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user