Organize image inputs (#1531)

This commit is contained in:
Liangsheng Yin
2024-09-28 23:28:55 -07:00
committed by GitHub
parent e165a9fc1b
commit fd9ad817ec
8 changed files with 121 additions and 132 deletions

View File

@@ -25,7 +25,7 @@ import torch
if TYPE_CHECKING:
from sglang.srt.layers.attention_backend import AttentionBackend
from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.managers.schedule_batch import ImageInputs, ScheduleBatch
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.model_executor.model_runner import ModelRunner
@@ -84,17 +84,10 @@ class InputMetadata:
extend_logprob_start_lens_cpu: List[int] = None
# For multimodal
pixel_values: List[torch.Tensor] = None
image_sizes: List[List[List[int]]] = None
image_offsets: List[List[int]] = None
modalities: List[List[str]] = None
image_inputs: List[ImageInputs] = None
def init_multimuldal_info(self, batch: ScheduleBatch):
reqs = batch.reqs
self.pixel_values = [r.pixel_values for r in reqs]
self.image_sizes = [r.image_sizes for r in reqs]
self.image_offsets = [r.image_offsets for r in reqs]
self.modalities = [r.modalities for r in reqs]
self.image_inputs = [r.image_inputs for r in batch.reqs]
def compute_positions(self, batch: ScheduleBatch):
if self.forward_mode.is_decode():

View File

@@ -498,23 +498,10 @@ class ModelRunner:
get_embedding=True,
)
def forward_extend_multi_modal(self, batch: ScheduleBatch):
input_metadata = InputMetadata.from_schedule_batch(self, batch)
return self.model.forward(
batch.input_ids,
input_metadata.positions,
input_metadata,
input_metadata.pixel_values,
input_metadata.image_sizes,
input_metadata.image_offsets,
)
def forward(self, batch: ScheduleBatch) -> Tuple[LogitsProcessorOutput]:
assert batch.forward_mode is not None
if self.is_multimodal_model and batch.forward_mode.is_extend():
return self.forward_extend_multi_modal(batch)
elif batch.forward_mode.is_decode():
if batch.forward_mode.is_decode():
return self.forward_decode(batch)
elif batch.forward_mode.is_extend():
return self.forward_extend(batch)