diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 812dc63cb..61aec045e 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -353,8 +353,7 @@ async def generate_from_file_request(file: UploadFile, request: Request): obj = GenerateReqInput( input_embeds=input_embeds, sampling_params={ - "repetition_penalty": 1.2, - "temperature": 0.2, + "temperature": 0.0, "max_new_tokens": 512, }, ) @@ -393,16 +392,6 @@ async def classify_request(obj: EmbeddingReqInput, request: Request): return _create_error_response(e) -@app.api_route( - "/v1/rerank", methods=["POST", "PUT"], dependencies=[Depends(validate_json_request)] -) -async def v1_rerank_request(request: V1RerankReqInput, raw_request: Request): - """Endpoint for reranking documents based on query relevance.""" - return await raw_request.app.state.openai_serving_rerank.handle_request( - request, raw_request - ) - - @app.api_route("/flush_cache", methods=["GET", "POST"]) async def flush_cache(): """Flush the radix cache.""" @@ -841,6 +830,16 @@ async def v1_score_request(request: ScoringRequest, raw_request: Request): ) +@app.api_route( + "/v1/rerank", methods=["POST", "PUT"], dependencies=[Depends(validate_json_request)] +) +async def v1_rerank_request(request: V1RerankReqInput, raw_request: Request): + """Endpoint for reranking documents based on query relevance.""" + return await raw_request.app.state.openai_serving_rerank.handle_request( + request, raw_request + ) + + def _create_error_response(e): return ORJSONResponse( {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index aebd820ab..800dfc1fd 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -22,17 +22,16 @@ from dataclasses import dataclass, field from enum import Enum from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union +from sglang.srt.managers.schedule_batch import BaseFinishReason from sglang.srt.multimodal.mm_utils import has_valid_data +from sglang.srt.sampling.sampling_params import SamplingParams -# handle serialization of Image for pydantic +# Handle serialization of Image for pydantic if TYPE_CHECKING: from PIL.Image import Image else: Image = Any -from sglang.srt.managers.schedule_batch import BaseFinishReason -from sglang.srt.sampling.sampling_params import SamplingParams - @dataclass class SessionParams: @@ -182,6 +181,7 @@ class GenerateReqInput: # Determine parallel sample count if self.sampling_params is None: self.parallel_sample_num = 1 + return elif isinstance(self.sampling_params, dict): self.parallel_sample_num = self.sampling_params.get("n", 1) else: # isinstance(self.sampling_params, list): diff --git a/python/sglang/srt/managers/multimodal_processor.py b/python/sglang/srt/managers/multimodal_processor.py index 3980947d7..faf6576e6 100644 --- a/python/sglang/srt/managers/multimodal_processor.py +++ b/python/sglang/srt/managers/multimodal_processor.py @@ -25,7 +25,6 @@ def get_dummy_processor(): return DummyMultimodalProcessor() -@lru_cache() def import_processors(): package_name = "sglang.srt.multimodal.processors" package = importlib.import_module(package_name) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 6728f8852..8053c35da 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -180,46 +180,48 @@ class Modality(Enum): @dataclasses.dataclass class MultimodalDataItem: """ - A single multimodal data, from a single image/video/audio or others + A single multimodal data, from a single image/video/audio or others. + + We put the common fields first and the model-specific fields last. """ modality: Modality - hash: int = None pad_value: int = None - - aspect_ratio_id: Optional[List[torch.Tensor]] = None - aspect_ratio_mask: Optional[List[torch.Tensor]] = None - image_sizes: Tuple[int, int] = None image_offsets: Optional[list] = None # the real data, pixel_values or audio_features # data: Union[List[torch.Tensor], List[np.ndarray]] pixel_values: Union[torch.Tensor, np.ndarray] = None - image_grid_thw: Union[torch.Tensor, np.ndarray] = None - video_grid_thws: Union[torch.Tensor, np.ndarray] = None - - image_emb_mask: Optional[torch.Tensor] = None - image_spatial_crop: Optional[torch.Tensor] = None - second_per_grid_ts: Optional[List[torch.Tensor]] = None - - # [num_images, (n, w, h)] - tgt_size: Tuple[int, int] = None - - # kimi-vl related - image_grid_hws: Optional[List[torch.Tensor]] = None - audio_features: Union[torch.Tensor, np.ndarray] = None audio_feature_lens: Optional[List[torch.Tensor]] = None audio_offsets: Optional[List[Tuple[int, int]]] = None + precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None - # gemma3n related + # For qwen-vl + image_grid_thw: Union[torch.Tensor, np.ndarray] = None + second_per_grid_ts: Optional[List[torch.Tensor]] = None + + # For deepseek-vl + image_emb_mask: Optional[torch.Tensor] = None + image_spatial_crop: Optional[torch.Tensor] = None + + # For minicpmv + # [num_images, (n, w, h)] + tgt_size: Tuple[int, int] = None + + # For mllama + aspect_ratio_id: Optional[List[torch.Tensor]] = None + aspect_ratio_mask: Optional[List[torch.Tensor]] = None + + # For kimi-vl + image_grid_hws: Optional[List[torch.Tensor]] = None + + # For gemma3n input_features: Optional[torch.Tensor] = None input_features_mask: Optional[torch.Tensor] = None - precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None - @staticmethod def is_empty_list(l): if l is None: @@ -339,10 +341,6 @@ class MultimodalInputs: image_pad_len: Optional[list] = None num_image_tokens: Optional[int] = None - # QWen2-VL related - mrope_positions: Optional[torch.Tensor] = None - mrope_position_delta: Optional[torch.Tensor] = None - # image im_token_id: Optional[int] = None im_start_id: Optional[int] = None @@ -358,6 +356,10 @@ class MultimodalInputs: audio_start_id: Optional[int] = None audio_end_id: Optional[int] = None + # QWen2-VL related + mrope_positions: Optional[torch.Tensor] = None + mrope_position_delta: Optional[torch.Tensor] = None + @staticmethod def from_dict(obj: dict): ret = MultimodalInputs( diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index c4ec8646b..aa61ad063 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -150,7 +150,9 @@ class ReqState: # For streaming output last_output_offset: int = 0 + # For incremental state update. + # TODO(lianmin): do not initialize some lists if not needed. text: str = "" output_ids: List[int] = dataclasses.field(default_factory=list) input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list) @@ -199,7 +201,6 @@ class TokenizerManager: self.model_path = server_args.model_path self.served_model_name = server_args.served_model_name self.model_config = ModelConfig.from_server_args(server_args) - self.is_generation = self.model_config.is_generation self.is_image_gen = self.model_config.is_image_gen self.context_len = self.model_config.context_len @@ -251,19 +252,36 @@ class TokenizerManager: self.dump_requests_threshold = 1000 self.dump_request_list: List[Tuple] = [] self.log_request_metadata = self.get_log_request_metadata() + self.asyncio_tasks = set() + self.session_futures = {} # session_id -> asyncio event + self.max_req_input_len = None # The event to notify the weight sync is finished. self.model_update_lock = RWLock() self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = ( None ) - self.asyncio_tasks = set() - # For session info - self.session_futures = {} # session_id -> asyncio event + # For pd disaggregtion + self.disaggregation_mode = DisaggregationMode( + self.server_args.disaggregation_mode + ) + self.transfer_backend = TransferBackend( + self.server_args.disaggregation_transfer_backend + ) + # Start kv boostrap server on prefill + if self.disaggregation_mode == DisaggregationMode.PREFILL: + # only start bootstrap server on prefill tm + kv_bootstrap_server_class = get_kv_class( + self.transfer_backend, KVClassType.BOOTSTRAP_SERVER + ) + self.bootstrap_server = kv_bootstrap_server_class( + self.server_args.disaggregation_bootstrap_port + ) - # Set after scheduler is initialized - self.max_req_input_len = None + # For load balancing + self.current_load = 0 + self.current_load_lock = asyncio.Lock() # Metrics if self.enable_metrics: @@ -393,34 +411,14 @@ class TokenizerManager: ] ) - # For pd disaggregtion - self.disaggregation_mode = DisaggregationMode( - self.server_args.disaggregation_mode - ) - self.transfer_backend = TransferBackend( - self.server_args.disaggregation_transfer_backend - ) - # Start kv boostrap server on prefill - if self.disaggregation_mode == DisaggregationMode.PREFILL: - # only start bootstrap server on prefill tm - kv_bootstrap_server_class = get_kv_class( - self.transfer_backend, KVClassType.BOOTSTRAP_SERVER - ) - self.bootstrap_server = kv_bootstrap_server_class( - self.server_args.disaggregation_bootstrap_port - ) - - self.current_load = 0 - self.current_load_lock = asyncio.Lock() - async def generate_request( self, obj: Union[GenerateReqInput, EmbeddingReqInput], request: Optional[fastapi.Request] = None, ): created_time = time.time() - self.auto_create_handle_loop() + obj.normalize_batch_and_arguments() if isinstance(obj, EmbeddingReqInput) and self.is_generation: raise ValueError( @@ -428,22 +426,6 @@ class TokenizerManager: "Please add `--is-embedding` when launching the server or try another model." ) - obj.normalize_batch_and_arguments() - - if isinstance(obj, GenerateReqInput): - return_hidden_states = obj.return_hidden_states - has_return_hidden_states = return_hidden_states == True or ( - isinstance(return_hidden_states, list) and any(return_hidden_states) - ) - if ( - not self.server_args.enable_return_hidden_states - and has_return_hidden_states - ): - raise ValueError( - "return_hidden_states=True requires the server to be started " - "with --enable-return-hidden-states (ServerArgs.enable_return_hidden_states)." - ) - if self.log_requests: max_length, skip_names, _ = self.log_request_metadata logger.info( @@ -451,8 +433,7 @@ class TokenizerManager: ) async with self.model_update_lock.reader_lock: - is_single = obj.is_single - if is_single: + if obj.is_single: tokenized_obj = await self._tokenize_one_request(obj) state = self._send_one_request(obj, tokenized_obj, created_time) async for response in self._wait_one_response(obj, state, request): @@ -514,12 +495,12 @@ class TokenizerManager: else: image_inputs: Optional[Dict] = None - self._validate_token_len(obj, input_ids) + self._validate_one_request(obj, input_ids) return self._create_tokenized_object( obj, input_text, input_ids, input_embeds, image_inputs, token_type_ids ) - def _validate_token_len( + def _validate_one_request( self, obj: Union[GenerateReqInput, EmbeddingReqInput], input_ids: List[int] ) -> None: """Validates that the input token count and the requested token count doesn't exceed the model's context length.""" @@ -548,6 +529,24 @@ class TokenizerManager: ) raise ValueError(error_msg) + if isinstance(obj, GenerateReqInput): + if ( + obj.return_hidden_states + and not self.server_args.enable_return_hidden_states + ): + raise ValueError( + "The server is not configured to return the hidden states. " + "Please set `--enable-return-hidden-states` to enable this feature." + ) + if ( + obj.custom_logit_processor + and not self.server_args.enable_custom_logit_processor + ): + raise ValueError( + "The server is not configured to enable custom logit processor. " + "Please set `--enable-custom-logits-processor` to enable this feature." + ) + def _create_tokenized_object( self, obj: Union[GenerateReqInput, EmbeddingReqInput], @@ -558,24 +557,6 @@ class TokenizerManager: token_type_ids: Optional[List[int]] = None, ) -> Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]: """Create a tokenized request object from common parameters.""" - - if self.is_generation: - return_logprob = obj.return_logprob - logprob_start_len = obj.logprob_start_len - top_logprobs_num = obj.top_logprobs_num - token_ids_logprob = obj.token_ids_logprob - session_params = ( - SessionParams(**obj.session_params) if obj.session_params else None - ) - if ( - obj.custom_logit_processor - and not self.server_args.enable_custom_logit_processor - ): - raise ValueError( - "The server is not configured to enable custom logit processor. " - "Please set `--enable-custom-logits-processor` to enable this feature." - ) - # Parse sampling parameters # Note: if there are preferred sampling params, we use them if they are not # explicitly passed in sampling_params @@ -589,16 +570,20 @@ class TokenizerManager: # Build return object if isinstance(obj, GenerateReqInput): + session_params = ( + SessionParams(**obj.session_params) if obj.session_params else None + ) + tokenized_obj = TokenizedGenerateReqInput( obj.rid, input_text, input_ids, image_inputs, sampling_params, - return_logprob, - logprob_start_len, - top_logprobs_num, - token_ids_logprob, + obj.return_logprob, + obj.logprob_start_len, + obj.top_logprobs_num, + obj.token_ids_logprob, obj.stream, bootstrap_host=obj.bootstrap_host, bootstrap_port=obj.bootstrap_port, diff --git a/python/sglang/srt/multimodal/processors/base_processor.py b/python/sglang/srt/multimodal/processors/base_processor.py index 4b8732627..2e6ee5ed3 100644 --- a/python/sglang/srt/multimodal/processors/base_processor.py +++ b/python/sglang/srt/multimodal/processors/base_processor.py @@ -98,6 +98,7 @@ class BaseMultimodalProcessor(ABC): self._processor = _processor self.arch = hf_config.architectures[0] self.server_args = server_args + # FIXME: not accurate, model and image specific self.NUM_TOKEN_PER_FRAME = 330 diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index efacf37ad..f88082e69 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -10,7 +10,6 @@ import torch import sglang.srt.sampling.penaltylib as penaltylib from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor from sglang.srt.sampling.sampling_params import TOP_K_ALL -from sglang.srt.utils import merge_bias_tensor if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import ScheduleBatch @@ -345,3 +344,42 @@ class SamplingBatchInfo: self.logit_bias = merge_bias_tensor( self.logit_bias, other.logit_bias, len(self), len(other), self.device, 0.0 ) + + +def merge_bias_tensor( + lhs: Optional[torch.Tensor], + rhs: Optional[torch.Tensor], + bs1: int, + bs2: int, + device: str, + default: float, +): + """Merge two bias tensors for batch merging. + + Args: + lhs: Left-hand side tensor + rhs: Right-hand side tensor + bs1: Batch size of left-hand side tensor + bs2: Batch size of right-hand side tensor + device: Device to place the merged tensor on + default: Default value for missing tensor elements + + Returns: + Merged tensor or None if both inputs are None + """ + if lhs is None and rhs is None: + return None + + if lhs is not None and rhs is not None: + return torch.cat([lhs, rhs]) + else: + if lhs is not None: + shape, dtype = lhs.shape[1:], lhs.dtype + else: + shape, dtype = rhs.shape[1:], rhs.dtype + + if lhs is None: + lhs = torch.empty((bs1, *shape), device=device, dtype=dtype).fill_(default) + if rhs is None: + rhs = torch.empty((bs2, *shape), device=device, dtype=dtype).fill_(default) + return torch.cat([lhs, rhs]) diff --git a/python/sglang/srt/speculative/build_eagle_tree.py b/python/sglang/srt/speculative/build_eagle_tree.py index c6b853cc6..c53a13f4a 100644 --- a/python/sglang/srt/speculative/build_eagle_tree.py +++ b/python/sglang/srt/speculative/build_eagle_tree.py @@ -4,7 +4,7 @@ from typing import List import torch -from sglang.srt.utils import is_cuda, is_hip, rank0_print +from sglang.srt.utils import is_cuda, is_hip, rank0_log if is_cuda() or is_hip(): from sgl_kernel import ( @@ -344,13 +344,13 @@ def test_build_tree_kernel_efficient(): num_verify_tokens=num_draft_token, ) - rank0_print("=========== build tree kernel efficient ==========") - # rank0_print(f"{tree_mask=}", flush=True) - rank0_print(f"{position=}", flush=True) - rank0_print(f"{retrive_index=}", flush=True) - rank0_print(f"{retrive_next_token=}", flush=True) - rank0_print(f"{retrive_next_sibling=}", flush=True) - rank0_print(f"{draft_tokens=}", flush=True) + rank0_log("=========== build tree kernel efficient ==========") + # rank0_log(f"{tree_mask=}") + rank0_log(f"{position=}") + rank0_log(f"{retrive_index=}") + rank0_log(f"{retrive_next_token=}") + rank0_log(f"{retrive_next_sibling=}") + rank0_log(f"{draft_tokens=}") assert position.tolist() == [5, 6, 6, 7, 7, 8, 8, 9, 10, 11, 12, 12, 12, 12, 13, 14] assert retrive_index.tolist() == [ [0, 1, 2, 3, 4, 5, 6, 7], diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 8f3d712e6..8a91c2fc4 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1917,14 +1917,11 @@ def configure_ipv6(dist_init_addr): return port, host -def rank0_print(msg: str): +def rank0_log(msg: str): from sglang.srt.distributed import get_tensor_model_parallel_rank if get_tensor_model_parallel_rank() == 0: - print(msg, flush=True) - - -rank0_log = rank0_print + logger.info(msg) def get_cuda_version(): @@ -2344,45 +2341,6 @@ def require_mlp_sync(server_args): return server_args.enable_dp_attention or require_gathered_buffer(server_args) -def merge_bias_tensor( - lhs: Optional[torch.Tensor], - rhs: Optional[torch.Tensor], - bs1: int, - bs2: int, - device: str, - default: float, -): - """Merge two bias tensors for batch merging. - - Args: - lhs: Left-hand side tensor - rhs: Right-hand side tensor - bs1: Batch size of left-hand side tensor - bs2: Batch size of right-hand side tensor - device: Device to place the merged tensor on - default: Default value for missing tensor elements - - Returns: - Merged tensor or None if both inputs are None - """ - if lhs is None and rhs is None: - return None - - if lhs is not None and rhs is not None: - return torch.cat([lhs, rhs]) - else: - if lhs is not None: - shape, dtype = lhs.shape[1:], lhs.dtype - else: - shape, dtype = rhs.shape[1:], rhs.dtype - - if lhs is None: - lhs = torch.empty((bs1, *shape), device=device, dtype=dtype).fill_(default) - if rhs is None: - rhs = torch.empty((bs2, *shape), device=device, dtype=dtype).fill_(default) - return torch.cat([lhs, rhs]) - - def find_local_repo_dir(repo_id: str, revision: Optional[str] = None) -> Optional[str]: import huggingface_hub as hf