[Minor] clean up multimodal processor and tokenizer manager (#7624)
This commit is contained in:
@@ -353,8 +353,7 @@ async def generate_from_file_request(file: UploadFile, request: Request):
|
||||
obj = GenerateReqInput(
|
||||
input_embeds=input_embeds,
|
||||
sampling_params={
|
||||
"repetition_penalty": 1.2,
|
||||
"temperature": 0.2,
|
||||
"temperature": 0.0,
|
||||
"max_new_tokens": 512,
|
||||
},
|
||||
)
|
||||
@@ -393,16 +392,6 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
|
||||
return _create_error_response(e)
|
||||
|
||||
|
||||
@app.api_route(
|
||||
"/v1/rerank", methods=["POST", "PUT"], dependencies=[Depends(validate_json_request)]
|
||||
)
|
||||
async def v1_rerank_request(request: V1RerankReqInput, raw_request: Request):
|
||||
"""Endpoint for reranking documents based on query relevance."""
|
||||
return await raw_request.app.state.openai_serving_rerank.handle_request(
|
||||
request, raw_request
|
||||
)
|
||||
|
||||
|
||||
@app.api_route("/flush_cache", methods=["GET", "POST"])
|
||||
async def flush_cache():
|
||||
"""Flush the radix cache."""
|
||||
@@ -841,6 +830,16 @@ async def v1_score_request(request: ScoringRequest, raw_request: Request):
|
||||
)
|
||||
|
||||
|
||||
@app.api_route(
|
||||
"/v1/rerank", methods=["POST", "PUT"], dependencies=[Depends(validate_json_request)]
|
||||
)
|
||||
async def v1_rerank_request(request: V1RerankReqInput, raw_request: Request):
|
||||
"""Endpoint for reranking documents based on query relevance."""
|
||||
return await raw_request.app.state.openai_serving_rerank.handle_request(
|
||||
request, raw_request
|
||||
)
|
||||
|
||||
|
||||
def _create_error_response(e):
|
||||
return ORJSONResponse(
|
||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||
|
||||
@@ -22,17 +22,16 @@ from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union
|
||||
|
||||
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
||||
from sglang.srt.multimodal.mm_utils import has_valid_data
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
|
||||
# handle serialization of Image for pydantic
|
||||
# Handle serialization of Image for pydantic
|
||||
if TYPE_CHECKING:
|
||||
from PIL.Image import Image
|
||||
else:
|
||||
Image = Any
|
||||
|
||||
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionParams:
|
||||
@@ -182,6 +181,7 @@ class GenerateReqInput:
|
||||
# Determine parallel sample count
|
||||
if self.sampling_params is None:
|
||||
self.parallel_sample_num = 1
|
||||
return
|
||||
elif isinstance(self.sampling_params, dict):
|
||||
self.parallel_sample_num = self.sampling_params.get("n", 1)
|
||||
else: # isinstance(self.sampling_params, list):
|
||||
|
||||
@@ -25,7 +25,6 @@ def get_dummy_processor():
|
||||
return DummyMultimodalProcessor()
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def import_processors():
|
||||
package_name = "sglang.srt.multimodal.processors"
|
||||
package = importlib.import_module(package_name)
|
||||
|
||||
@@ -180,46 +180,48 @@ class Modality(Enum):
|
||||
@dataclasses.dataclass
|
||||
class MultimodalDataItem:
|
||||
"""
|
||||
A single multimodal data, from a single image/video/audio or others
|
||||
A single multimodal data, from a single image/video/audio or others.
|
||||
|
||||
We put the common fields first and the model-specific fields last.
|
||||
"""
|
||||
|
||||
modality: Modality
|
||||
|
||||
hash: int = None
|
||||
pad_value: int = None
|
||||
|
||||
aspect_ratio_id: Optional[List[torch.Tensor]] = None
|
||||
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
|
||||
|
||||
image_sizes: Tuple[int, int] = None
|
||||
image_offsets: Optional[list] = None
|
||||
|
||||
# the real data, pixel_values or audio_features
|
||||
# data: Union[List[torch.Tensor], List[np.ndarray]]
|
||||
pixel_values: Union[torch.Tensor, np.ndarray] = None
|
||||
image_grid_thw: Union[torch.Tensor, np.ndarray] = None
|
||||
video_grid_thws: Union[torch.Tensor, np.ndarray] = None
|
||||
|
||||
image_emb_mask: Optional[torch.Tensor] = None
|
||||
image_spatial_crop: Optional[torch.Tensor] = None
|
||||
second_per_grid_ts: Optional[List[torch.Tensor]] = None
|
||||
|
||||
# [num_images, (n, w, h)]
|
||||
tgt_size: Tuple[int, int] = None
|
||||
|
||||
# kimi-vl related
|
||||
image_grid_hws: Optional[List[torch.Tensor]] = None
|
||||
|
||||
audio_features: Union[torch.Tensor, np.ndarray] = None
|
||||
audio_feature_lens: Optional[List[torch.Tensor]] = None
|
||||
audio_offsets: Optional[List[Tuple[int, int]]] = None
|
||||
precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None
|
||||
|
||||
# gemma3n related
|
||||
# For qwen-vl
|
||||
image_grid_thw: Union[torch.Tensor, np.ndarray] = None
|
||||
second_per_grid_ts: Optional[List[torch.Tensor]] = None
|
||||
|
||||
# For deepseek-vl
|
||||
image_emb_mask: Optional[torch.Tensor] = None
|
||||
image_spatial_crop: Optional[torch.Tensor] = None
|
||||
|
||||
# For minicpmv
|
||||
# [num_images, (n, w, h)]
|
||||
tgt_size: Tuple[int, int] = None
|
||||
|
||||
# For mllama
|
||||
aspect_ratio_id: Optional[List[torch.Tensor]] = None
|
||||
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
|
||||
|
||||
# For kimi-vl
|
||||
image_grid_hws: Optional[List[torch.Tensor]] = None
|
||||
|
||||
# For gemma3n
|
||||
input_features: Optional[torch.Tensor] = None
|
||||
input_features_mask: Optional[torch.Tensor] = None
|
||||
|
||||
precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None
|
||||
|
||||
@staticmethod
|
||||
def is_empty_list(l):
|
||||
if l is None:
|
||||
@@ -339,10 +341,6 @@ class MultimodalInputs:
|
||||
image_pad_len: Optional[list] = None
|
||||
num_image_tokens: Optional[int] = None
|
||||
|
||||
# QWen2-VL related
|
||||
mrope_positions: Optional[torch.Tensor] = None
|
||||
mrope_position_delta: Optional[torch.Tensor] = None
|
||||
|
||||
# image
|
||||
im_token_id: Optional[int] = None
|
||||
im_start_id: Optional[int] = None
|
||||
@@ -358,6 +356,10 @@ class MultimodalInputs:
|
||||
audio_start_id: Optional[int] = None
|
||||
audio_end_id: Optional[int] = None
|
||||
|
||||
# QWen2-VL related
|
||||
mrope_positions: Optional[torch.Tensor] = None
|
||||
mrope_position_delta: Optional[torch.Tensor] = None
|
||||
|
||||
@staticmethod
|
||||
def from_dict(obj: dict):
|
||||
ret = MultimodalInputs(
|
||||
|
||||
@@ -150,7 +150,9 @@ class ReqState:
|
||||
|
||||
# For streaming output
|
||||
last_output_offset: int = 0
|
||||
|
||||
# For incremental state update.
|
||||
# TODO(lianmin): do not initialize some lists if not needed.
|
||||
text: str = ""
|
||||
output_ids: List[int] = dataclasses.field(default_factory=list)
|
||||
input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
|
||||
@@ -199,7 +201,6 @@ class TokenizerManager:
|
||||
self.model_path = server_args.model_path
|
||||
self.served_model_name = server_args.served_model_name
|
||||
self.model_config = ModelConfig.from_server_args(server_args)
|
||||
|
||||
self.is_generation = self.model_config.is_generation
|
||||
self.is_image_gen = self.model_config.is_image_gen
|
||||
self.context_len = self.model_config.context_len
|
||||
@@ -251,19 +252,36 @@ class TokenizerManager:
|
||||
self.dump_requests_threshold = 1000
|
||||
self.dump_request_list: List[Tuple] = []
|
||||
self.log_request_metadata = self.get_log_request_metadata()
|
||||
self.asyncio_tasks = set()
|
||||
self.session_futures = {} # session_id -> asyncio event
|
||||
self.max_req_input_len = None
|
||||
|
||||
# The event to notify the weight sync is finished.
|
||||
self.model_update_lock = RWLock()
|
||||
self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
|
||||
None
|
||||
)
|
||||
self.asyncio_tasks = set()
|
||||
|
||||
# For session info
|
||||
self.session_futures = {} # session_id -> asyncio event
|
||||
# For pd disaggregtion
|
||||
self.disaggregation_mode = DisaggregationMode(
|
||||
self.server_args.disaggregation_mode
|
||||
)
|
||||
self.transfer_backend = TransferBackend(
|
||||
self.server_args.disaggregation_transfer_backend
|
||||
)
|
||||
# Start kv boostrap server on prefill
|
||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
# only start bootstrap server on prefill tm
|
||||
kv_bootstrap_server_class = get_kv_class(
|
||||
self.transfer_backend, KVClassType.BOOTSTRAP_SERVER
|
||||
)
|
||||
self.bootstrap_server = kv_bootstrap_server_class(
|
||||
self.server_args.disaggregation_bootstrap_port
|
||||
)
|
||||
|
||||
# Set after scheduler is initialized
|
||||
self.max_req_input_len = None
|
||||
# For load balancing
|
||||
self.current_load = 0
|
||||
self.current_load_lock = asyncio.Lock()
|
||||
|
||||
# Metrics
|
||||
if self.enable_metrics:
|
||||
@@ -393,34 +411,14 @@ class TokenizerManager:
|
||||
]
|
||||
)
|
||||
|
||||
# For pd disaggregtion
|
||||
self.disaggregation_mode = DisaggregationMode(
|
||||
self.server_args.disaggregation_mode
|
||||
)
|
||||
self.transfer_backend = TransferBackend(
|
||||
self.server_args.disaggregation_transfer_backend
|
||||
)
|
||||
# Start kv boostrap server on prefill
|
||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
# only start bootstrap server on prefill tm
|
||||
kv_bootstrap_server_class = get_kv_class(
|
||||
self.transfer_backend, KVClassType.BOOTSTRAP_SERVER
|
||||
)
|
||||
self.bootstrap_server = kv_bootstrap_server_class(
|
||||
self.server_args.disaggregation_bootstrap_port
|
||||
)
|
||||
|
||||
self.current_load = 0
|
||||
self.current_load_lock = asyncio.Lock()
|
||||
|
||||
async def generate_request(
|
||||
self,
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||
request: Optional[fastapi.Request] = None,
|
||||
):
|
||||
created_time = time.time()
|
||||
|
||||
self.auto_create_handle_loop()
|
||||
obj.normalize_batch_and_arguments()
|
||||
|
||||
if isinstance(obj, EmbeddingReqInput) and self.is_generation:
|
||||
raise ValueError(
|
||||
@@ -428,22 +426,6 @@ class TokenizerManager:
|
||||
"Please add `--is-embedding` when launching the server or try another model."
|
||||
)
|
||||
|
||||
obj.normalize_batch_and_arguments()
|
||||
|
||||
if isinstance(obj, GenerateReqInput):
|
||||
return_hidden_states = obj.return_hidden_states
|
||||
has_return_hidden_states = return_hidden_states == True or (
|
||||
isinstance(return_hidden_states, list) and any(return_hidden_states)
|
||||
)
|
||||
if (
|
||||
not self.server_args.enable_return_hidden_states
|
||||
and has_return_hidden_states
|
||||
):
|
||||
raise ValueError(
|
||||
"return_hidden_states=True requires the server to be started "
|
||||
"with --enable-return-hidden-states (ServerArgs.enable_return_hidden_states)."
|
||||
)
|
||||
|
||||
if self.log_requests:
|
||||
max_length, skip_names, _ = self.log_request_metadata
|
||||
logger.info(
|
||||
@@ -451,8 +433,7 @@ class TokenizerManager:
|
||||
)
|
||||
|
||||
async with self.model_update_lock.reader_lock:
|
||||
is_single = obj.is_single
|
||||
if is_single:
|
||||
if obj.is_single:
|
||||
tokenized_obj = await self._tokenize_one_request(obj)
|
||||
state = self._send_one_request(obj, tokenized_obj, created_time)
|
||||
async for response in self._wait_one_response(obj, state, request):
|
||||
@@ -514,12 +495,12 @@ class TokenizerManager:
|
||||
else:
|
||||
image_inputs: Optional[Dict] = None
|
||||
|
||||
self._validate_token_len(obj, input_ids)
|
||||
self._validate_one_request(obj, input_ids)
|
||||
return self._create_tokenized_object(
|
||||
obj, input_text, input_ids, input_embeds, image_inputs, token_type_ids
|
||||
)
|
||||
|
||||
def _validate_token_len(
|
||||
def _validate_one_request(
|
||||
self, obj: Union[GenerateReqInput, EmbeddingReqInput], input_ids: List[int]
|
||||
) -> None:
|
||||
"""Validates that the input token count and the requested token count doesn't exceed the model's context length."""
|
||||
@@ -548,6 +529,24 @@ class TokenizerManager:
|
||||
)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
if isinstance(obj, GenerateReqInput):
|
||||
if (
|
||||
obj.return_hidden_states
|
||||
and not self.server_args.enable_return_hidden_states
|
||||
):
|
||||
raise ValueError(
|
||||
"The server is not configured to return the hidden states. "
|
||||
"Please set `--enable-return-hidden-states` to enable this feature."
|
||||
)
|
||||
if (
|
||||
obj.custom_logit_processor
|
||||
and not self.server_args.enable_custom_logit_processor
|
||||
):
|
||||
raise ValueError(
|
||||
"The server is not configured to enable custom logit processor. "
|
||||
"Please set `--enable-custom-logits-processor` to enable this feature."
|
||||
)
|
||||
|
||||
def _create_tokenized_object(
|
||||
self,
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||
@@ -558,24 +557,6 @@ class TokenizerManager:
|
||||
token_type_ids: Optional[List[int]] = None,
|
||||
) -> Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]:
|
||||
"""Create a tokenized request object from common parameters."""
|
||||
|
||||
if self.is_generation:
|
||||
return_logprob = obj.return_logprob
|
||||
logprob_start_len = obj.logprob_start_len
|
||||
top_logprobs_num = obj.top_logprobs_num
|
||||
token_ids_logprob = obj.token_ids_logprob
|
||||
session_params = (
|
||||
SessionParams(**obj.session_params) if obj.session_params else None
|
||||
)
|
||||
if (
|
||||
obj.custom_logit_processor
|
||||
and not self.server_args.enable_custom_logit_processor
|
||||
):
|
||||
raise ValueError(
|
||||
"The server is not configured to enable custom logit processor. "
|
||||
"Please set `--enable-custom-logits-processor` to enable this feature."
|
||||
)
|
||||
|
||||
# Parse sampling parameters
|
||||
# Note: if there are preferred sampling params, we use them if they are not
|
||||
# explicitly passed in sampling_params
|
||||
@@ -589,16 +570,20 @@ class TokenizerManager:
|
||||
|
||||
# Build return object
|
||||
if isinstance(obj, GenerateReqInput):
|
||||
session_params = (
|
||||
SessionParams(**obj.session_params) if obj.session_params else None
|
||||
)
|
||||
|
||||
tokenized_obj = TokenizedGenerateReqInput(
|
||||
obj.rid,
|
||||
input_text,
|
||||
input_ids,
|
||||
image_inputs,
|
||||
sampling_params,
|
||||
return_logprob,
|
||||
logprob_start_len,
|
||||
top_logprobs_num,
|
||||
token_ids_logprob,
|
||||
obj.return_logprob,
|
||||
obj.logprob_start_len,
|
||||
obj.top_logprobs_num,
|
||||
obj.token_ids_logprob,
|
||||
obj.stream,
|
||||
bootstrap_host=obj.bootstrap_host,
|
||||
bootstrap_port=obj.bootstrap_port,
|
||||
|
||||
@@ -98,6 +98,7 @@ class BaseMultimodalProcessor(ABC):
|
||||
self._processor = _processor
|
||||
self.arch = hf_config.architectures[0]
|
||||
self.server_args = server_args
|
||||
|
||||
# FIXME: not accurate, model and image specific
|
||||
self.NUM_TOKEN_PER_FRAME = 330
|
||||
|
||||
|
||||
@@ -10,7 +10,6 @@ import torch
|
||||
import sglang.srt.sampling.penaltylib as penaltylib
|
||||
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
|
||||
from sglang.srt.sampling.sampling_params import TOP_K_ALL
|
||||
from sglang.srt.utils import merge_bias_tensor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
@@ -345,3 +344,42 @@ class SamplingBatchInfo:
|
||||
self.logit_bias = merge_bias_tensor(
|
||||
self.logit_bias, other.logit_bias, len(self), len(other), self.device, 0.0
|
||||
)
|
||||
|
||||
|
||||
def merge_bias_tensor(
|
||||
lhs: Optional[torch.Tensor],
|
||||
rhs: Optional[torch.Tensor],
|
||||
bs1: int,
|
||||
bs2: int,
|
||||
device: str,
|
||||
default: float,
|
||||
):
|
||||
"""Merge two bias tensors for batch merging.
|
||||
|
||||
Args:
|
||||
lhs: Left-hand side tensor
|
||||
rhs: Right-hand side tensor
|
||||
bs1: Batch size of left-hand side tensor
|
||||
bs2: Batch size of right-hand side tensor
|
||||
device: Device to place the merged tensor on
|
||||
default: Default value for missing tensor elements
|
||||
|
||||
Returns:
|
||||
Merged tensor or None if both inputs are None
|
||||
"""
|
||||
if lhs is None and rhs is None:
|
||||
return None
|
||||
|
||||
if lhs is not None and rhs is not None:
|
||||
return torch.cat([lhs, rhs])
|
||||
else:
|
||||
if lhs is not None:
|
||||
shape, dtype = lhs.shape[1:], lhs.dtype
|
||||
else:
|
||||
shape, dtype = rhs.shape[1:], rhs.dtype
|
||||
|
||||
if lhs is None:
|
||||
lhs = torch.empty((bs1, *shape), device=device, dtype=dtype).fill_(default)
|
||||
if rhs is None:
|
||||
rhs = torch.empty((bs2, *shape), device=device, dtype=dtype).fill_(default)
|
||||
return torch.cat([lhs, rhs])
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.utils import is_cuda, is_hip, rank0_print
|
||||
from sglang.srt.utils import is_cuda, is_hip, rank0_log
|
||||
|
||||
if is_cuda() or is_hip():
|
||||
from sgl_kernel import (
|
||||
@@ -344,13 +344,13 @@ def test_build_tree_kernel_efficient():
|
||||
num_verify_tokens=num_draft_token,
|
||||
)
|
||||
|
||||
rank0_print("=========== build tree kernel efficient ==========")
|
||||
# rank0_print(f"{tree_mask=}", flush=True)
|
||||
rank0_print(f"{position=}", flush=True)
|
||||
rank0_print(f"{retrive_index=}", flush=True)
|
||||
rank0_print(f"{retrive_next_token=}", flush=True)
|
||||
rank0_print(f"{retrive_next_sibling=}", flush=True)
|
||||
rank0_print(f"{draft_tokens=}", flush=True)
|
||||
rank0_log("=========== build tree kernel efficient ==========")
|
||||
# rank0_log(f"{tree_mask=}")
|
||||
rank0_log(f"{position=}")
|
||||
rank0_log(f"{retrive_index=}")
|
||||
rank0_log(f"{retrive_next_token=}")
|
||||
rank0_log(f"{retrive_next_sibling=}")
|
||||
rank0_log(f"{draft_tokens=}")
|
||||
assert position.tolist() == [5, 6, 6, 7, 7, 8, 8, 9, 10, 11, 12, 12, 12, 12, 13, 14]
|
||||
assert retrive_index.tolist() == [
|
||||
[0, 1, 2, 3, 4, 5, 6, 7],
|
||||
|
||||
@@ -1917,14 +1917,11 @@ def configure_ipv6(dist_init_addr):
|
||||
return port, host
|
||||
|
||||
|
||||
def rank0_print(msg: str):
|
||||
def rank0_log(msg: str):
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||
|
||||
if get_tensor_model_parallel_rank() == 0:
|
||||
print(msg, flush=True)
|
||||
|
||||
|
||||
rank0_log = rank0_print
|
||||
logger.info(msg)
|
||||
|
||||
|
||||
def get_cuda_version():
|
||||
@@ -2344,45 +2341,6 @@ def require_mlp_sync(server_args):
|
||||
return server_args.enable_dp_attention or require_gathered_buffer(server_args)
|
||||
|
||||
|
||||
def merge_bias_tensor(
|
||||
lhs: Optional[torch.Tensor],
|
||||
rhs: Optional[torch.Tensor],
|
||||
bs1: int,
|
||||
bs2: int,
|
||||
device: str,
|
||||
default: float,
|
||||
):
|
||||
"""Merge two bias tensors for batch merging.
|
||||
|
||||
Args:
|
||||
lhs: Left-hand side tensor
|
||||
rhs: Right-hand side tensor
|
||||
bs1: Batch size of left-hand side tensor
|
||||
bs2: Batch size of right-hand side tensor
|
||||
device: Device to place the merged tensor on
|
||||
default: Default value for missing tensor elements
|
||||
|
||||
Returns:
|
||||
Merged tensor or None if both inputs are None
|
||||
"""
|
||||
if lhs is None and rhs is None:
|
||||
return None
|
||||
|
||||
if lhs is not None and rhs is not None:
|
||||
return torch.cat([lhs, rhs])
|
||||
else:
|
||||
if lhs is not None:
|
||||
shape, dtype = lhs.shape[1:], lhs.dtype
|
||||
else:
|
||||
shape, dtype = rhs.shape[1:], rhs.dtype
|
||||
|
||||
if lhs is None:
|
||||
lhs = torch.empty((bs1, *shape), device=device, dtype=dtype).fill_(default)
|
||||
if rhs is None:
|
||||
rhs = torch.empty((bs2, *shape), device=device, dtype=dtype).fill_(default)
|
||||
return torch.cat([lhs, rhs])
|
||||
|
||||
|
||||
def find_local_repo_dir(repo_id: str, revision: Optional[str] = None) -> Optional[str]:
|
||||
import huggingface_hub as hf
|
||||
|
||||
|
||||
Reference in New Issue
Block a user