[Minor] clean up multimodal processor and tokenizer manager (#7624)

This commit is contained in:
Lianmin Zheng
2025-06-29 02:50:14 -07:00
committed by GitHub
parent 7c0db3a6c5
commit 071a1f51ae
9 changed files with 147 additions and 165 deletions

View File

@@ -22,17 +22,16 @@ from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union
from sglang.srt.managers.schedule_batch import BaseFinishReason
from sglang.srt.multimodal.mm_utils import has_valid_data
from sglang.srt.sampling.sampling_params import SamplingParams
# handle serialization of Image for pydantic
# Handle serialization of Image for pydantic
if TYPE_CHECKING:
from PIL.Image import Image
else:
Image = Any
from sglang.srt.managers.schedule_batch import BaseFinishReason
from sglang.srt.sampling.sampling_params import SamplingParams
@dataclass
class SessionParams:
@@ -182,6 +181,7 @@ class GenerateReqInput:
# Determine parallel sample count
if self.sampling_params is None:
self.parallel_sample_num = 1
return
elif isinstance(self.sampling_params, dict):
self.parallel_sample_num = self.sampling_params.get("n", 1)
else: # isinstance(self.sampling_params, list):

View File

@@ -25,7 +25,6 @@ def get_dummy_processor():
return DummyMultimodalProcessor()
@lru_cache()
def import_processors():
package_name = "sglang.srt.multimodal.processors"
package = importlib.import_module(package_name)

View File

@@ -180,46 +180,48 @@ class Modality(Enum):
@dataclasses.dataclass
class MultimodalDataItem:
"""
A single multimodal data, from a single image/video/audio or others
A single multimodal data, from a single image/video/audio or others.
We put the common fields first and the model-specific fields last.
"""
modality: Modality
hash: int = None
pad_value: int = None
aspect_ratio_id: Optional[List[torch.Tensor]] = None
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
image_sizes: Tuple[int, int] = None
image_offsets: Optional[list] = None
# the real data, pixel_values or audio_features
# data: Union[List[torch.Tensor], List[np.ndarray]]
pixel_values: Union[torch.Tensor, np.ndarray] = None
image_grid_thw: Union[torch.Tensor, np.ndarray] = None
video_grid_thws: Union[torch.Tensor, np.ndarray] = None
image_emb_mask: Optional[torch.Tensor] = None
image_spatial_crop: Optional[torch.Tensor] = None
second_per_grid_ts: Optional[List[torch.Tensor]] = None
# [num_images, (n, w, h)]
tgt_size: Tuple[int, int] = None
# kimi-vl related
image_grid_hws: Optional[List[torch.Tensor]] = None
audio_features: Union[torch.Tensor, np.ndarray] = None
audio_feature_lens: Optional[List[torch.Tensor]] = None
audio_offsets: Optional[List[Tuple[int, int]]] = None
precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None
# gemma3n related
# For qwen-vl
image_grid_thw: Union[torch.Tensor, np.ndarray] = None
second_per_grid_ts: Optional[List[torch.Tensor]] = None
# For deepseek-vl
image_emb_mask: Optional[torch.Tensor] = None
image_spatial_crop: Optional[torch.Tensor] = None
# For minicpmv
# [num_images, (n, w, h)]
tgt_size: Tuple[int, int] = None
# For mllama
aspect_ratio_id: Optional[List[torch.Tensor]] = None
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
# For kimi-vl
image_grid_hws: Optional[List[torch.Tensor]] = None
# For gemma3n
input_features: Optional[torch.Tensor] = None
input_features_mask: Optional[torch.Tensor] = None
precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None
@staticmethod
def is_empty_list(l):
if l is None:
@@ -339,10 +341,6 @@ class MultimodalInputs:
image_pad_len: Optional[list] = None
num_image_tokens: Optional[int] = None
# QWen2-VL related
mrope_positions: Optional[torch.Tensor] = None
mrope_position_delta: Optional[torch.Tensor] = None
# image
im_token_id: Optional[int] = None
im_start_id: Optional[int] = None
@@ -358,6 +356,10 @@ class MultimodalInputs:
audio_start_id: Optional[int] = None
audio_end_id: Optional[int] = None
# QWen2-VL related
mrope_positions: Optional[torch.Tensor] = None
mrope_position_delta: Optional[torch.Tensor] = None
@staticmethod
def from_dict(obj: dict):
ret = MultimodalInputs(

View File

@@ -150,7 +150,9 @@ class ReqState:
# For streaming output
last_output_offset: int = 0
# For incremental state update.
# TODO(lianmin): do not initialize some lists if not needed.
text: str = ""
output_ids: List[int] = dataclasses.field(default_factory=list)
input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
@@ -199,7 +201,6 @@ class TokenizerManager:
self.model_path = server_args.model_path
self.served_model_name = server_args.served_model_name
self.model_config = ModelConfig.from_server_args(server_args)
self.is_generation = self.model_config.is_generation
self.is_image_gen = self.model_config.is_image_gen
self.context_len = self.model_config.context_len
@@ -251,19 +252,36 @@ class TokenizerManager:
self.dump_requests_threshold = 1000
self.dump_request_list: List[Tuple] = []
self.log_request_metadata = self.get_log_request_metadata()
self.asyncio_tasks = set()
self.session_futures = {} # session_id -> asyncio event
self.max_req_input_len = None
# The event to notify the weight sync is finished.
self.model_update_lock = RWLock()
self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
None
)
self.asyncio_tasks = set()
# For session info
self.session_futures = {} # session_id -> asyncio event
# For pd disaggregtion
self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode
)
self.transfer_backend = TransferBackend(
self.server_args.disaggregation_transfer_backend
)
# Start kv boostrap server on prefill
if self.disaggregation_mode == DisaggregationMode.PREFILL:
# only start bootstrap server on prefill tm
kv_bootstrap_server_class = get_kv_class(
self.transfer_backend, KVClassType.BOOTSTRAP_SERVER
)
self.bootstrap_server = kv_bootstrap_server_class(
self.server_args.disaggregation_bootstrap_port
)
# Set after scheduler is initialized
self.max_req_input_len = None
# For load balancing
self.current_load = 0
self.current_load_lock = asyncio.Lock()
# Metrics
if self.enable_metrics:
@@ -393,34 +411,14 @@ class TokenizerManager:
]
)
# For pd disaggregtion
self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode
)
self.transfer_backend = TransferBackend(
self.server_args.disaggregation_transfer_backend
)
# Start kv boostrap server on prefill
if self.disaggregation_mode == DisaggregationMode.PREFILL:
# only start bootstrap server on prefill tm
kv_bootstrap_server_class = get_kv_class(
self.transfer_backend, KVClassType.BOOTSTRAP_SERVER
)
self.bootstrap_server = kv_bootstrap_server_class(
self.server_args.disaggregation_bootstrap_port
)
self.current_load = 0
self.current_load_lock = asyncio.Lock()
async def generate_request(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
request: Optional[fastapi.Request] = None,
):
created_time = time.time()
self.auto_create_handle_loop()
obj.normalize_batch_and_arguments()
if isinstance(obj, EmbeddingReqInput) and self.is_generation:
raise ValueError(
@@ -428,22 +426,6 @@ class TokenizerManager:
"Please add `--is-embedding` when launching the server or try another model."
)
obj.normalize_batch_and_arguments()
if isinstance(obj, GenerateReqInput):
return_hidden_states = obj.return_hidden_states
has_return_hidden_states = return_hidden_states == True or (
isinstance(return_hidden_states, list) and any(return_hidden_states)
)
if (
not self.server_args.enable_return_hidden_states
and has_return_hidden_states
):
raise ValueError(
"return_hidden_states=True requires the server to be started "
"with --enable-return-hidden-states (ServerArgs.enable_return_hidden_states)."
)
if self.log_requests:
max_length, skip_names, _ = self.log_request_metadata
logger.info(
@@ -451,8 +433,7 @@ class TokenizerManager:
)
async with self.model_update_lock.reader_lock:
is_single = obj.is_single
if is_single:
if obj.is_single:
tokenized_obj = await self._tokenize_one_request(obj)
state = self._send_one_request(obj, tokenized_obj, created_time)
async for response in self._wait_one_response(obj, state, request):
@@ -514,12 +495,12 @@ class TokenizerManager:
else:
image_inputs: Optional[Dict] = None
self._validate_token_len(obj, input_ids)
self._validate_one_request(obj, input_ids)
return self._create_tokenized_object(
obj, input_text, input_ids, input_embeds, image_inputs, token_type_ids
)
def _validate_token_len(
def _validate_one_request(
self, obj: Union[GenerateReqInput, EmbeddingReqInput], input_ids: List[int]
) -> None:
"""Validates that the input token count and the requested token count doesn't exceed the model's context length."""
@@ -548,6 +529,24 @@ class TokenizerManager:
)
raise ValueError(error_msg)
if isinstance(obj, GenerateReqInput):
if (
obj.return_hidden_states
and not self.server_args.enable_return_hidden_states
):
raise ValueError(
"The server is not configured to return the hidden states. "
"Please set `--enable-return-hidden-states` to enable this feature."
)
if (
obj.custom_logit_processor
and not self.server_args.enable_custom_logit_processor
):
raise ValueError(
"The server is not configured to enable custom logit processor. "
"Please set `--enable-custom-logits-processor` to enable this feature."
)
def _create_tokenized_object(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
@@ -558,24 +557,6 @@ class TokenizerManager:
token_type_ids: Optional[List[int]] = None,
) -> Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]:
"""Create a tokenized request object from common parameters."""
if self.is_generation:
return_logprob = obj.return_logprob
logprob_start_len = obj.logprob_start_len
top_logprobs_num = obj.top_logprobs_num
token_ids_logprob = obj.token_ids_logprob
session_params = (
SessionParams(**obj.session_params) if obj.session_params else None
)
if (
obj.custom_logit_processor
and not self.server_args.enable_custom_logit_processor
):
raise ValueError(
"The server is not configured to enable custom logit processor. "
"Please set `--enable-custom-logits-processor` to enable this feature."
)
# Parse sampling parameters
# Note: if there are preferred sampling params, we use them if they are not
# explicitly passed in sampling_params
@@ -589,16 +570,20 @@ class TokenizerManager:
# Build return object
if isinstance(obj, GenerateReqInput):
session_params = (
SessionParams(**obj.session_params) if obj.session_params else None
)
tokenized_obj = TokenizedGenerateReqInput(
obj.rid,
input_text,
input_ids,
image_inputs,
sampling_params,
return_logprob,
logprob_start_len,
top_logprobs_num,
token_ids_logprob,
obj.return_logprob,
obj.logprob_start_len,
obj.top_logprobs_num,
obj.token_ids_logprob,
obj.stream,
bootstrap_host=obj.bootstrap_host,
bootstrap_port=obj.bootstrap_port,