Fix hash collision for multi modal models (#2256)

This commit is contained in:
Lianmin Zheng
2024-11-29 03:15:58 -08:00
committed by GitHub
parent fe97a2d40f
commit f50a6cf443
6 changed files with 42 additions and 39 deletions

View File

@@ -526,8 +526,9 @@ class Scheduler:
self,
recv_req: TokenizedGenerateReqInput,
):
# Create a new request
if recv_req.session_id is None or recv_req.session_id not in self.sessions:
# Create a new request
if recv_req.input_embeds is not None:
# Generate fake input_ids based on the length of input_embeds
seq_length = len(recv_req.input_embeds)
@@ -558,20 +559,20 @@ class Scheduler:
self.waiting_queue.append(req)
return
# Image inputs
# Handle image inputs
if recv_req.image_inputs is not None:
image_inputs = ImageInputs.from_dict(
recv_req.image_inputs, self.model_config.vocab_size
)
image_inputs = ImageInputs.from_dict(recv_req.image_inputs)
# Expand a single image token into multiple dummy tokens for receiving image embeddings
req.origin_input_ids = self.pad_input_ids_func(
req.origin_input_ids, image_inputs
)
req.extend_image_inputs(image_inputs, self.model_config.vocab_size)
req.extend_image_inputs(image_inputs)
if len(req.origin_input_ids) > self.max_req_input_len:
req.finished_reason = FINISH_ABORT(
"Image request length is longer than the KV cache pool size or "
"the max context length aborting because you cannot truncate the image embeds"
"the max context length. "
"Abort this request because you cannot truncate the image embeds"
)
req.image_inputs = None
req.origin_input_ids = [0]
@@ -579,6 +580,7 @@ class Scheduler:
self.waiting_queue.append(req)
return
# Copy more attributes
req.return_logprob = recv_req.return_logprob
req.top_logprobs_num = recv_req.top_logprobs_num
req.stream = recv_req.stream