Organize image inputs (#1531)
This commit is contained in:
@@ -172,12 +172,8 @@ class TokenizedGenerateReqInput:
|
||||
input_text: str
|
||||
# The input token ids
|
||||
input_ids: List[int]
|
||||
# The pixel values for input images
|
||||
pixel_values: List[float]
|
||||
# The hash values of input images
|
||||
image_hashes: List[int]
|
||||
# The image sizes
|
||||
image_sizes: List[List[int]]
|
||||
# The image input
|
||||
image_inputs: dict
|
||||
# The sampling parameters
|
||||
sampling_params: SamplingParams
|
||||
# Whether to return the logprobs
|
||||
@@ -188,8 +184,6 @@ class TokenizedGenerateReqInput:
|
||||
top_logprobs_num: int
|
||||
# Whether to stream output
|
||||
stream: bool
|
||||
# Modalities of the input images
|
||||
modalites: Optional[List[str]] = None
|
||||
|
||||
# LoRA related
|
||||
lora_path: Optional[str] = None # None means just use the base model
|
||||
|
||||
@@ -102,6 +102,39 @@ class FINISH_ABORT(BaseFinishReason):
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageInputs:
|
||||
pixel_values: torch.Tensor
|
||||
image_hash: int
|
||||
image_sizes: Optional[list] = None
|
||||
image_offsets: Optional[list] = None
|
||||
pad_values: Optional[list] = None
|
||||
modalities: Optional[list] = None
|
||||
|
||||
image_embeds: Optional[List[torch.Tensor]] = None
|
||||
aspect_ratio_ids: Optional[List[torch.Tensor]] = None
|
||||
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
|
||||
|
||||
@staticmethod
|
||||
def from_dict(obj, vocab_size):
|
||||
# Use image hash as fake token_ids, which is then used for prefix matching
|
||||
ret = ImageInputs(
|
||||
pixel_values=obj["pixel_values"],
|
||||
image_hash=hash(tuple(obj["image_hashes"])),
|
||||
)
|
||||
image_hash = ret.image_hash
|
||||
ret.pad_values = [
|
||||
(image_hash) % vocab_size,
|
||||
(image_hash >> 16) % vocab_size,
|
||||
(image_hash >> 32) % vocab_size,
|
||||
(image_hash >> 64) % vocab_size,
|
||||
]
|
||||
ret.image_sizes = obj["image_sizes"]
|
||||
# Only when pixel values is not None we have modalities
|
||||
ret.modalities = obj["modalities"]
|
||||
return ret
|
||||
|
||||
|
||||
class Req:
|
||||
"""Store all inforamtion of a request."""
|
||||
|
||||
@@ -147,11 +180,7 @@ class Req:
|
||||
self.completion_tokens_wo_jump_forward = 0
|
||||
|
||||
# For vision inputs
|
||||
self.pixel_values = None
|
||||
self.image_sizes = None
|
||||
self.image_offsets = None
|
||||
self.pad_value = None
|
||||
self.modalities = None
|
||||
self.image_inputs: Optional[ImageInputs] = None
|
||||
|
||||
# Prefix info
|
||||
self.prefix_indices = []
|
||||
@@ -654,15 +683,9 @@ class ScheduleBatch:
|
||||
self.tree_cache.cache_finished_req(req, cur_all_ids)
|
||||
|
||||
# re-applying image padding
|
||||
if req.pixel_values is not None:
|
||||
(
|
||||
req.origin_input_ids,
|
||||
req.image_offsets,
|
||||
) = model_runner.model.pad_input_ids(
|
||||
req.origin_input_ids_unpadded,
|
||||
req.pad_value,
|
||||
req.pixel_values,
|
||||
req.image_sizes,
|
||||
if req.image_inputs is not None:
|
||||
req.origin_input_ids = model_runner.model.pad_input_ids(
|
||||
req.origin_input_ids_unpadded, req.image_inputs
|
||||
)
|
||||
|
||||
jump_forward_reqs.append(req)
|
||||
|
||||
@@ -194,10 +194,9 @@ class TokenizerManager:
|
||||
)
|
||||
|
||||
if self.is_generation:
|
||||
pixel_values, image_hashes, image_sizes = await self._get_pixel_values(
|
||||
obj.image_data if not_use_index else obj.image_data[index]
|
||||
image_inputs = await self._get_image_inputs(
|
||||
obj, obj.image_data if not_use_index else obj.image_data[index]
|
||||
)
|
||||
modalities = obj.modalities
|
||||
return_logprob = (
|
||||
obj.return_logprob if not_use_index else obj.return_logprob[index]
|
||||
)
|
||||
@@ -248,10 +247,7 @@ class TokenizerManager:
|
||||
|
||||
sampling_params = SamplingParams(**obj.sampling_params[0])
|
||||
sampling_params.max_new_tokens = 0
|
||||
pixel_values, image_hashes, image_sizes = await self._get_pixel_values(
|
||||
obj.image_data[0]
|
||||
)
|
||||
modalities = obj.modalities
|
||||
image_inputs = await self._get_image_inputs(obj, obj.image_data[0])
|
||||
return_logprob = obj.return_logprob[0]
|
||||
logprob_start_len = obj.logprob_start_len[0]
|
||||
top_logprobs_num = obj.top_logprobs_num[0]
|
||||
@@ -262,15 +258,12 @@ class TokenizerManager:
|
||||
rid,
|
||||
input_text,
|
||||
input_ids,
|
||||
pixel_values,
|
||||
image_hashes,
|
||||
image_sizes,
|
||||
image_inputs,
|
||||
sampling_params,
|
||||
return_logprob,
|
||||
logprob_start_len,
|
||||
top_logprobs_num,
|
||||
obj.stream,
|
||||
modalities,
|
||||
(
|
||||
obj.lora_path[index]
|
||||
if isinstance(obj.lora_path, list)
|
||||
@@ -369,24 +362,20 @@ class TokenizerManager:
|
||||
sampling_params = self._get_sampling_params(obj.sampling_params[index])
|
||||
|
||||
if self.is_generation:
|
||||
pixel_values, image_hashes, image_sizes = (
|
||||
await self._get_pixel_values(obj.image_data[index])
|
||||
image_inputs = await self._get_image_inputs(
|
||||
obj, obj.image_data[index]
|
||||
)
|
||||
modalities = obj.modalities
|
||||
|
||||
tokenized_obj = TokenizedGenerateReqInput(
|
||||
rid,
|
||||
input_text,
|
||||
input_ids,
|
||||
pixel_values,
|
||||
image_hashes,
|
||||
image_sizes,
|
||||
image_inputs,
|
||||
sampling_params,
|
||||
obj.return_logprob[index],
|
||||
obj.logprob_start_len[index],
|
||||
obj.top_logprobs_num[index],
|
||||
obj.stream,
|
||||
modalities,
|
||||
(
|
||||
obj.lora_path[index]
|
||||
if isinstance(obj.lora_path, list)
|
||||
@@ -697,10 +686,11 @@ class TokenizerManager:
|
||||
)
|
||||
return top_logprobs
|
||||
|
||||
async def _get_pixel_values(self, image_data: List[Union[str, bytes]]):
|
||||
async def _get_image_inputs(self, obj, image_data: List[Union[str, bytes]]):
|
||||
if not image_data:
|
||||
return None, None, None
|
||||
return None
|
||||
|
||||
# TODO: move this into a processor for each vision architecture
|
||||
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
|
||||
grid_pinpoints = (
|
||||
self.hf_config.image_grid_pinpoints
|
||||
@@ -741,7 +731,12 @@ class TokenizerManager:
|
||||
else:
|
||||
raise ValueError(f"Invalid image data: {image_data}")
|
||||
|
||||
return pixel_values, image_hashes, image_sizes
|
||||
return {
|
||||
"pixel_values": pixel_values,
|
||||
"image_hashes": image_hashes,
|
||||
"image_sizes": image_sizes,
|
||||
"modalities": obj.modalities,
|
||||
}
|
||||
|
||||
async def _process_single_image(
|
||||
self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str
|
||||
|
||||
@@ -49,6 +49,7 @@ from sglang.srt.managers.policy_scheduler import PolicyScheduler, PrefillAdder
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
FINISH_ABORT,
|
||||
BaseFinishReason,
|
||||
ImageInputs,
|
||||
Req,
|
||||
ScheduleBatch,
|
||||
)
|
||||
@@ -340,29 +341,16 @@ class ModelTpServer:
|
||||
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
|
||||
req.tokenizer = self.tokenizer
|
||||
req.sampling_params = recv_req.sampling_params
|
||||
req.pixel_values = recv_req.pixel_values
|
||||
if req.pixel_values is not None:
|
||||
# Use image hash as fake token_ids, which is then used
|
||||
# for prefix matching
|
||||
image_hash = hash(tuple(recv_req.image_hashes))
|
||||
req.pad_value = [
|
||||
(image_hash) % self.model_config.vocab_size,
|
||||
(image_hash >> 16) % self.model_config.vocab_size,
|
||||
(image_hash >> 32) % self.model_config.vocab_size,
|
||||
(image_hash >> 64) % self.model_config.vocab_size,
|
||||
]
|
||||
req.image_sizes = recv_req.image_sizes
|
||||
(
|
||||
req.origin_input_ids,
|
||||
req.image_offsets,
|
||||
) = self.model_runner.model.pad_input_ids(
|
||||
req.origin_input_ids_unpadded,
|
||||
req.pad_value,
|
||||
req.pixel_values,
|
||||
req.image_sizes,
|
||||
|
||||
# Image inputs
|
||||
if recv_req.image_inputs is not None:
|
||||
req.image_inputs = ImageInputs.from_dict(
|
||||
recv_req.image_inputs, self.model_config.vocab_size
|
||||
)
|
||||
# Only when pixel values is not None we have modalities
|
||||
req.modalities = recv_req.modalites
|
||||
req.origin_input_ids = self.model_runner.model.pad_input_ids(
|
||||
req.origin_input_ids_unpadded, req.image_inputs
|
||||
)
|
||||
|
||||
req.return_logprob = recv_req.return_logprob
|
||||
req.top_logprobs_num = recv_req.top_logprobs_num
|
||||
req.stream = recv_req.stream
|
||||
|
||||
@@ -25,7 +25,7 @@ import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.attention_backend import AttentionBackend
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
from sglang.srt.managers.schedule_batch import ImageInputs, ScheduleBatch
|
||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
|
||||
@@ -84,17 +84,10 @@ class InputMetadata:
|
||||
extend_logprob_start_lens_cpu: List[int] = None
|
||||
|
||||
# For multimodal
|
||||
pixel_values: List[torch.Tensor] = None
|
||||
image_sizes: List[List[List[int]]] = None
|
||||
image_offsets: List[List[int]] = None
|
||||
modalities: List[List[str]] = None
|
||||
image_inputs: List[ImageInputs] = None
|
||||
|
||||
def init_multimuldal_info(self, batch: ScheduleBatch):
|
||||
reqs = batch.reqs
|
||||
self.pixel_values = [r.pixel_values for r in reqs]
|
||||
self.image_sizes = [r.image_sizes for r in reqs]
|
||||
self.image_offsets = [r.image_offsets for r in reqs]
|
||||
self.modalities = [r.modalities for r in reqs]
|
||||
self.image_inputs = [r.image_inputs for r in batch.reqs]
|
||||
|
||||
def compute_positions(self, batch: ScheduleBatch):
|
||||
if self.forward_mode.is_decode():
|
||||
|
||||
@@ -498,23 +498,10 @@ class ModelRunner:
|
||||
get_embedding=True,
|
||||
)
|
||||
|
||||
def forward_extend_multi_modal(self, batch: ScheduleBatch):
|
||||
input_metadata = InputMetadata.from_schedule_batch(self, batch)
|
||||
return self.model.forward(
|
||||
batch.input_ids,
|
||||
input_metadata.positions,
|
||||
input_metadata,
|
||||
input_metadata.pixel_values,
|
||||
input_metadata.image_sizes,
|
||||
input_metadata.image_offsets,
|
||||
)
|
||||
|
||||
def forward(self, batch: ScheduleBatch) -> Tuple[LogitsProcessorOutput]:
|
||||
assert batch.forward_mode is not None
|
||||
|
||||
if self.is_multimodal_model and batch.forward_mode.is_extend():
|
||||
return self.forward_extend_multi_modal(batch)
|
||||
elif batch.forward_mode.is_decode():
|
||||
if batch.forward_mode.is_decode():
|
||||
return self.forward_decode(batch)
|
||||
elif batch.forward_mode.is_extend():
|
||||
return self.forward_extend(batch)
|
||||
|
||||
@@ -35,25 +35,22 @@ from vllm.config import CacheConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.managers.schedule_batch import ImageInputs
|
||||
from sglang.srt.mm_utils import (
|
||||
get_anyres_image_grid_shape,
|
||||
unpad_image,
|
||||
unpad_image_shape,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
from sglang.srt.models.llama import LlamaForCausalLM
|
||||
from sglang.srt.models.mistral import MistralForCausalLM
|
||||
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
||||
|
||||
|
||||
class LlavaBaseForCausalLM(nn.Module):
|
||||
def pad_input_ids(
|
||||
self,
|
||||
input_ids: List[int],
|
||||
pad_value: List[int],
|
||||
pixel_values: List,
|
||||
image_sizes: List[List[int]],
|
||||
):
|
||||
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
|
||||
image_sizes, pad_values = image_inputs.image_sizes, image_inputs.pad_values
|
||||
|
||||
# hardcode for spatial_unpad + anyres
|
||||
image_aspect_ratio = "anyres" if len(image_sizes) == 1 else "pad"
|
||||
offset_list = []
|
||||
@@ -92,8 +89,8 @@ class LlavaBaseForCausalLM(nn.Module):
|
||||
new_w = int(new_w // times)
|
||||
new_image_feature_len += new_h * (new_w + 1)
|
||||
|
||||
pad_ids = pad_value * (
|
||||
(new_image_feature_len + len(pad_value)) // len(pad_value)
|
||||
pad_ids = pad_values * (
|
||||
(new_image_feature_len + len(pad_values)) // len(pad_values)
|
||||
)
|
||||
# print("calculated new_image_feature_len: ", new_image_feature_len)
|
||||
try:
|
||||
@@ -107,7 +104,9 @@ class LlavaBaseForCausalLM(nn.Module):
|
||||
+ input_ids[offset + 1 :]
|
||||
)
|
||||
offset_list.append(offset)
|
||||
return input_ids, offset_list
|
||||
|
||||
image_inputs.image_offsets = offset_list
|
||||
return input_ids
|
||||
|
||||
def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||
@@ -132,32 +131,39 @@ class LlavaBaseForCausalLM(nn.Module):
|
||||
input_ids: torch.LongTensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
pixel_values: Optional[List[Optional[np.array]]] = None,
|
||||
image_sizes: Optional[List[List[int]]] = None,
|
||||
image_offsets: Optional[List[int]] = None,
|
||||
) -> torch.Tensor:
|
||||
image_inputs = input_metadata.image_inputs
|
||||
|
||||
if input_metadata.forward_mode.is_extend():
|
||||
bs = input_metadata.batch_size
|
||||
# Got List[List[str]] extend it to List[str]
|
||||
# The length of the List should be equal to batch size
|
||||
modalities_list = []
|
||||
for modalities in input_metadata.modalities:
|
||||
if modalities is not None:
|
||||
modalities_list.extend(modalities)
|
||||
max_image_offset = []
|
||||
for im in image_inputs:
|
||||
if im and im.modalities is not None:
|
||||
modalities_list.extend(im.modalities)
|
||||
if im and im.image_offsets is not None:
|
||||
max_image_offset.append(max(im.image_offsets))
|
||||
else:
|
||||
max_image_offset.append(-1)
|
||||
|
||||
# Embed text inputs
|
||||
input_embeds = self.language_model.model.embed_tokens(input_ids)
|
||||
|
||||
# Whether the requests need vision inputs
|
||||
max_image_offset = np.array(
|
||||
[max(image_offsets[i]) if image_offsets[i] else -1 for i in range(bs)]
|
||||
)
|
||||
start_positions = positions[input_metadata.extend_start_loc].cpu().numpy()
|
||||
need_vision = start_positions <= max_image_offset
|
||||
need_vision = start_positions <= np.array(max_image_offset)
|
||||
|
||||
if need_vision.any():
|
||||
pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]]
|
||||
image_sizes = [image_sizes[i] for i in range(bs) if need_vision[i]]
|
||||
pixel_values = [
|
||||
image_inputs[i].pixel_values for i in range(bs) if need_vision[i]
|
||||
]
|
||||
image_sizes = [
|
||||
image_inputs[i].image_sizes for i in range(bs) if need_vision[i]
|
||||
]
|
||||
image_offsets = [
|
||||
image_inputs[i].image_offsets for i in range(bs) if need_vision[i]
|
||||
]
|
||||
|
||||
########## Encode Image ########
|
||||
|
||||
|
||||
@@ -26,7 +26,8 @@ from vllm.config import CacheConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
||||
from sglang.srt.managers.schedule_batch import ImageInputs
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
from sglang.srt.models.llama import LlamaForCausalLM
|
||||
|
||||
|
||||
@@ -54,17 +55,12 @@ class LlavaVidForCausalLM(nn.Module):
|
||||
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
|
||||
)
|
||||
|
||||
def pad_input_ids(
|
||||
self,
|
||||
input_ids: List[int],
|
||||
pad_value: List[int],
|
||||
pixel_values: List,
|
||||
image_sizes: List[List[int]],
|
||||
):
|
||||
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
|
||||
pad_values = image_inputs.pad_values
|
||||
new_image_feature_len = self.image_feature_len
|
||||
|
||||
pad_ids = pad_value * (
|
||||
(new_image_feature_len + len(pad_value)) // len(pad_value)
|
||||
pad_ids = pad_values * (
|
||||
(new_image_feature_len + len(pad_values)) // len(pad_values)
|
||||
)
|
||||
offset = input_ids.index(self.config.image_token_index)
|
||||
# old_len + pad_len - 1, because we need to remove image_token_id
|
||||
@@ -73,7 +69,8 @@ class LlavaVidForCausalLM(nn.Module):
|
||||
+ pad_ids[:new_image_feature_len]
|
||||
+ input_ids[offset + 1 :]
|
||||
)
|
||||
return new_input_ids, [offset]
|
||||
image_inputs.image_offsets = [offset]
|
||||
return new_input_ids
|
||||
|
||||
def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||
@@ -112,10 +109,8 @@ class LlavaVidForCausalLM(nn.Module):
|
||||
input_ids: torch.LongTensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
pixel_values: Optional[List[Optional[np.array]]] = None,
|
||||
image_sizes: Optional[List[List[int]]] = None,
|
||||
image_offsets: Optional[List[int]] = None,
|
||||
) -> torch.Tensor:
|
||||
image_inputs = input_metadata.image_inputs
|
||||
if input_metadata.forward_mode.is_extend():
|
||||
bs = input_metadata.batch_size
|
||||
|
||||
@@ -123,14 +118,22 @@ class LlavaVidForCausalLM(nn.Module):
|
||||
input_embeds = self.language_model.model.embed_tokens(input_ids)
|
||||
|
||||
# Whether the requests need vision inputs
|
||||
max_image_offset = np.array(
|
||||
[max(image_offsets[i]) if image_offsets[i] else -1 for i in range(bs)]
|
||||
)
|
||||
max_image_offset = []
|
||||
for im in image_inputs:
|
||||
if im and im.image_offsets:
|
||||
max_image_offset.append(max(im.image_offsets))
|
||||
else:
|
||||
max_image_offset.append(-1)
|
||||
start_positions = positions[input_metadata.extend_start_loc].cpu().numpy()
|
||||
need_vision = start_positions <= max_image_offset
|
||||
need_vision = start_positions <= np.array(max_image_offset)
|
||||
|
||||
if need_vision.any():
|
||||
pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]]
|
||||
pixel_values = [
|
||||
image_inputs[i].pixel_values for i in range(bs) if need_vision[i]
|
||||
]
|
||||
image_offsets = [
|
||||
image_inputs[i].image_offsets for i in range(bs) if need_vision[i]
|
||||
]
|
||||
|
||||
########## Encode Image ########
|
||||
|
||||
|
||||
Reference in New Issue
Block a user