Organize image inputs (#1531)
This commit is contained in:
@@ -25,7 +25,7 @@ import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.attention_backend import AttentionBackend
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
from sglang.srt.managers.schedule_batch import ImageInputs, ScheduleBatch
|
||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
|
||||
@@ -84,17 +84,10 @@ class InputMetadata:
|
||||
extend_logprob_start_lens_cpu: List[int] = None
|
||||
|
||||
# For multimodal
|
||||
pixel_values: List[torch.Tensor] = None
|
||||
image_sizes: List[List[List[int]]] = None
|
||||
image_offsets: List[List[int]] = None
|
||||
modalities: List[List[str]] = None
|
||||
image_inputs: List[ImageInputs] = None
|
||||
|
||||
def init_multimuldal_info(self, batch: ScheduleBatch):
|
||||
reqs = batch.reqs
|
||||
self.pixel_values = [r.pixel_values for r in reqs]
|
||||
self.image_sizes = [r.image_sizes for r in reqs]
|
||||
self.image_offsets = [r.image_offsets for r in reqs]
|
||||
self.modalities = [r.modalities for r in reqs]
|
||||
self.image_inputs = [r.image_inputs for r in batch.reqs]
|
||||
|
||||
def compute_positions(self, batch: ScheduleBatch):
|
||||
if self.forward_mode.is_decode():
|
||||
|
||||
@@ -498,23 +498,10 @@ class ModelRunner:
|
||||
get_embedding=True,
|
||||
)
|
||||
|
||||
def forward_extend_multi_modal(self, batch: ScheduleBatch):
|
||||
input_metadata = InputMetadata.from_schedule_batch(self, batch)
|
||||
return self.model.forward(
|
||||
batch.input_ids,
|
||||
input_metadata.positions,
|
||||
input_metadata,
|
||||
input_metadata.pixel_values,
|
||||
input_metadata.image_sizes,
|
||||
input_metadata.image_offsets,
|
||||
)
|
||||
|
||||
def forward(self, batch: ScheduleBatch) -> Tuple[LogitsProcessorOutput]:
|
||||
assert batch.forward_mode is not None
|
||||
|
||||
if self.is_multimodal_model and batch.forward_mode.is_extend():
|
||||
return self.forward_extend_multi_modal(batch)
|
||||
elif batch.forward_mode.is_decode():
|
||||
if batch.forward_mode.is_decode():
|
||||
return self.forward_decode(batch)
|
||||
elif batch.forward_mode.is_extend():
|
||||
return self.forward_extend(batch)
|
||||
|
||||
Reference in New Issue
Block a user