From e30ef368abd3be9293838e133d33715a7515a99e Mon Sep 17 00:00:00 2001 From: woodx <124784234+woodx9@users.noreply.github.com> Date: Tue, 17 Jun 2025 01:50:01 +0800 Subject: [PATCH] Feat/support rerank (#6058) --- python/sglang/srt/configs/model_config.py | 5 + python/sglang/srt/entrypoints/engine.py | 14 ++ python/sglang/srt/entrypoints/http_server.py | 11 ++ python/sglang/srt/layers/activation.py | 19 +++ python/sglang/srt/layers/pooler.py | 56 ++++++++ python/sglang/srt/managers/io_struct.py | 22 ++- python/sglang/srt/managers/schedule_batch.py | 21 +++ python/sglang/srt/managers/scheduler.py | 1 + .../sglang/srt/managers/tokenizer_manager.py | 17 ++- .../srt/model_executor/forward_batch_info.py | 6 +- python/sglang/srt/models/bert.py | 126 ++++++++++++++++-- python/sglang/srt/models/roberta.py | 126 ++++++++++++++++-- python/sglang/srt/openai_api/adapter.py | 65 ++++++++- python/sglang/srt/openai_api/protocol.py | 7 + python/sglang/test/runners.py | 41 +++++- python/sglang/test/test_utils.py | 2 + python/sglang/utils.py | 9 ++ test/srt/models/test_cross_encoder_models.py | 91 +++++++++++++ test/srt/run_suite.py | 2 + test/srt/test_openai_server.py | 73 ++++++++++ 20 files changed, 684 insertions(+), 30 deletions(-) create mode 100644 test/srt/models/test_cross_encoder_models.py diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index b52ae3957..6ddd2484a 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -550,6 +550,11 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal or "Qwen2ForRewardModel" in model_architectures or "Qwen2ForSequenceClassification" in model_architectures or "CLIPModel" in model_architectures + or "BertModel" in model_architectures + or "Contriever" in model_architectures + or "BertForSequenceClassification" in model_architectures + or "XLMRobertaModel" in model_architectures + or "XLMRobertaForSequenceClassification" in model_architectures ): return False else: diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 45e159d63..96d5e0801 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -327,6 +327,20 @@ class Engine(EngineBase): generator = self.tokenizer_manager.generate_request(obj, None) return await generator.__anext__() + def rerank( + self, + prompt: Union[List[List[str]]], + ) -> Dict: + """ + The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`. + Please refer to `EmbeddingReqInput` for the documentation. + """ + obj = EmbeddingReqInput(text=prompt, is_cross_encoder_request=True) + loop = asyncio.get_event_loop() + generator = self.tokenizer_manager.generate_request(obj, None) + ret = loop.run_until_complete(generator.__anext__()) + return ret + def shutdown(self): """Shutdown the engine""" kill_process_tree(os.getpid(), include_parent=False) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 89417fd86..9262d10a9 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -67,6 +67,7 @@ from sglang.srt.managers.io_struct import ( UpdateWeightFromDiskReqInput, UpdateWeightsFromDistributedReqInput, UpdateWeightsFromTensorReqInput, + V1RerankReqInput, VertexGenerateReqInput, ) from sglang.srt.managers.tokenizer_manager import TokenizerManager @@ -79,6 +80,7 @@ from sglang.srt.openai_api.adapter import ( v1_delete_file, v1_embeddings, v1_files_create, + v1_rerank, v1_retrieve_batch, v1_retrieve_file, v1_retrieve_file_content, @@ -328,6 +330,15 @@ async def classify_request(obj: EmbeddingReqInput, request: Request): return _create_error_response(e) +@app.api_route("/v1/rerank", methods=["POST", "PUT"]) +async def v1_rerank_request(obj: V1RerankReqInput, raw_request: Request): + try: + ret = await v1_rerank(_global_state.tokenizer_manager, obj, raw_request) + return ret + except ValueError as e: + return _create_error_response(e) + + @app.api_route("/flush_cache", methods=["GET", "POST"]) async def flush_cache(): """Flush the radix cache.""" diff --git a/python/sglang/srt/layers/activation.py b/python/sglang/srt/layers/activation.py index b428246a7..2e200be36 100644 --- a/python/sglang/srt/layers/activation.py +++ b/python/sglang/srt/layers/activation.py @@ -20,6 +20,7 @@ from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F +from transformers import PretrainedConfig from sglang.srt.custom_op import CustomOp from sglang.srt.distributed import ( @@ -29,6 +30,7 @@ from sglang.srt.distributed import ( ) from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.utils import is_cuda, set_weight_attrs +from sglang.utils import resolve_obj_by_qualname _is_cuda = is_cuda() @@ -165,6 +167,23 @@ def get_act_fn( return act_fn +def get_cross_encoder_activation_function(config: PretrainedConfig): + if ( + hasattr(config, "sbert_ce_default_activation_function") + and config.sbert_ce_default_activation_function is not None + ): + + function_name = config.sbert_ce_default_activation_function + assert function_name.startswith("torch.nn.modules."), ( + "Loading of activation functions is restricted to " + "torch.nn.modules for security reasons" + ) + return resolve_obj_by_qualname(function_name)() + else: + # adapt bge-reranker + return nn.Identity() + + if not _is_cuda: logger.info( "sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries." diff --git a/python/sglang/srt/layers/pooler.py b/python/sglang/srt/layers/pooler.py index 7ee8dbcc2..26bc5899e 100644 --- a/python/sglang/srt/layers/pooler.py +++ b/python/sglang/srt/layers/pooler.py @@ -3,10 +3,13 @@ from dataclasses import dataclass from enum import IntEnum +from typing import Optional import torch import torch.nn as nn +from transformers import PretrainedConfig +from sglang.srt.layers.activation import get_cross_encoder_activation_function from sglang.srt.model_executor.model_runner import ForwardBatch @@ -54,3 +57,56 @@ class Pooler(nn.Module): pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1) return EmbeddingPoolerOutput(embeddings=pooled_data) + + +class CrossEncodingPooler(nn.Module): + """A layer that pools specific information from hidden states. + + This layer does the following: + 1. Extracts specific tokens or aggregates data based on pooling method. + 2. Normalizes output if specified. + 3. Returns structured results as `EmbeddingPoolerOutput`. + """ + + def __init__( + self, + config: PretrainedConfig, + classifier: nn.Module, + pooler: Optional[nn.Module] = None, + ): + super().__init__() + self.classifier = classifier + self.pooler = pooler + self.default_activation_function = get_cross_encoder_activation_function(config) + + def forward( + self, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + ) -> EmbeddingPoolerOutput: + """Pools sentence pair scores from the hidden_states.""" + + prompt_lens = forward_batch.extend_seq_lens + + offset = 0 + pooled_data_lst = [] + for prompt_len in prompt_lens: + pooled_data_i = hidden_states[offset : offset + prompt_len] + + if self.pooler is not None: + final_shape_tensor = self.pooler(pooled_data_i, forward_batch) + else: + final_shape_tensor = self.classifier(pooled_data_i) + + pooled_data_lst.append(final_shape_tensor) + offset += prompt_len + + pooled_output = torch.stack(pooled_data_lst) + + if self.pooler is not None: + # apply classifier once on the full batch if possible + pooled_output = self.classifier(pooled_output) + + scores = self.default_activation_function(pooled_output).squeeze(-1) + + return EmbeddingPoolerOutput(embeddings=scores) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 3d3c9a270..c94d81eb3 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -481,7 +481,7 @@ class TokenizedGenerateReqInput: @dataclass class EmbeddingReqInput: # The input prompt. It can be a single prompt or a batch of prompts. - text: Optional[Union[List[str], str]] = None + text: Optional[Union[List[List[str]], List[str], str]] = None # The image input. It can be an image instance, file name, URL, or base64 encoded string. # Can be formatted as: # - Single image for a single request @@ -505,6 +505,8 @@ class EmbeddingReqInput: log_metrics: bool = True # The modalities of the image data [image, multi-images, video] modalities: Optional[List[str]] = None + # For cross-encoder requests + is_cross_encoder_request: bool = False def contains_mm_input(self) -> bool: return has_valid_data(self.image_data) or has_valid_data(self.audio_data) @@ -564,6 +566,16 @@ class EmbeddingReqInput: return self.rid def __getitem__(self, i): + if self.is_cross_encoder_request: + return EmbeddingReqInput( + text=[self.text[i]] if self.text is not None else None, + input_ids=None, + image_data=None, + sampling_params=self.sampling_params[i], + rid=self.rid[i], + is_cross_encoder_request=True, + ) + return EmbeddingReqInput( text=self.text[i] if self.text is not None else None, input_ids=self.input_ids[i] if self.input_ids is not None else None, @@ -583,6 +595,8 @@ class TokenizedEmbeddingReqInput: input_ids: List[int] # The image inputs image_inputs: dict + # The token type ids + token_type_ids: List[int] # Dummy sampling params for compatibility sampling_params: SamplingParams @@ -847,6 +861,12 @@ class SetInternalStateReq: server_args: Dict[str, Any] +@dataclass +class V1RerankReqInput: + query: str + documents: List[str] + + @dataclass class SetInternalStateReqOutput: updated: bool diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 9ca8700f0..369340553 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -445,6 +445,7 @@ class Req: origin_input_ids_unpadded: Optional[Tuple[int]] = None, lora_path: Optional[str] = None, input_embeds: Optional[List[List[float]]] = None, + token_type_ids: List[int] = None, session_id: Optional[str] = None, custom_logit_processor: Optional[str] = None, return_hidden_states: bool = False, @@ -470,6 +471,9 @@ class Req: self.session_id = session_id self.input_embeds = input_embeds + # for corss-endoder model + self.token_type_ids = token_type_ids + # Sampling info if isinstance(sampling_params.custom_params, dict): sampling_params = copy.copy(sampling_params) @@ -841,6 +845,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): # Batched arguments to model runner input_ids: torch.Tensor = None # shape: [b], int64 input_embeds: torch.Tensor = None # shape: [b, hidden_size], float32 + token_type_ids: torch.Tensor = None # shape: [b], int64 req_pool_indices: torch.Tensor = None # shape: [b], int64 seq_lens: torch.Tensor = None # shape: [b], int64 # The output locations of the KV cache @@ -1142,6 +1147,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): prefix_lens = [len(r.prefix_indices) for r in reqs] extend_lens = [r.extend_input_len for r in reqs] + token_type_ids = [ + r.token_type_ids for r in reqs if r.token_type_ids is not None + ] + req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to( self.device, non_blocking=True ) @@ -1154,6 +1163,13 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): prefix_lens_tensor = torch.tensor( prefix_lens, dtype=torch.int64, device=self.device ) + + token_type_ids_tensor = None + if len(token_type_ids) > 0: + token_type_ids_tensor = torch.tensor( + sum(token_type_ids, []), dtype=torch.int64 + ).to(self.device, non_blocking=True) + extend_lens_tensor = seq_lens_tensor - prefix_lens_tensor # Copy prefix and do some basic check @@ -1269,6 +1285,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): self.device, non_blocking=True ) self.multimodal_inputs = multimodal_inputs + self.token_type_ids = token_type_ids_tensor self.seq_lens_sum = sum(seq_lens) if self.return_logprob: @@ -1714,6 +1731,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): lora_paths=[req.lora_path for req in self.reqs], sampling_info=self.sampling_info, input_embeds=self.input_embeds, + token_type_ids=self.token_type_ids, spec_algorithm=self.spec_algorithm, spec_info=self.spec_info, capture_hidden_mode=( @@ -1807,6 +1825,9 @@ class ModelWorkerBatch: # The input Embeds input_embeds: Optional[torch.tensor] = None + # For corss-encoder model + token_type_ids: Optional[torch.Tensor] = None + # Speculative decoding spec_algorithm: SpeculativeAlgorithm = None spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 8d325791a..6b5a03b82 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1150,6 +1150,7 @@ class Scheduler( recv_req.input_text, recv_req.input_ids, recv_req.sampling_params, + token_type_ids=recv_req.token_type_ids, ) req.tokenizer = self.tokenizer diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index e6c3189cb..b6e584d56 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -459,6 +459,10 @@ class TokenizerManager: # Tokenize input_embeds = None input_text = obj.text + token_type_ids = None + is_cross_encoder_request = ( + isinstance(obj, EmbeddingReqInput) and obj.is_cross_encoder_request + ) if obj.input_embeds is not None: if not self.server_args.disable_radix_cache: raise ValueError( @@ -477,7 +481,14 @@ class TokenizerManager: "accept text prompts. Please provide input_ids or re-initialize " "the engine with skip_tokenizer_init=False." ) - input_ids = self.tokenizer.encode(input_text) + encoded = self.tokenizer( + input_text, return_token_type_ids=is_cross_encoder_request + ) + + input_ids = encoded["input_ids"] + if is_cross_encoder_request: + input_ids = encoded["input_ids"][0] + token_type_ids = encoded.get("token_type_ids", [None])[0] if self.mm_processor and obj.contains_mm_input(): image_inputs = await self.mm_processor.process_mm_data_async( @@ -493,7 +504,7 @@ class TokenizerManager: self._validate_token_len(obj, input_ids) return self._create_tokenized_object( - obj, input_text, input_ids, input_embeds, image_inputs + obj, input_text, input_ids, input_embeds, image_inputs, token_type_ids ) def _validate_token_len( @@ -532,6 +543,7 @@ class TokenizerManager: input_ids: List[int], input_embeds: Optional[Union[List[float], None]] = None, image_inputs: Optional[Dict] = None, + token_type_ids: Optional[List[int]] = None, ) -> Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]: """Create a tokenized request object from common parameters.""" @@ -592,6 +604,7 @@ class TokenizerManager: input_text, input_ids, image_inputs, + token_type_ids, sampling_params, ) diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 1205ebee6..97e48c10d 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -224,6 +224,9 @@ class ForwardBatch: # For input embeddings input_embeds: Optional[torch.tensor] = None + # For cross-encoder model + token_type_ids: Optional[torch.Tensor] = None + # Sampling info sampling_info: SamplingBatchInfo = None @@ -300,6 +303,7 @@ class ForwardBatch: spec_info=batch.spec_info, capture_hidden_mode=batch.capture_hidden_mode, input_embeds=batch.input_embeds, + token_type_ids=batch.token_type_ids, tbo_split_seq_index=batch.tbo_split_seq_index, ) device = model_runner.device @@ -356,8 +360,8 @@ class ForwardBatch: ret.extend_prefix_lens = torch.tensor( batch.extend_prefix_lens, dtype=torch.int32 ).to(device, non_blocking=True) + ret.extend_num_tokens = batch.extend_num_tokens if support_triton(model_runner.server_args.attention_backend): - ret.extend_num_tokens = batch.extend_num_tokens positions, ret.extend_start_loc = compute_position_triton( ret.extend_prefix_lens, ret.extend_seq_lens, diff --git a/python/sglang/srt/models/bert.py b/python/sglang/srt/models/bert.py index 46d2e7265..d7f3301c6 100644 --- a/python/sglang/srt/models/bert.py +++ b/python/sglang/srt/models/bert.py @@ -11,12 +11,13 @@ from sglang.srt.layers.linear import ( QKVParallelLinear, RowParallelLinear, ) -from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType +from sglang.srt.layers.pooler import CrossEncodingPooler, Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import AttentionType, RadixAttention from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.utils import add_prefix BertConfig = None @@ -50,7 +51,8 @@ class BertEmbedding(nn.Module): def forward( self, input_ids: torch.Tensor, - position_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, ) -> torch.Tensor: input_shape = input_ids.size() @@ -58,11 +60,14 @@ class BertEmbedding(nn.Module): inputs_embeds = self.word_embeddings(input_ids) # Position embeddings. - position_embeddings = self.position_embeddings(position_ids) + position_embeddings = self.position_embeddings(positions) - token_type_ids = torch.zeros( - input_shape, dtype=torch.long, device=inputs_embeds.device - ) + token_type_ids = forward_batch.token_type_ids + + if token_type_ids is None: + token_type_ids = torch.zeros( + input_shape, dtype=torch.long, device=inputs_embeds.device + ) token_type_embeddings = self.token_type_embeddings(token_type_ids) @@ -71,6 +76,25 @@ class BertEmbedding(nn.Module): return embeddings +class BertPooler(nn.Module): + + def __init__(self, config: BertConfig): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward( + self, hidden_states: torch.Tensor, forward_batch: ForwardBatch + ) -> torch.Tensor: + # simply taking the hidden state corresponding + first_token_tensor = hidden_states[0, :] + + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + + return pooled_output + + class BertEncoder(nn.Module): def __init__( @@ -113,6 +137,8 @@ class BertLayer(nn.Module): ): super().__init__() + self.layer_id = layer_id + self.attention = BertAttention( hidden_size=config.hidden_size, num_attention_heads=config.num_attention_heads, @@ -142,6 +168,7 @@ class BertLayer(nn.Module): attn_output = self.attention(hidden_states, forward_batch) intermediate_output = self.intermediate(attn_output) output = self.output(intermediate_output, attn_output) + return output @@ -326,16 +353,23 @@ class BertModel(nn.Module): *, config: BertConfig, quant_config: Optional[QuantizationConfig] = None, + use_bert_pooler: bool = False, prefix: str = "", ): super().__init__() + self.use_bert_pooler = use_bert_pooler self.config = config self.embeddings = BertEmbedding(config) self.encoder = BertEncoder( - config=config, quant_config=quant_config, prefix=f"encoder" + config=config, + quant_config=quant_config, + prefix=add_prefix("encoder", prefix), + ) + self.pooler = ( + BertPooler(config) + if self.use_bert_pooler + else Pooler(pooling_type=PoolingType.LAST, normalize=True) ) - self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) - # self.pooler = BertPooler(config) @torch.no_grad() def forward( @@ -351,11 +385,16 @@ class BertModel(nn.Module): hidden_states = self.embeddings( input_ids=input_ids, - position_ids=positions, + positions=positions, + forward_batch=forward_batch, ) hidden_states = self.encoder(hidden_states, forward_batch=forward_batch) - return self.pooler(hidden_states, forward_batch) + + if not self.use_bert_pooler: + hidden_states = self.pooler(hidden_states, forward_batch) + + return hidden_states def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ @@ -368,7 +407,7 @@ class BertModel(nn.Module): params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: name = name.replace("self", "self_attn") - if "pooler" in name: + if not self.use_bert_pooler and "pooler" in name: continue for param_name, weight_name, shard_id in stacked_params_mapping: @@ -395,4 +434,65 @@ class Contriever(BertModel): pass -EntryClass = [BertModel, Contriever] +class BertForSequenceClassification(nn.Module): + + def __init__( + self, + *, + config: BertConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + + self.num_labels = config.num_labels + self.bert = BertModel( + config=config, + quant_config=quant_config, + use_bert_pooler=True, + prefix=add_prefix("bert", prefix), + ) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + self.pooler = CrossEncodingPooler(config, self.classifier, self.bert.pooler) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + self_weights = [] + + def weight_filter(): + for name, weight in weights: + if name.startswith("bert."): + yield (name[len("bert.") :], weight) + else: + self_weights.append((name, weight)) + + self.bert.load_weights(weight_filter()) + + params_dict = dict(self.named_parameters()) + + for name, loaded_weight in self_weights: + if name.startswith("classifier"): + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + get_embedding: bool = False, + ) -> torch.Tensor: + assert get_embedding == True + + hidden_states = self.bert( + input_ids=input_ids, + positions=positions, + forward_batch=forward_batch, + input_embeds=input_embeds, + get_embedding=get_embedding, + ) + return self.pooler(hidden_states, forward_batch) + + +EntryClass = [BertModel, Contriever, BertForSequenceClassification] diff --git a/python/sglang/srt/models/roberta.py b/python/sglang/srt/models/roberta.py index b982bc8e3..209be1296 100644 --- a/python/sglang/srt/models/roberta.py +++ b/python/sglang/srt/models/roberta.py @@ -6,7 +6,7 @@ from typing import Iterable, Optional, Tuple import torch from torch import nn -from sglang.srt.layers.pooler import Pooler, PoolingType +from sglang.srt.layers.pooler import CrossEncodingPooler, Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.model_executor.forward_batch_info import ForwardBatch @@ -16,6 +16,23 @@ from sglang.srt.models.bert import BertEncoder RobertaConfig = None +# Adapted from transformers +class RobertaClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, config: RobertaConfig): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.out_proj = nn.Linear(config.hidden_size, config.num_labels) + + def forward(self, features, **kwargs): + x = features[0, :] # take token (equiv. to [CLS]) + x = self.dense(x) + x = torch.tanh(x) + x = self.out_proj(x) + return x + + class RobertaEmbedding(nn.Module): def __init__(self, config: RobertaConfig): @@ -51,8 +68,7 @@ class RobertaEmbedding(nn.Module): input_ids: torch.Tensor, seq_lens: torch.Tensor, position_ids: torch.Tensor, - inputs_embeds=None, - token_type_ids: Optional[torch.Tensor] = None, + forward_batch: ForwardBatch, ) -> torch.Tensor: input_shape = input_ids.size() inputs_embeds = self.word_embeddings(input_ids) @@ -82,6 +98,8 @@ class RobertaEmbedding(nn.Module): # Position embeddings. position_embeddings = self.position_embeddings(position_ids) + + token_type_ids = forward_batch.token_type_ids if token_type_ids is None: token_type_ids = torch.zeros( input_shape, dtype=torch.long, device=inputs_embeds.device @@ -93,20 +111,25 @@ class RobertaEmbedding(nn.Module): return embeddings -class XLMRobertaModel(nn.Module): +class XLMRobertaBaseModel(nn.Module): def __init__( self, *, config: RobertaConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + add_pooling_layer: bool = False, ): super().__init__() self.config = config self.embeddings = RobertaEmbedding(config) self.encoder = BertEncoder(config=config, quant_config=quant_config, prefix="") - self.pooler = Pooler(pooling_type=PoolingType.CLS, normalize=True) + self.pooler = ( + Pooler(pooling_type=PoolingType.CLS, normalize=True) + if add_pooling_layer + else None + ) @torch.no_grad() def forward( @@ -124,11 +147,12 @@ class XLMRobertaModel(nn.Module): input_ids=input_ids, position_ids=positions, seq_lens=forward_batch.seq_lens, + forward_batch=forward_batch, ) hidden_states = self.encoder(hidden_states, forward_batch=forward_batch) - pooler_out = self.pooler(hidden_states, forward_batch) - return pooler_out + + return hidden_states def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ @@ -141,7 +165,7 @@ class XLMRobertaModel(nn.Module): params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: name = name.replace("self", "self_attn") - if "pooler" in name: + if self.pooler is None and "pooler" in name: continue for param_name, weight_name, shard_id in stacked_params_mapping: @@ -175,4 +199,88 @@ def create_position_ids_from_input_ids( return incremental_indices.long() + padding_idx -EntryClass = [XLMRobertaModel] +class XLMRobertaModel(nn.Module): + def __init__( + self, + *, + config: RobertaConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.roberta = XLMRobertaBaseModel( + config=config, quant_config=quant_config, prefix=prefix + ) + self.pooler = Pooler(pooling_type=PoolingType.CLS, normalize=True) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + get_embedding: bool = False, + ) -> torch.Tensor: + hidden_states = self.roberta( + input_ids, positions, forward_batch, input_embeds, get_embedding + ) + return self.pooler(hidden_states, forward_batch) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + self.roberta.load_weights(weights) + + +class XLMRobertaForSequenceClassification(nn.Module): + def __init__( + self, + *, + config: RobertaConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.roberta = XLMRobertaBaseModel( + config=config, quant_config=quant_config, prefix=prefix + ) + self.classifier = RobertaClassificationHead(config) + self.pooler = CrossEncodingPooler(config, self.classifier, self.roberta.pooler) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + get_embedding: bool = True, + ) -> torch.Tensor: + assert ( + get_embedding + ), "XLMRobertaForSequenceClassification is only used for rerank" + + hidden_states = self.roberta( + input_ids, positions, forward_batch, input_embeds, get_embedding + ) + return self.pooler(hidden_states, forward_batch) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + self_weights = [] + + def weight_filter(): + for name, weight in weights: + if name.startswith("roberta."): + yield (name[len("roberta.") :], weight) + else: + self_weights.append((name, weight)) + + self.roberta.load_weights(weight_filter()) + + params_dict = dict(self.named_parameters()) + + for name, loaded_weight in self_weights: + if name.startswith("classifier"): + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +EntryClass = [XLMRobertaModel, XLMRobertaForSequenceClassification] diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index b37e8d13a..aba1a5afd 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -41,7 +41,11 @@ from sglang.srt.conversation import ( register_conv_template, ) from sglang.srt.function_call.function_call_parser import FunctionCallParser -from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput +from sglang.srt.managers.io_struct import ( + EmbeddingReqInput, + GenerateReqInput, + V1RerankReqInput, +) from sglang.srt.openai_api.protocol import ( BatchRequest, BatchResponse, @@ -69,6 +73,7 @@ from sglang.srt.openai_api.protocol import ( FunctionResponse, LogProbs, MultimodalEmbeddingInput, + RerankResponse, ScoringRequest, ScoringResponse, ToolCall, @@ -2020,6 +2025,64 @@ async def v1_embeddings(tokenizer_manager, raw_request: Request): return response +def v1_rerank_request(obj: V1RerankReqInput): + if obj.query is None: + raise ValueError("query is required") + if obj.documents is None or len(obj.documents) == 0: + raise ValueError("documents is required") + + pairs = [] + for doc in obj.documents: + pairs.append([obj.query, doc]) + + adapted_request = EmbeddingReqInput( + text=pairs, + is_cross_encoder_request=True, + ) + + return adapted_request + + +def v1_rerank_response(ret, obj: V1RerankReqInput): + + response = [] + for idx, ret_item in enumerate(ret): + response.append( + RerankResponse( + score=ret[idx]["embedding"], + document=obj.documents[idx], + index=idx, + meta_info=ret[idx]["meta_info"], + ) + ) + + response.sort(key=lambda x: x.score, reverse=True) + + return response + + +async def v1_rerank(tokenizer_manager, obj: V1RerankReqInput, raw_request: Request): + adapted_request = v1_rerank_request(obj) + + try: + ret = await tokenizer_manager.generate_request( + adapted_request, raw_request + ).__anext__() + + except ValueError as e: + return create_error_response(str(e)) + + if not isinstance(ret, list): + ret = [ret] + + response = v1_rerank_response( + ret, + obj, + ) + + return response + + def to_openai_style_logprobs( input_token_logprobs=None, output_token_logprobs=None, diff --git a/python/sglang/srt/openai_api/protocol.py b/python/sglang/srt/openai_api/protocol.py index 2d2b76155..71153b912 100644 --- a/python/sglang/srt/openai_api/protocol.py +++ b/python/sglang/srt/openai_api/protocol.py @@ -539,6 +539,13 @@ class ScoringResponse(BaseModel): object: str = "scoring" +class RerankResponse(BaseModel): + score: float + document: str + index: int + meta_info: Optional[dict] = None + + def exclude_if_none(obj, field_names: List[str]): omit_if_none_fields = {k for k, v in obj.model_fields.items() if k in field_names} return {k: v for k, v in obj if k not in omit_if_none_fields or v is not None} diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index f5c8365a7..b51597d96 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -42,6 +42,21 @@ DEFAULT_PROMPTS = [ # the output of gemma-2-2b from SRT is unstable on the commented prompt # "The capital of France is", ] +TEST_RERANK_QUERY_DOCS = [ + { + "query": "How many people live in Berlin?", + "documents": [ + "Berlin is well known for its museums.", + ], + }, + { + "query": "How many people live in Berlin?", + "documents": [ + "Berlin had a population of 3,520,031 registered inhabitants in an area of 891.82 square kilometers.", + "Berlin is well known for its museums.", + ], + }, +] dirpath = os.path.dirname(__file__) with open(os.path.join(dirpath, "long_prompt.txt"), "r") as f: @@ -241,7 +256,7 @@ class HFRunner: self.model = _get_sentence_transformer_embedding_model( model_path, torch_dtype ) - elif self.model_type == "reward": + elif self.model_type == "reward" or self.model_type == "cross_encoder": from transformers import AutoModelForSequenceClassification self.model = AutoModelForSequenceClassification.from_pretrained( @@ -303,6 +318,15 @@ class HFRunner: else: logits = self.model.encode(prompts).tolist() out_queue.put(ModelOutput(embed_logits=logits)) + elif self.model_type == "cross_encoder": + inputs = self.tokenizer( + prompts, padding=True, return_tensors="pt" + ).to("cuda") + scores = self.model(**inputs).logits + scores = scores.squeeze().tolist() + if not isinstance(scores, list): + scores = [scores] + out_queue.put(ModelOutput(scores=scores)) elif self.model_type == "reward": scores = [] @@ -322,7 +346,9 @@ class HFRunner: def forward( self, - prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS, + prompts: Union[ + List[List[str]], List[str], List[torch.Tensor] + ] = DEFAULT_PROMPTS, image_data: Optional[List[str]] = None, max_new_tokens: int = 8, lora_paths: Optional[List[str]] = None, @@ -526,7 +552,9 @@ class SRTRunner: def forward( self, - prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS, + prompts: Union[ + List[List[str]], List[str], List[torch.Tensor] + ] = DEFAULT_PROMPTS, image_data: Optional[List[str]] = None, max_new_tokens: int = 8, lora_paths: Optional[List[str]] = None, @@ -552,6 +580,13 @@ class SRTRunner: else: logits = [response["embedding"]] return ModelOutput(embed_logits=logits) + # cross encoder model + elif self.model_type == "cross_encoder": + response = self.engine.rerank(prompts) + if not isinstance(response, list): + response = [response] + scores = [x["embedding"] for x in response] + return ModelOutput(scores=scores) # reward model else: response = self.engine.encode(prompts) diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index c83a3ed1e..9e3011dd8 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -41,6 +41,8 @@ DEFAULT_MOE_MODEL_NAME_FOR_TEST = "mistralai/Mixtral-8x7B-Instruct-v0.1" DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST = "Qwen/Qwen1.5-MoE-A2.7B" # MLA test models +DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST = "Alibaba-NLP/gte-Qwen2-1.5B-instruct" +DEFAULT_SMALL_CROSS_ENCODER_MODEL_NAME_FOR_TEST = "cross-encoder/ms-marco-MiniLM-L6-v2" DEFAULT_MLA_MODEL_NAME_FOR_TEST = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct" DEFAULT_MLA_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8" DEFAULT_MODEL_NAME_FOR_TEST_MLA = "lmsys/sglang-ci-dsv3-test" diff --git a/python/sglang/utils.py b/python/sglang/utils.py index 1d994c3b5..6b3f36e19 100644 --- a/python/sglang/utils.py +++ b/python/sglang/utils.py @@ -512,3 +512,12 @@ async def async_stream_and_merge(llm, prompt, sampling_params): cleaned_chunk = trim_overlap(final_text, chunk_text) final_text += cleaned_chunk yield cleaned_chunk # yield the non-overlapping portion + + +def resolve_obj_by_qualname(qualname: str) -> Any: + """ + Resolve an object by its fully qualified name. + """ + module_name, obj_name = qualname.rsplit(".", 1) + module = importlib.import_module(module_name) + return getattr(module, obj_name) diff --git a/test/srt/models/test_cross_encoder_models.py b/test/srt/models/test_cross_encoder_models.py new file mode 100644 index 000000000..93edc3fa1 --- /dev/null +++ b/test/srt/models/test_cross_encoder_models.py @@ -0,0 +1,91 @@ +import multiprocessing as mp +import random +import unittest + +import torch +from transformers import AutoConfig, AutoTokenizer + +from sglang.test.runners import TEST_RERANK_QUERY_DOCS, HFRunner, SRTRunner +from sglang.test.test_utils import CustomTestCase, is_in_ci + +MODELS = [ + ("cross-encoder/ms-marco-MiniLM-L6-v2", 1, 1e-2), + ("BAAI/bge-reranker-v2-m3", 1, 1e-2), +] +ATTENTION_BACKEND = ["torch_native", "triton"] + +TORCH_DTYPES = [torch.float32] + + +class TestCrossEncoderModels(CustomTestCase): + + @classmethod + def setUpClass(cls): + mp.set_start_method("spawn", force=True) + + def assert_close_prefill_logits( + self, + prompts, + model_path, + tp_size, + torch_dtype, + score_tolerance, + attention_backend, + ) -> None: + with HFRunner( + model_path, + torch_dtype=torch_dtype, + model_type="cross_encoder", + ) as hf_runner: + hf_scores = hf_runner.forward(prompts).scores + + with SRTRunner( + model_path, + tp_size=tp_size, + torch_dtype=torch_dtype, + model_type="cross_encoder", + attention_backend=attention_backend, + chunked_prefill_size=-1, + disable_radix_cache=True, + ) as srt_runner: + srt_scores = srt_runner.forward(prompts).scores + + for i in range(len(srt_scores)): + score_difference = abs(hf_scores[i] - srt_scores[i]) + + assert ( + score_difference < score_tolerance + ), "cross encoder scores are not all close" + + def preprocess_prompts(self, prompt): + processed_prompts = [] + query = prompt["query"] + documents = prompt["documents"] + for document in documents: + processed_prompts.append([query, document]) + + return processed_prompts + + def test_prefill_logits(self): + models_to_test = MODELS + + if is_in_ci(): + models_to_test = [random.choice(MODELS)] + + for model, tp_size, prefill_tolerance in models_to_test: + for attention_backend in ATTENTION_BACKEND: + for queryDocs in TEST_RERANK_QUERY_DOCS: + prompts = self.preprocess_prompts(queryDocs) + for torch_dtype in TORCH_DTYPES: + self.assert_close_prefill_logits( + prompts, + model, + tp_size, + torch_dtype, + prefill_tolerance, + attention_backend, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 8cdde63e2..a2e39b9aa 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -19,6 +19,8 @@ suites = { TestFile("models/lora/test_lora_cuda_graph.py", 250), TestFile("models/test_embedding_models.py", 73), # TestFile("models/test_clip_models.py", 52), + TestFile("models/test_encoder_embedding_models.py", 100), + TestFile("models/test_cross_encoder_models.py", 100), TestFile("models/test_compressed_tensors_models.py", 42), TestFile("models/test_generation_models.py", 103), # TestFile("models/test_gme_qwen_models.py", 45), diff --git a/test/srt/test_openai_server.py b/test/srt/test_openai_server.py index d10b953c0..4913eb38c 100644 --- a/test/srt/test_openai_server.py +++ b/test/srt/test_openai_server.py @@ -17,7 +17,9 @@ import requests from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.utils import kill_process_tree +from sglang.test.runners import TEST_RERANK_QUERY_DOCS from sglang.test.test_utils import ( + DEFAULT_SMALL_CROSS_ENCODER_MODEL_NAME_FOR_TEST, DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST, DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -699,6 +701,77 @@ class TestOpenAIEmbedding(CustomTestCase): self.assertEqual(cm.exception.status_code, 400) +class TestOpenAIV1Rerank(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_SMALL_CROSS_ENCODER_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.api_key = "sk-123456" + cls.score_tolerance = 1e-2 + + # Configure embedding-specific args + other_args = [ + "--is-embedding", + "--enable-metrics", + "--disable-radix-cache", + "--chunked-prefill-size", + "-1", + "--attention-backend", + "torch_native", + ] + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + api_key=cls.api_key, + other_args=other_args, + ) + cls.base_url += "/v1/rerank" + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def run_rerank(self, query, docs): + response = requests.post( + self.base_url, + headers={ + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + }, + json={"query": query, "documents": docs}, + ) + + return response.json() + + def test_rerank_single(self): + """Test single rerank request""" + query = TEST_RERANK_QUERY_DOCS[0]["query"] + docs = TEST_RERANK_QUERY_DOCS[0]["documents"] + + response = self.run_rerank(query, docs) + + self.assertEqual(len(response), 1) + self.assertTrue(isinstance(response[0]["score"], float)) + self.assertTrue(isinstance(response[0]["document"], str)) + self.assertTrue(isinstance(response[0]["index"], int)) + + def test_rerank_batch(self): + """Test batch rerank request""" + query = TEST_RERANK_QUERY_DOCS[1]["query"] + docs = TEST_RERANK_QUERY_DOCS[1]["documents"] + + response = self.run_rerank(query, docs) + + self.assertEqual(len(response), 2) + self.assertTrue(isinstance(response[0]["score"], float)) + self.assertTrue(isinstance(response[1]["score"], float)) + self.assertTrue(isinstance(response[0]["document"], str)) + self.assertTrue(isinstance(response[1]["document"], str)) + self.assertTrue(isinstance(response[0]["index"], int)) + self.assertTrue(isinstance(response[1]["index"], int)) + + class TestOpenAIServerIgnoreEOS(CustomTestCase): @classmethod def setUpClass(cls):