model: Minicpmo (#3023)

This commit is contained in:
Mick
2025-03-25 11:08:40 +08:00
committed by GitHub
parent 64129fa632
commit 1e86457c90
40 changed files with 2906 additions and 493 deletions

View File

@@ -43,7 +43,7 @@ from sglang.srt.utils import get_compiler_backend
if TYPE_CHECKING:
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, MultimodalInputs
from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
@@ -176,7 +176,7 @@ class ForwardBatch:
extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None
# For multimodal
image_inputs: Optional[List[ImageInputs]] = None
mm_inputs: Optional[List[MultimodalInputs]] = None
# Encoder-decoder
encoder_cached: Optional[List[bool]] = None
@@ -242,7 +242,7 @@ class ForwardBatch:
req_pool_indices=batch.req_pool_indices,
seq_lens=batch.seq_lens,
out_cache_loc=batch.out_cache_loc,
image_inputs=batch.image_inputs,
mm_inputs=batch.multimodal_inputs,
encoder_cached=batch.encoder_cached,
encoder_lens=batch.encoder_lens,
encoder_lens_cpu=batch.encoder_lens_cpu,
@@ -332,42 +332,53 @@ class ForwardBatch:
return ret
def merge_image_inputs(self) -> Optional[ImageInputs]:
def merge_mm_inputs(self) -> Optional[MultimodalInputs]:
"""
Merge all image inputs in the batch into a single ImageInputs object.
Merge all image inputs in the batch into a single MultiModalInputs object.
Returns:
if none, current batch contains no image input
"""
if not self.image_inputs or all(x is None for x in self.image_inputs):
if not self.mm_inputs or all(x is None for x in self.mm_inputs):
return None
# Filter out None values
valid_inputs = [x for x in self.image_inputs if x is not None]
valid_inputs = [x for x in self.mm_inputs if x is not None]
# Start with the first valid image input
merged = valid_inputs[0]
# Merge remaining inputs
for img_input in valid_inputs[1:]:
merged.merge(img_input)
for mm_input in valid_inputs[1:]:
merged.merge(mm_input)
if isinstance(merged.pixel_values, np.ndarray):
merged.pixel_values = torch.from_numpy(merged.pixel_values)
if isinstance(merged.audio_features, np.ndarray):
merged.audio_features = torch.from_numpy(merged.audio_features)
return merged
def contains_image_inputs(self) -> bool:
""" """
if self.image_inputs is None:
return True
if self.mm_inputs is None:
return False
return any(
image_input.pixel_values is not None and image_input.pixel_values is not []
for image_input in self.image_inputs
if image_input is not None
mm_input is not None and mm_input.contains_image_inputs()
for mm_input in self.mm_inputs
)
def contains_audio_inputs(self) -> bool:
if self.mm_inputs is None:
return False
return any(
mm_input is not None and mm_input.contains_audio_inputs()
for mm_input in self.mm_inputs
)
def contains_mm_inputs(self) -> bool:
return self.contains_audio_inputs() or self.contains_image_inputs()
def _compute_mrope_positions(
self, model_runner: ModelRunner, batch: ModelWorkerBatch
):
@@ -378,8 +389,8 @@ class ForwardBatch:
for i, _ in enumerate(mrope_positions_list):
mrope_position_delta = (
0
if batch.image_inputs[i] is None
else batch.image_inputs[i].mrope_position_delta
if batch.multimodal_inputs[i] is None
else batch.multimodal_inputs[i].mrope_position_delta
)
mrope_positions_list[i] = MRotaryEmbedding.get_next_input_positions(
mrope_position_delta,
@@ -388,13 +399,13 @@ class ForwardBatch:
)
elif self.forward_mode.is_extend():
extend_start_loc_cpu = self.extend_start_loc.cpu().numpy()
for i, image_inputs in enumerate(batch.image_inputs):
for i, multimodal_inputs in enumerate(batch.multimodal_inputs):
extend_start_loc, extend_seq_len, extend_prefix_len = (
extend_start_loc_cpu[i],
batch.extend_seq_lens[i],
batch.extend_prefix_lens[i],
)
if image_inputs is None:
if multimodal_inputs is None:
# text only
mrope_positions = [
[
@@ -411,20 +422,22 @@ class ForwardBatch:
input_tokens=self.input_ids[
extend_start_loc : extend_start_loc + extend_seq_len
],
image_grid_thw=image_inputs.image_grid_thws,
video_grid_thw=image_inputs.video_grid_thws,
image_token_id=image_inputs.im_token_id,
video_token_id=image_inputs.video_token_id,
image_grid_thw=multimodal_inputs.image_grid_thws,
video_grid_thw=multimodal_inputs.video_grid_thws,
image_token_id=multimodal_inputs.im_token_id,
video_token_id=multimodal_inputs.video_token_id,
vision_start_token_id=hf_config.vision_start_token_id,
vision_end_token_id=hf_config.vision_end_token_id,
spatial_merge_size=hf_config.vision_config.spatial_merge_size,
context_len=0,
seq_len=len(self.input_ids),
second_per_grid_ts=image_inputs.second_per_grid_ts,
second_per_grid_ts=multimodal_inputs.second_per_grid_ts,
tokens_per_second=hf_config.vision_config.tokens_per_second,
)
)
batch.image_inputs[i].mrope_position_delta = mrope_position_delta
batch.multimodal_inputs[i].mrope_position_delta = (
mrope_position_delta
)
mrope_positions_list[i] = mrope_positions
self.mrope_positions = torch.cat(