Add support for OpenAI API parallel sampling (#640)
This commit is contained in:
@@ -122,125 +122,150 @@ class TokenizerManager:
|
||||
|
||||
obj.post_init()
|
||||
is_single = obj.is_single
|
||||
|
||||
if is_single:
|
||||
rid = obj.rid
|
||||
|
||||
if obj.input_ids is None:
|
||||
input_ids = self.tokenizer.encode(obj.text)
|
||||
else:
|
||||
input_ids = obj.input_ids
|
||||
|
||||
if len(input_ids) >= self.context_len:
|
||||
raise ValueError(
|
||||
f"The input ({len(input_ids)} tokens) is longer than the "
|
||||
f"model's context length ({self.context_len} tokens)."
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(**obj.sampling_params)
|
||||
if sampling_params.max_new_tokens != 0:
|
||||
sampling_params.normalize(self.tokenizer)
|
||||
sampling_params.verify()
|
||||
|
||||
if isinstance(obj.image_data, list) and len(obj.image_data) > 0:
|
||||
pixel_values, image_hash, image_size = await self.get_pixel_values(
|
||||
obj.image_data[0]
|
||||
)
|
||||
elif isinstance(obj.image_data, str):
|
||||
pixel_values, image_hash, image_size = await self.get_pixel_values(
|
||||
obj.image_data
|
||||
)
|
||||
else:
|
||||
pixel_values, image_hash, image_size = None, None, None
|
||||
tokenized_obj = TokenizedGenerateReqInput(
|
||||
rid=rid,
|
||||
input_text=obj.text,
|
||||
input_ids=input_ids,
|
||||
pixel_values=pixel_values,
|
||||
image_hash=image_hash,
|
||||
image_size=image_size,
|
||||
sampling_params=sampling_params,
|
||||
return_logprob=obj.return_logprob,
|
||||
logprob_start_len=obj.logprob_start_len,
|
||||
top_logprobs_num=obj.top_logprobs_num,
|
||||
stream=obj.stream,
|
||||
)
|
||||
self.send_to_router.send_pyobj(tokenized_obj)
|
||||
|
||||
event = asyncio.Event()
|
||||
state = ReqState([], False, event)
|
||||
self.rid_to_state[rid] = state
|
||||
|
||||
while True:
|
||||
try:
|
||||
await asyncio.wait_for(event.wait(), timeout=4)
|
||||
except asyncio.TimeoutError:
|
||||
if request is not None and await request.is_disconnected():
|
||||
self.abort_request(rid)
|
||||
raise ValueError(f"Abort request {rid}")
|
||||
continue
|
||||
|
||||
out = self.convert_logprob_style(
|
||||
state.out_list[-1],
|
||||
obj.return_logprob,
|
||||
obj.top_logprobs_num,
|
||||
obj.return_text_in_logprobs,
|
||||
)
|
||||
|
||||
if self.server_args.log_requests and state.finished:
|
||||
logger.info(f"in={obj.text}, out={out}")
|
||||
|
||||
state.out_list = []
|
||||
if state.finished:
|
||||
del self.rid_to_state[rid]
|
||||
|
||||
yield out
|
||||
|
||||
break
|
||||
|
||||
event.clear()
|
||||
|
||||
yield out
|
||||
async for response in self._handle_single_request(obj, request):
|
||||
yield response
|
||||
else:
|
||||
if obj.stream:
|
||||
raise ValueError("Do not support stream for batch mode.")
|
||||
|
||||
if obj.input_ids is None:
|
||||
bs = len(obj.text)
|
||||
async for response in self._handle_batch_request(obj, request):
|
||||
yield response
|
||||
|
||||
async def _handle_single_request(self, obj, request, index=None, is_prefill=False):
|
||||
if is_prefill:
|
||||
if isinstance(obj.text, list):
|
||||
input_text = obj.text[index]
|
||||
rid = obj.rid[index]
|
||||
else:
|
||||
bs = len(obj.input_ids)
|
||||
input_text = obj.text
|
||||
rid = obj.rid[0]
|
||||
input_ids = self.tokenizer.encode(input_text)
|
||||
sampling_params = SamplingParams(**obj.sampling_params[0])
|
||||
sampling_params.max_new_tokens = 0
|
||||
pixel_values, image_hash, image_size = await self._get_pixel_values(
|
||||
obj.image_data[0]
|
||||
)
|
||||
return_logprob = obj.return_logprob[0]
|
||||
logprob_start_len = obj.logprob_start_len[0]
|
||||
top_logprobs_num = obj.top_logprobs_num[0]
|
||||
else:
|
||||
rid = obj.rid if index is None else obj.rid[index]
|
||||
input_text = obj.text if index is None else obj.text[index]
|
||||
input_ids = (
|
||||
self.tokenizer.encode(input_text)
|
||||
if obj.input_ids is None
|
||||
else obj.input_ids
|
||||
)
|
||||
if index is not None and obj.input_ids:
|
||||
input_ids = obj.input_ids[index]
|
||||
|
||||
for i in range(bs):
|
||||
rid = obj.rid[i]
|
||||
self._validate_input_length(input_ids)
|
||||
sampling_params = self._get_sampling_params(
|
||||
obj.sampling_params if index is None else obj.sampling_params[index]
|
||||
)
|
||||
pixel_values, image_hash, image_size = await self._get_pixel_values(
|
||||
obj.image_data if index is None else obj.image_data[index]
|
||||
)
|
||||
return_logprob = (
|
||||
obj.return_logprob if index is None else obj.return_logprob[index]
|
||||
)
|
||||
logprob_start_len = (
|
||||
obj.logprob_start_len if index is None else obj.logprob_start_len[index]
|
||||
)
|
||||
top_logprobs_num = (
|
||||
obj.top_logprobs_num if index is None else obj.top_logprobs_num[index]
|
||||
)
|
||||
|
||||
if obj.input_ids is None:
|
||||
input_text = obj.text[i]
|
||||
input_ids = self.tokenizer.encode(obj.text[i])
|
||||
tokenized_obj = TokenizedGenerateReqInput(
|
||||
rid,
|
||||
input_text,
|
||||
input_ids,
|
||||
pixel_values,
|
||||
image_hash,
|
||||
image_size,
|
||||
sampling_params,
|
||||
return_logprob,
|
||||
logprob_start_len,
|
||||
top_logprobs_num,
|
||||
obj.stream,
|
||||
)
|
||||
self.send_to_router.send_pyobj(tokenized_obj)
|
||||
|
||||
event = asyncio.Event()
|
||||
state = ReqState([], False, event)
|
||||
self.rid_to_state[rid] = state
|
||||
if is_prefill == False:
|
||||
async for response in self._wait_for_response(
|
||||
event, state, obj, rid, request
|
||||
):
|
||||
yield response
|
||||
else:
|
||||
await self._wait_for_prefill_response(event, state, obj, request, rid)
|
||||
yield input_ids
|
||||
|
||||
async def _handle_batch_request(self, obj, request):
|
||||
batch_size = obj.batch_size
|
||||
parallel_sample_num = obj.sampling_params[0].get("n", 1)
|
||||
|
||||
if parallel_sample_num != 1:
|
||||
## send prefill requests
|
||||
parallel_sample_num += 1
|
||||
input_id_result = [] if obj.input_ids is None else None
|
||||
for i in range(batch_size):
|
||||
async for input_id in self._handle_single_request(
|
||||
obj, request, index=i, is_prefill=True
|
||||
):
|
||||
if input_id_result is not None:
|
||||
input_id_result.append(input_id)
|
||||
pass
|
||||
if len(input_id_result) > 1 and input_id_result is not None:
|
||||
obj.input_ids = input_id_result
|
||||
elif input_id_result is not None:
|
||||
obj.input_ids = input_id_result[0]
|
||||
# First send out all requests
|
||||
for i in range(batch_size):
|
||||
for j in range(parallel_sample_num):
|
||||
if j == 0 and parallel_sample_num != 1:
|
||||
continue
|
||||
index = i * parallel_sample_num + j
|
||||
if parallel_sample_num != 1:
|
||||
# Here when using parallel sampling we shoul consider prefill stage so the index is : j + i * (parallel_sample_num-1) + batch_size - 1
|
||||
index += batch_size - 1 - i
|
||||
rid = obj.rid[index]
|
||||
if parallel_sample_num == 1:
|
||||
## select operation
|
||||
if obj.input_ids is None:
|
||||
input_text = obj.text[i]
|
||||
input_ids = self.tokenizer.encode(obj.text[i])
|
||||
else:
|
||||
input_text = None
|
||||
input_ids = obj.input_ids[i]
|
||||
else:
|
||||
input_text = None
|
||||
input_ids = obj.input_ids[i]
|
||||
if batch_size == 1:
|
||||
input_text = obj.text
|
||||
input_ids = obj.input_ids
|
||||
else:
|
||||
input_text = obj.text[i]
|
||||
input_ids = obj.input_ids[i]
|
||||
sampling_params = self._get_sampling_params(obj.sampling_params[index])
|
||||
pixel_values, image_hash, image_size = await self._get_pixel_values(
|
||||
obj.image_data[index]
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(**obj.sampling_params[i])
|
||||
if sampling_params.max_new_tokens != 0:
|
||||
sampling_params.normalize(self.tokenizer)
|
||||
sampling_params.verify()
|
||||
if obj.image_data[i] is None:
|
||||
pixel_values, image_hash, image_size = None, None, None
|
||||
else:
|
||||
pixel_values, image_hash, image_size = await self.get_pixel_values(
|
||||
obj.image_data[i]
|
||||
)
|
||||
tokenized_obj = TokenizedGenerateReqInput(
|
||||
rid=rid,
|
||||
input_text=input_text,
|
||||
input_ids=input_ids,
|
||||
pixel_values=pixel_values,
|
||||
image_hash=image_hash,
|
||||
image_size=image_size,
|
||||
sampling_params=sampling_params,
|
||||
return_logprob=obj.return_logprob[i],
|
||||
logprob_start_len=obj.logprob_start_len[i],
|
||||
top_logprobs_num=obj.top_logprobs_num[i],
|
||||
stream=obj.stream,
|
||||
rid,
|
||||
input_text,
|
||||
input_ids,
|
||||
pixel_values,
|
||||
image_hash,
|
||||
image_size,
|
||||
sampling_params,
|
||||
obj.return_logprob[index],
|
||||
obj.logprob_start_len[index],
|
||||
obj.top_logprobs_num[index],
|
||||
obj.stream,
|
||||
)
|
||||
self.send_to_router.send_pyobj(tokenized_obj)
|
||||
|
||||
@@ -248,9 +273,16 @@ class TokenizerManager:
|
||||
state = ReqState([], False, event)
|
||||
self.rid_to_state[rid] = state
|
||||
|
||||
output_list = []
|
||||
for i in range(bs):
|
||||
rid = obj.rid[i]
|
||||
# Then wait for all responses
|
||||
output_list = []
|
||||
for i in range(batch_size):
|
||||
for j in range(parallel_sample_num):
|
||||
if j == 0 and parallel_sample_num != 1:
|
||||
continue
|
||||
index = i * parallel_sample_num + j
|
||||
if parallel_sample_num != 1:
|
||||
index += batch_size - 1 - i
|
||||
rid = obj.rid[index]
|
||||
state = self.rid_to_state[rid]
|
||||
|
||||
while True:
|
||||
@@ -263,19 +295,86 @@ class TokenizerManager:
|
||||
self.abort_request(rid)
|
||||
raise ValueError(f"Abort request {rid}")
|
||||
continue
|
||||
|
||||
output_list.append(
|
||||
self.convert_logprob_style(
|
||||
state.out_list[-1],
|
||||
obj.return_logprob[i],
|
||||
obj.top_logprobs_num[i],
|
||||
obj.return_logprob[index],
|
||||
obj.top_logprobs_num[index],
|
||||
obj.return_text_in_logprobs,
|
||||
)
|
||||
)
|
||||
assert state.finished
|
||||
del self.rid_to_state[rid]
|
||||
|
||||
yield output_list
|
||||
yield output_list
|
||||
|
||||
def _validate_input_length(self, input_ids):
|
||||
if len(input_ids) >= self.context_len:
|
||||
raise ValueError(
|
||||
f"The input ({len(input_ids)} tokens) is longer than the "
|
||||
f"model's context length ({self.context_len} tokens)."
|
||||
)
|
||||
|
||||
def _get_sampling_params(self, sampling_params_data, max_new_tokens=None):
|
||||
sampling_params = SamplingParams(**sampling_params_data)
|
||||
if max_new_tokens is not None:
|
||||
sampling_params.max_new_tokens = max_new_tokens
|
||||
if sampling_params.max_new_tokens != 0:
|
||||
sampling_params.normalize(self.tokenizer)
|
||||
sampling_params.verify()
|
||||
return sampling_params
|
||||
|
||||
async def _get_pixel_values(self, image_data):
|
||||
if isinstance(image_data, list) and len(image_data) > 0:
|
||||
return await self.get_pixel_values(image_data[0])
|
||||
elif isinstance(image_data, str):
|
||||
return await self.get_pixel_values(image_data)
|
||||
else:
|
||||
return None, None, None
|
||||
|
||||
async def _wait_for_response(self, event, state, obj, rid, request):
|
||||
while True:
|
||||
try:
|
||||
await asyncio.wait_for(event.wait(), timeout=4)
|
||||
except asyncio.TimeoutError:
|
||||
if request is not None and await request.is_disconnected():
|
||||
self.abort_request(rid)
|
||||
raise ValueError(f"Abort request {rid}")
|
||||
continue
|
||||
|
||||
out = self.convert_logprob_style(
|
||||
state.out_list[-1],
|
||||
obj.return_logprob,
|
||||
obj.top_logprobs_num,
|
||||
obj.return_text_in_logprobs,
|
||||
)
|
||||
|
||||
if self.server_args.log_requests and state.finished:
|
||||
logger.info(f"in={obj.text}, out={out}")
|
||||
|
||||
state.out_list = []
|
||||
if state.finished:
|
||||
del self.rid_to_state[rid]
|
||||
yield out
|
||||
break
|
||||
|
||||
event.clear()
|
||||
yield out
|
||||
|
||||
async def _wait_for_prefill_response(self, event, state, obj, request, rid):
|
||||
while True:
|
||||
try:
|
||||
await asyncio.wait_for(state.event.wait(), timeout=4)
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
if request is not None and await request.is_disconnected():
|
||||
for rid in obj.rid:
|
||||
self.abort_request(rid)
|
||||
raise ValueError(f"Abort request {rid}")
|
||||
continue
|
||||
|
||||
assert state.finished
|
||||
del self.rid_to_state[rid]
|
||||
|
||||
def flush_cache(self):
|
||||
req = FlushCacheReq()
|
||||
|
||||
Reference in New Issue
Block a user