[Feat] Add modalities for vision server when handling pixel values for llava (#1346)
This commit is contained in:
committed by
GitHub
parent
8e6bdf851c
commit
662ecd9368
@@ -71,6 +71,7 @@ class Conversation:
|
||||
# Stop criteria (the default one is EOS token)
|
||||
stop_str: Union[str, List[str]] = None
|
||||
image_data: Optional[List[str]] = None
|
||||
modalities: Optional[List[str]] = None
|
||||
|
||||
def get_prompt(self) -> str:
|
||||
"""Get the prompt for generation."""
|
||||
@@ -379,6 +380,7 @@ def generate_chat_conv(
|
||||
sep2=conv.sep2,
|
||||
stop_str=conv.stop_str,
|
||||
image_data=[],
|
||||
modalities=[],
|
||||
)
|
||||
|
||||
if isinstance(request.messages, str):
|
||||
@@ -408,6 +410,7 @@ def generate_chat_conv(
|
||||
for content in message.content:
|
||||
if content.type == "image_url":
|
||||
num_image_url += 1
|
||||
conv.modalities.append(content.modalities)
|
||||
if num_image_url > 1:
|
||||
image_token = "<image>"
|
||||
else:
|
||||
|
||||
@@ -50,6 +50,8 @@ class GenerateReqInput:
|
||||
return_text_in_logprobs: bool = False
|
||||
# Whether to stream output.
|
||||
stream: bool = False
|
||||
# The modalities of the image data [image, multi-images, video]
|
||||
modalities: Optional[List[str]] = None
|
||||
|
||||
def post_init(self):
|
||||
if (self.text is None and self.input_ids is None) or (
|
||||
@@ -177,6 +179,8 @@ class TokenizedGenerateReqInput:
|
||||
top_logprobs_num: int
|
||||
# Whether to stream output
|
||||
stream: bool
|
||||
# Modalities of the input images
|
||||
modalites: Optional[List[str]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -130,6 +130,7 @@ class Req:
|
||||
self.image_sizes = None
|
||||
self.image_offsets = None
|
||||
self.pad_value = None
|
||||
self.modalities = None
|
||||
|
||||
# Prefix info
|
||||
self.extend_input_len = 0
|
||||
|
||||
@@ -188,6 +188,7 @@ class TokenizerManager:
|
||||
pixel_values, image_hashes, image_sizes = await self._get_pixel_values(
|
||||
obj.image_data if not_use_index else obj.image_data[index]
|
||||
)
|
||||
modalities = obj.modalities
|
||||
return_logprob = (
|
||||
obj.return_logprob if not_use_index else obj.return_logprob[index]
|
||||
)
|
||||
@@ -243,6 +244,7 @@ class TokenizerManager:
|
||||
pixel_values, image_hashes, image_sizes = await self._get_pixel_values(
|
||||
obj.image_data[0]
|
||||
)
|
||||
modalities = obj.modalities
|
||||
return_logprob = obj.return_logprob[0]
|
||||
logprob_start_len = obj.logprob_start_len[0]
|
||||
top_logprobs_num = obj.top_logprobs_num[0]
|
||||
@@ -263,6 +265,7 @@ class TokenizerManager:
|
||||
logprob_start_len,
|
||||
top_logprobs_num,
|
||||
obj.stream,
|
||||
modalities,
|
||||
)
|
||||
else: # is embedding
|
||||
tokenized_obj = TokenizedEmbeddingReqInput(
|
||||
@@ -346,6 +349,7 @@ class TokenizerManager:
|
||||
pixel_values, image_hashes, image_sizes = (
|
||||
await self._get_pixel_values(obj.image_data[index])
|
||||
)
|
||||
modalities = obj.modalities
|
||||
|
||||
tokenized_obj = TokenizedGenerateReqInput(
|
||||
rid,
|
||||
@@ -359,6 +363,7 @@ class TokenizerManager:
|
||||
obj.logprob_start_len[index],
|
||||
obj.top_logprobs_num[index],
|
||||
obj.stream,
|
||||
modalities,
|
||||
)
|
||||
else:
|
||||
tokenized_obj = TokenizedEmbeddingReqInput(
|
||||
|
||||
@@ -358,6 +358,8 @@ class ModelTpServer:
|
||||
req.pixel_values,
|
||||
req.image_sizes,
|
||||
)
|
||||
# Only when pixel values is not None we have modalities
|
||||
req.modalities = recv_req.modalites
|
||||
req.return_logprob = recv_req.return_logprob
|
||||
req.logprob_start_len = recv_req.logprob_start_len
|
||||
req.top_logprobs_num = recv_req.top_logprobs_num
|
||||
|
||||
@@ -78,6 +78,7 @@ class InputMetadata:
|
||||
pixel_values: List[torch.Tensor] = None
|
||||
image_sizes: List[List[List[int]]] = None
|
||||
image_offsets: List[List[int]] = None
|
||||
modalities: List[List[str]] = None
|
||||
|
||||
# Trition attention backend
|
||||
triton_max_seq_len: int = 0
|
||||
@@ -96,6 +97,7 @@ class InputMetadata:
|
||||
self.pixel_values = [r.pixel_values for r in reqs]
|
||||
self.image_sizes = [r.image_sizes for r in reqs]
|
||||
self.image_offsets = [r.image_offsets for r in reqs]
|
||||
self.modalities = [r.modalities for r in reqs]
|
||||
|
||||
def compute_positions(self, batch: ScheduleBatch):
|
||||
position_ids_offsets = batch.position_ids_offsets
|
||||
|
||||
@@ -138,6 +138,12 @@ class LlavaBaseForCausalLM(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
if input_metadata.forward_mode == ForwardMode.EXTEND:
|
||||
bs = input_metadata.batch_size
|
||||
# Got List[List[str]] extend it to List[str]
|
||||
# The length of the List should be equal to batch size
|
||||
modalities_list = []
|
||||
for modalities in input_metadata.modalities:
|
||||
if modalities is not None:
|
||||
modalities_list.extend(modalities)
|
||||
|
||||
# Embed text inputs
|
||||
input_embeds = self.language_model.model.embed_tokens(input_ids)
|
||||
@@ -179,7 +185,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
||||
new_image_features = []
|
||||
height = width = self.num_patches_per_side
|
||||
for image_idx, image_feature in enumerate(image_features):
|
||||
if len(image_sizes[image_idx]) == 1:
|
||||
if modalities_list[image_idx] == 1:
|
||||
image_aspect_ratio = (
|
||||
self.config.image_aspect_ratio
|
||||
) # single image
|
||||
@@ -191,6 +197,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
||||
if (
|
||||
image_feature.shape[0] > 1
|
||||
and "anyres" in image_aspect_ratio
|
||||
and modalities_list[image_idx] == "image"
|
||||
):
|
||||
base_image_feature = image_feature[0]
|
||||
image_feature = image_feature[1:]
|
||||
@@ -290,7 +297,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
||||
)
|
||||
image_feature = image_feature.unsqueeze(0)
|
||||
else:
|
||||
if image_feature.shape[0] > 16: # video
|
||||
if modalities_list[image_idx] == "video": # video
|
||||
# 2x2 pooling
|
||||
num_of_frames = image_feature.shape[0]
|
||||
image_feature = image_feature.view(
|
||||
|
||||
@@ -832,6 +832,7 @@ def v1_chat_generate_request(
|
||||
return_logprobs = []
|
||||
logprob_start_lens = []
|
||||
top_logprobs_nums = []
|
||||
modalities_list = []
|
||||
|
||||
# NOTE: with openai API, the prompt's logprobs are always not computed
|
||||
|
||||
@@ -864,10 +865,12 @@ def v1_chat_generate_request(
|
||||
)
|
||||
stop = request.stop
|
||||
image_data = None
|
||||
modalities = []
|
||||
else:
|
||||
conv = generate_chat_conv(request, chat_template_name)
|
||||
prompt = conv.get_prompt()
|
||||
image_data = conv.image_data
|
||||
modalities = conv.modalities
|
||||
stop = conv.stop_str or []
|
||||
if request.stop:
|
||||
if isinstance(request.stop, str):
|
||||
@@ -880,6 +883,7 @@ def v1_chat_generate_request(
|
||||
prompt_ids = request.messages
|
||||
stop = request.stop
|
||||
image_data = None
|
||||
modalities = []
|
||||
input_ids.append(prompt_ids)
|
||||
return_logprobs.append(request.logprobs)
|
||||
logprob_start_lens.append(-1)
|
||||
@@ -901,6 +905,7 @@ def v1_chat_generate_request(
|
||||
}
|
||||
)
|
||||
image_data_list.append(image_data)
|
||||
modalities_list.extend(modalities)
|
||||
if len(all_requests) == 1:
|
||||
input_ids = input_ids[0]
|
||||
if isinstance(input_ids, str):
|
||||
@@ -912,6 +917,7 @@ def v1_chat_generate_request(
|
||||
return_logprobs = return_logprobs[0]
|
||||
logprob_start_lens = logprob_start_lens[0]
|
||||
top_logprobs_nums = top_logprobs_nums[0]
|
||||
modalities_list = modalities_list[:1]
|
||||
else:
|
||||
if isinstance(input_ids[0], str):
|
||||
prompt_kwargs = {"text": input_ids}
|
||||
@@ -928,6 +934,7 @@ def v1_chat_generate_request(
|
||||
stream=all_requests[0].stream,
|
||||
return_text_in_logprobs=True,
|
||||
rid=request_ids,
|
||||
modalities=modalities_list,
|
||||
)
|
||||
if len(all_requests) == 1:
|
||||
return adapted_request, all_requests[0]
|
||||
|
||||
@@ -213,6 +213,7 @@ class ChatCompletionMessageContentImageURL(BaseModel):
|
||||
class ChatCompletionMessageContentImagePart(BaseModel):
|
||||
type: Literal["image_url"]
|
||||
image_url: ChatCompletionMessageContentImageURL
|
||||
modalities: Optional[Literal["image", "multi-images", "video"]] = "image"
|
||||
|
||||
|
||||
ChatCompletionMessageContentPart = Union[
|
||||
|
||||
Reference in New Issue
Block a user