Feat/support rerank (#6058)
This commit is contained in:
@@ -550,6 +550,11 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
|
||||
or "Qwen2ForRewardModel" in model_architectures
|
||||
or "Qwen2ForSequenceClassification" in model_architectures
|
||||
or "CLIPModel" in model_architectures
|
||||
or "BertModel" in model_architectures
|
||||
or "Contriever" in model_architectures
|
||||
or "BertForSequenceClassification" in model_architectures
|
||||
or "XLMRobertaModel" in model_architectures
|
||||
or "XLMRobertaForSequenceClassification" in model_architectures
|
||||
):
|
||||
return False
|
||||
else:
|
||||
|
||||
@@ -327,6 +327,20 @@ class Engine(EngineBase):
|
||||
generator = self.tokenizer_manager.generate_request(obj, None)
|
||||
return await generator.__anext__()
|
||||
|
||||
def rerank(
|
||||
self,
|
||||
prompt: Union[List[List[str]]],
|
||||
) -> Dict:
|
||||
"""
|
||||
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`.
|
||||
Please refer to `EmbeddingReqInput` for the documentation.
|
||||
"""
|
||||
obj = EmbeddingReqInput(text=prompt, is_cross_encoder_request=True)
|
||||
loop = asyncio.get_event_loop()
|
||||
generator = self.tokenizer_manager.generate_request(obj, None)
|
||||
ret = loop.run_until_complete(generator.__anext__())
|
||||
return ret
|
||||
|
||||
def shutdown(self):
|
||||
"""Shutdown the engine"""
|
||||
kill_process_tree(os.getpid(), include_parent=False)
|
||||
|
||||
@@ -67,6 +67,7 @@ from sglang.srt.managers.io_struct import (
|
||||
UpdateWeightFromDiskReqInput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
UpdateWeightsFromTensorReqInput,
|
||||
V1RerankReqInput,
|
||||
VertexGenerateReqInput,
|
||||
)
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
@@ -79,6 +80,7 @@ from sglang.srt.openai_api.adapter import (
|
||||
v1_delete_file,
|
||||
v1_embeddings,
|
||||
v1_files_create,
|
||||
v1_rerank,
|
||||
v1_retrieve_batch,
|
||||
v1_retrieve_file,
|
||||
v1_retrieve_file_content,
|
||||
@@ -328,6 +330,15 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
|
||||
return _create_error_response(e)
|
||||
|
||||
|
||||
@app.api_route("/v1/rerank", methods=["POST", "PUT"])
|
||||
async def v1_rerank_request(obj: V1RerankReqInput, raw_request: Request):
|
||||
try:
|
||||
ret = await v1_rerank(_global_state.tokenizer_manager, obj, raw_request)
|
||||
return ret
|
||||
except ValueError as e:
|
||||
return _create_error_response(e)
|
||||
|
||||
|
||||
@app.api_route("/flush_cache", methods=["GET", "POST"])
|
||||
async def flush_cache():
|
||||
"""Flush the radix cache."""
|
||||
|
||||
@@ -20,6 +20,7 @@ from typing import Optional
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from sglang.srt.custom_op import CustomOp
|
||||
from sglang.srt.distributed import (
|
||||
@@ -29,6 +30,7 @@ from sglang.srt.distributed import (
|
||||
)
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.utils import is_cuda, set_weight_attrs
|
||||
from sglang.utils import resolve_obj_by_qualname
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
|
||||
@@ -165,6 +167,23 @@ def get_act_fn(
|
||||
return act_fn
|
||||
|
||||
|
||||
def get_cross_encoder_activation_function(config: PretrainedConfig):
|
||||
if (
|
||||
hasattr(config, "sbert_ce_default_activation_function")
|
||||
and config.sbert_ce_default_activation_function is not None
|
||||
):
|
||||
|
||||
function_name = config.sbert_ce_default_activation_function
|
||||
assert function_name.startswith("torch.nn.modules."), (
|
||||
"Loading of activation functions is restricted to "
|
||||
"torch.nn.modules for security reasons"
|
||||
)
|
||||
return resolve_obj_by_qualname(function_name)()
|
||||
else:
|
||||
# adapt bge-reranker
|
||||
return nn.Identity()
|
||||
|
||||
|
||||
if not _is_cuda:
|
||||
logger.info(
|
||||
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
|
||||
|
||||
@@ -3,10 +3,13 @@
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from sglang.srt.layers.activation import get_cross_encoder_activation_function
|
||||
from sglang.srt.model_executor.model_runner import ForwardBatch
|
||||
|
||||
|
||||
@@ -54,3 +57,56 @@ class Pooler(nn.Module):
|
||||
pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1)
|
||||
|
||||
return EmbeddingPoolerOutput(embeddings=pooled_data)
|
||||
|
||||
|
||||
class CrossEncodingPooler(nn.Module):
|
||||
"""A layer that pools specific information from hidden states.
|
||||
|
||||
This layer does the following:
|
||||
1. Extracts specific tokens or aggregates data based on pooling method.
|
||||
2. Normalizes output if specified.
|
||||
3. Returns structured results as `EmbeddingPoolerOutput`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
classifier: nn.Module,
|
||||
pooler: Optional[nn.Module] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.classifier = classifier
|
||||
self.pooler = pooler
|
||||
self.default_activation_function = get_cross_encoder_activation_function(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> EmbeddingPoolerOutput:
|
||||
"""Pools sentence pair scores from the hidden_states."""
|
||||
|
||||
prompt_lens = forward_batch.extend_seq_lens
|
||||
|
||||
offset = 0
|
||||
pooled_data_lst = []
|
||||
for prompt_len in prompt_lens:
|
||||
pooled_data_i = hidden_states[offset : offset + prompt_len]
|
||||
|
||||
if self.pooler is not None:
|
||||
final_shape_tensor = self.pooler(pooled_data_i, forward_batch)
|
||||
else:
|
||||
final_shape_tensor = self.classifier(pooled_data_i)
|
||||
|
||||
pooled_data_lst.append(final_shape_tensor)
|
||||
offset += prompt_len
|
||||
|
||||
pooled_output = torch.stack(pooled_data_lst)
|
||||
|
||||
if self.pooler is not None:
|
||||
# apply classifier once on the full batch if possible
|
||||
pooled_output = self.classifier(pooled_output)
|
||||
|
||||
scores = self.default_activation_function(pooled_output).squeeze(-1)
|
||||
|
||||
return EmbeddingPoolerOutput(embeddings=scores)
|
||||
|
||||
@@ -481,7 +481,7 @@ class TokenizedGenerateReqInput:
|
||||
@dataclass
|
||||
class EmbeddingReqInput:
|
||||
# The input prompt. It can be a single prompt or a batch of prompts.
|
||||
text: Optional[Union[List[str], str]] = None
|
||||
text: Optional[Union[List[List[str]], List[str], str]] = None
|
||||
# The image input. It can be an image instance, file name, URL, or base64 encoded string.
|
||||
# Can be formatted as:
|
||||
# - Single image for a single request
|
||||
@@ -505,6 +505,8 @@ class EmbeddingReqInput:
|
||||
log_metrics: bool = True
|
||||
# The modalities of the image data [image, multi-images, video]
|
||||
modalities: Optional[List[str]] = None
|
||||
# For cross-encoder requests
|
||||
is_cross_encoder_request: bool = False
|
||||
|
||||
def contains_mm_input(self) -> bool:
|
||||
return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
|
||||
@@ -564,6 +566,16 @@ class EmbeddingReqInput:
|
||||
return self.rid
|
||||
|
||||
def __getitem__(self, i):
|
||||
if self.is_cross_encoder_request:
|
||||
return EmbeddingReqInput(
|
||||
text=[self.text[i]] if self.text is not None else None,
|
||||
input_ids=None,
|
||||
image_data=None,
|
||||
sampling_params=self.sampling_params[i],
|
||||
rid=self.rid[i],
|
||||
is_cross_encoder_request=True,
|
||||
)
|
||||
|
||||
return EmbeddingReqInput(
|
||||
text=self.text[i] if self.text is not None else None,
|
||||
input_ids=self.input_ids[i] if self.input_ids is not None else None,
|
||||
@@ -583,6 +595,8 @@ class TokenizedEmbeddingReqInput:
|
||||
input_ids: List[int]
|
||||
# The image inputs
|
||||
image_inputs: dict
|
||||
# The token type ids
|
||||
token_type_ids: List[int]
|
||||
# Dummy sampling params for compatibility
|
||||
sampling_params: SamplingParams
|
||||
|
||||
@@ -847,6 +861,12 @@ class SetInternalStateReq:
|
||||
server_args: Dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class V1RerankReqInput:
|
||||
query: str
|
||||
documents: List[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SetInternalStateReqOutput:
|
||||
updated: bool
|
||||
|
||||
@@ -445,6 +445,7 @@ class Req:
|
||||
origin_input_ids_unpadded: Optional[Tuple[int]] = None,
|
||||
lora_path: Optional[str] = None,
|
||||
input_embeds: Optional[List[List[float]]] = None,
|
||||
token_type_ids: List[int] = None,
|
||||
session_id: Optional[str] = None,
|
||||
custom_logit_processor: Optional[str] = None,
|
||||
return_hidden_states: bool = False,
|
||||
@@ -470,6 +471,9 @@ class Req:
|
||||
self.session_id = session_id
|
||||
self.input_embeds = input_embeds
|
||||
|
||||
# for corss-endoder model
|
||||
self.token_type_ids = token_type_ids
|
||||
|
||||
# Sampling info
|
||||
if isinstance(sampling_params.custom_params, dict):
|
||||
sampling_params = copy.copy(sampling_params)
|
||||
@@ -841,6 +845,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
# Batched arguments to model runner
|
||||
input_ids: torch.Tensor = None # shape: [b], int64
|
||||
input_embeds: torch.Tensor = None # shape: [b, hidden_size], float32
|
||||
token_type_ids: torch.Tensor = None # shape: [b], int64
|
||||
req_pool_indices: torch.Tensor = None # shape: [b], int64
|
||||
seq_lens: torch.Tensor = None # shape: [b], int64
|
||||
# The output locations of the KV cache
|
||||
@@ -1142,6 +1147,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
prefix_lens = [len(r.prefix_indices) for r in reqs]
|
||||
extend_lens = [r.extend_input_len for r in reqs]
|
||||
|
||||
token_type_ids = [
|
||||
r.token_type_ids for r in reqs if r.token_type_ids is not None
|
||||
]
|
||||
|
||||
req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
@@ -1154,6 +1163,13 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
prefix_lens_tensor = torch.tensor(
|
||||
prefix_lens, dtype=torch.int64, device=self.device
|
||||
)
|
||||
|
||||
token_type_ids_tensor = None
|
||||
if len(token_type_ids) > 0:
|
||||
token_type_ids_tensor = torch.tensor(
|
||||
sum(token_type_ids, []), dtype=torch.int64
|
||||
).to(self.device, non_blocking=True)
|
||||
|
||||
extend_lens_tensor = seq_lens_tensor - prefix_lens_tensor
|
||||
|
||||
# Copy prefix and do some basic check
|
||||
@@ -1269,6 +1285,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
self.multimodal_inputs = multimodal_inputs
|
||||
self.token_type_ids = token_type_ids_tensor
|
||||
self.seq_lens_sum = sum(seq_lens)
|
||||
|
||||
if self.return_logprob:
|
||||
@@ -1714,6 +1731,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
lora_paths=[req.lora_path for req in self.reqs],
|
||||
sampling_info=self.sampling_info,
|
||||
input_embeds=self.input_embeds,
|
||||
token_type_ids=self.token_type_ids,
|
||||
spec_algorithm=self.spec_algorithm,
|
||||
spec_info=self.spec_info,
|
||||
capture_hidden_mode=(
|
||||
@@ -1807,6 +1825,9 @@ class ModelWorkerBatch:
|
||||
# The input Embeds
|
||||
input_embeds: Optional[torch.tensor] = None
|
||||
|
||||
# For corss-encoder model
|
||||
token_type_ids: Optional[torch.Tensor] = None
|
||||
|
||||
# Speculative decoding
|
||||
spec_algorithm: SpeculativeAlgorithm = None
|
||||
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
|
||||
|
||||
@@ -1150,6 +1150,7 @@ class Scheduler(
|
||||
recv_req.input_text,
|
||||
recv_req.input_ids,
|
||||
recv_req.sampling_params,
|
||||
token_type_ids=recv_req.token_type_ids,
|
||||
)
|
||||
req.tokenizer = self.tokenizer
|
||||
|
||||
|
||||
@@ -459,6 +459,10 @@ class TokenizerManager:
|
||||
# Tokenize
|
||||
input_embeds = None
|
||||
input_text = obj.text
|
||||
token_type_ids = None
|
||||
is_cross_encoder_request = (
|
||||
isinstance(obj, EmbeddingReqInput) and obj.is_cross_encoder_request
|
||||
)
|
||||
if obj.input_embeds is not None:
|
||||
if not self.server_args.disable_radix_cache:
|
||||
raise ValueError(
|
||||
@@ -477,7 +481,14 @@ class TokenizerManager:
|
||||
"accept text prompts. Please provide input_ids or re-initialize "
|
||||
"the engine with skip_tokenizer_init=False."
|
||||
)
|
||||
input_ids = self.tokenizer.encode(input_text)
|
||||
encoded = self.tokenizer(
|
||||
input_text, return_token_type_ids=is_cross_encoder_request
|
||||
)
|
||||
|
||||
input_ids = encoded["input_ids"]
|
||||
if is_cross_encoder_request:
|
||||
input_ids = encoded["input_ids"][0]
|
||||
token_type_ids = encoded.get("token_type_ids", [None])[0]
|
||||
|
||||
if self.mm_processor and obj.contains_mm_input():
|
||||
image_inputs = await self.mm_processor.process_mm_data_async(
|
||||
@@ -493,7 +504,7 @@ class TokenizerManager:
|
||||
|
||||
self._validate_token_len(obj, input_ids)
|
||||
return self._create_tokenized_object(
|
||||
obj, input_text, input_ids, input_embeds, image_inputs
|
||||
obj, input_text, input_ids, input_embeds, image_inputs, token_type_ids
|
||||
)
|
||||
|
||||
def _validate_token_len(
|
||||
@@ -532,6 +543,7 @@ class TokenizerManager:
|
||||
input_ids: List[int],
|
||||
input_embeds: Optional[Union[List[float], None]] = None,
|
||||
image_inputs: Optional[Dict] = None,
|
||||
token_type_ids: Optional[List[int]] = None,
|
||||
) -> Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]:
|
||||
"""Create a tokenized request object from common parameters."""
|
||||
|
||||
@@ -592,6 +604,7 @@ class TokenizerManager:
|
||||
input_text,
|
||||
input_ids,
|
||||
image_inputs,
|
||||
token_type_ids,
|
||||
sampling_params,
|
||||
)
|
||||
|
||||
|
||||
@@ -224,6 +224,9 @@ class ForwardBatch:
|
||||
# For input embeddings
|
||||
input_embeds: Optional[torch.tensor] = None
|
||||
|
||||
# For cross-encoder model
|
||||
token_type_ids: Optional[torch.Tensor] = None
|
||||
|
||||
# Sampling info
|
||||
sampling_info: SamplingBatchInfo = None
|
||||
|
||||
@@ -300,6 +303,7 @@ class ForwardBatch:
|
||||
spec_info=batch.spec_info,
|
||||
capture_hidden_mode=batch.capture_hidden_mode,
|
||||
input_embeds=batch.input_embeds,
|
||||
token_type_ids=batch.token_type_ids,
|
||||
tbo_split_seq_index=batch.tbo_split_seq_index,
|
||||
)
|
||||
device = model_runner.device
|
||||
@@ -356,8 +360,8 @@ class ForwardBatch:
|
||||
ret.extend_prefix_lens = torch.tensor(
|
||||
batch.extend_prefix_lens, dtype=torch.int32
|
||||
).to(device, non_blocking=True)
|
||||
ret.extend_num_tokens = batch.extend_num_tokens
|
||||
if support_triton(model_runner.server_args.attention_backend):
|
||||
ret.extend_num_tokens = batch.extend_num_tokens
|
||||
positions, ret.extend_start_loc = compute_position_triton(
|
||||
ret.extend_prefix_lens,
|
||||
ret.extend_seq_lens,
|
||||
|
||||
@@ -11,12 +11,13 @@ from sglang.srt.layers.linear import (
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
|
||||
from sglang.srt.layers.pooler import CrossEncodingPooler, Pooler, PoolingType
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import AttentionType, RadixAttention
|
||||
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
BertConfig = None
|
||||
|
||||
@@ -50,7 +51,8 @@ class BertEmbedding(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
input_shape = input_ids.size()
|
||||
|
||||
@@ -58,11 +60,14 @@ class BertEmbedding(nn.Module):
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
|
||||
# Position embeddings.
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
position_embeddings = self.position_embeddings(positions)
|
||||
|
||||
token_type_ids = torch.zeros(
|
||||
input_shape, dtype=torch.long, device=inputs_embeds.device
|
||||
)
|
||||
token_type_ids = forward_batch.token_type_ids
|
||||
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros(
|
||||
input_shape, dtype=torch.long, device=inputs_embeds.device
|
||||
)
|
||||
|
||||
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||
|
||||
@@ -71,6 +76,25 @@ class BertEmbedding(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
class BertPooler(nn.Module):
|
||||
|
||||
def __init__(self, config: BertConfig):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.activation = nn.Tanh()
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
|
||||
) -> torch.Tensor:
|
||||
# simply taking the hidden state corresponding
|
||||
first_token_tensor = hidden_states[0, :]
|
||||
|
||||
pooled_output = self.dense(first_token_tensor)
|
||||
pooled_output = self.activation(pooled_output)
|
||||
|
||||
return pooled_output
|
||||
|
||||
|
||||
class BertEncoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
@@ -113,6 +137,8 @@ class BertLayer(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.layer_id = layer_id
|
||||
|
||||
self.attention = BertAttention(
|
||||
hidden_size=config.hidden_size,
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
@@ -142,6 +168,7 @@ class BertLayer(nn.Module):
|
||||
attn_output = self.attention(hidden_states, forward_batch)
|
||||
intermediate_output = self.intermediate(attn_output)
|
||||
output = self.output(intermediate_output, attn_output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@@ -326,16 +353,23 @@ class BertModel(nn.Module):
|
||||
*,
|
||||
config: BertConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
use_bert_pooler: bool = False,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.use_bert_pooler = use_bert_pooler
|
||||
self.config = config
|
||||
self.embeddings = BertEmbedding(config)
|
||||
self.encoder = BertEncoder(
|
||||
config=config, quant_config=quant_config, prefix=f"encoder"
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("encoder", prefix),
|
||||
)
|
||||
self.pooler = (
|
||||
BertPooler(config)
|
||||
if self.use_bert_pooler
|
||||
else Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
)
|
||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
# self.pooler = BertPooler(config)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
@@ -351,11 +385,16 @@ class BertModel(nn.Module):
|
||||
|
||||
hidden_states = self.embeddings(
|
||||
input_ids=input_ids,
|
||||
position_ids=positions,
|
||||
positions=positions,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
|
||||
hidden_states = self.encoder(hidden_states, forward_batch=forward_batch)
|
||||
return self.pooler(hidden_states, forward_batch)
|
||||
|
||||
if not self.use_bert_pooler:
|
||||
hidden_states = self.pooler(hidden_states, forward_batch)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
@@ -368,7 +407,7 @@ class BertModel(nn.Module):
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
name = name.replace("self", "self_attn")
|
||||
if "pooler" in name:
|
||||
if not self.use_bert_pooler and "pooler" in name:
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
|
||||
@@ -395,4 +434,65 @@ class Contriever(BertModel):
|
||||
pass
|
||||
|
||||
|
||||
EntryClass = [BertModel, Contriever]
|
||||
class BertForSequenceClassification(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
config: BertConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.num_labels = config.num_labels
|
||||
self.bert = BertModel(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
use_bert_pooler=True,
|
||||
prefix=add_prefix("bert", prefix),
|
||||
)
|
||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||
self.pooler = CrossEncodingPooler(config, self.classifier, self.bert.pooler)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
self_weights = []
|
||||
|
||||
def weight_filter():
|
||||
for name, weight in weights:
|
||||
if name.startswith("bert."):
|
||||
yield (name[len("bert.") :], weight)
|
||||
else:
|
||||
self_weights.append((name, weight))
|
||||
|
||||
self.bert.load_weights(weight_filter())
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
||||
for name, loaded_weight in self_weights:
|
||||
if name.startswith("classifier"):
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
get_embedding: bool = False,
|
||||
) -> torch.Tensor:
|
||||
assert get_embedding == True
|
||||
|
||||
hidden_states = self.bert(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
forward_batch=forward_batch,
|
||||
input_embeds=input_embeds,
|
||||
get_embedding=get_embedding,
|
||||
)
|
||||
return self.pooler(hidden_states, forward_batch)
|
||||
|
||||
|
||||
EntryClass = [BertModel, Contriever, BertForSequenceClassification]
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Iterable, Optional, Tuple
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from sglang.srt.layers.pooler import Pooler, PoolingType
|
||||
from sglang.srt.layers.pooler import CrossEncodingPooler, Pooler, PoolingType
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
@@ -16,6 +16,23 @@ from sglang.srt.models.bert import BertEncoder
|
||||
RobertaConfig = None
|
||||
|
||||
|
||||
# Adapted from transformers
|
||||
class RobertaClassificationHead(nn.Module):
|
||||
"""Head for sentence-level classification tasks."""
|
||||
|
||||
def __init__(self, config: RobertaConfig):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
def forward(self, features, **kwargs):
|
||||
x = features[0, :] # take <s> token (equiv. to [CLS])
|
||||
x = self.dense(x)
|
||||
x = torch.tanh(x)
|
||||
x = self.out_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class RobertaEmbedding(nn.Module):
|
||||
|
||||
def __init__(self, config: RobertaConfig):
|
||||
@@ -51,8 +68,7 @@ class RobertaEmbedding(nn.Module):
|
||||
input_ids: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
inputs_embeds=None,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
input_shape = input_ids.size()
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
@@ -82,6 +98,8 @@ class RobertaEmbedding(nn.Module):
|
||||
|
||||
# Position embeddings.
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
|
||||
token_type_ids = forward_batch.token_type_ids
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros(
|
||||
input_shape, dtype=torch.long, device=inputs_embeds.device
|
||||
@@ -93,20 +111,25 @@ class RobertaEmbedding(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
class XLMRobertaModel(nn.Module):
|
||||
class XLMRobertaBaseModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
config: RobertaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
add_pooling_layer: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.embeddings = RobertaEmbedding(config)
|
||||
self.encoder = BertEncoder(config=config, quant_config=quant_config, prefix="")
|
||||
self.pooler = Pooler(pooling_type=PoolingType.CLS, normalize=True)
|
||||
self.pooler = (
|
||||
Pooler(pooling_type=PoolingType.CLS, normalize=True)
|
||||
if add_pooling_layer
|
||||
else None
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
@@ -124,11 +147,12 @@ class XLMRobertaModel(nn.Module):
|
||||
input_ids=input_ids,
|
||||
position_ids=positions,
|
||||
seq_lens=forward_batch.seq_lens,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
|
||||
hidden_states = self.encoder(hidden_states, forward_batch=forward_batch)
|
||||
pooler_out = self.pooler(hidden_states, forward_batch)
|
||||
return pooler_out
|
||||
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
@@ -141,7 +165,7 @@ class XLMRobertaModel(nn.Module):
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
name = name.replace("self", "self_attn")
|
||||
if "pooler" in name:
|
||||
if self.pooler is None and "pooler" in name:
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
|
||||
@@ -175,4 +199,88 @@ def create_position_ids_from_input_ids(
|
||||
return incremental_indices.long() + padding_idx
|
||||
|
||||
|
||||
EntryClass = [XLMRobertaModel]
|
||||
class XLMRobertaModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
config: RobertaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.roberta = XLMRobertaBaseModel(
|
||||
config=config, quant_config=quant_config, prefix=prefix
|
||||
)
|
||||
self.pooler = Pooler(pooling_type=PoolingType.CLS, normalize=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
get_embedding: bool = False,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.roberta(
|
||||
input_ids, positions, forward_batch, input_embeds, get_embedding
|
||||
)
|
||||
return self.pooler(hidden_states, forward_batch)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
self.roberta.load_weights(weights)
|
||||
|
||||
|
||||
class XLMRobertaForSequenceClassification(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
config: RobertaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.roberta = XLMRobertaBaseModel(
|
||||
config=config, quant_config=quant_config, prefix=prefix
|
||||
)
|
||||
self.classifier = RobertaClassificationHead(config)
|
||||
self.pooler = CrossEncodingPooler(config, self.classifier, self.roberta.pooler)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
get_embedding: bool = True,
|
||||
) -> torch.Tensor:
|
||||
assert (
|
||||
get_embedding
|
||||
), "XLMRobertaForSequenceClassification is only used for rerank"
|
||||
|
||||
hidden_states = self.roberta(
|
||||
input_ids, positions, forward_batch, input_embeds, get_embedding
|
||||
)
|
||||
return self.pooler(hidden_states, forward_batch)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
self_weights = []
|
||||
|
||||
def weight_filter():
|
||||
for name, weight in weights:
|
||||
if name.startswith("roberta."):
|
||||
yield (name[len("roberta.") :], weight)
|
||||
else:
|
||||
self_weights.append((name, weight))
|
||||
|
||||
self.roberta.load_weights(weight_filter())
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
||||
for name, loaded_weight in self_weights:
|
||||
if name.startswith("classifier"):
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
EntryClass = [XLMRobertaModel, XLMRobertaForSequenceClassification]
|
||||
|
||||
@@ -41,7 +41,11 @@ from sglang.srt.conversation import (
|
||||
register_conv_template,
|
||||
)
|
||||
from sglang.srt.function_call.function_call_parser import FunctionCallParser
|
||||
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
|
||||
from sglang.srt.managers.io_struct import (
|
||||
EmbeddingReqInput,
|
||||
GenerateReqInput,
|
||||
V1RerankReqInput,
|
||||
)
|
||||
from sglang.srt.openai_api.protocol import (
|
||||
BatchRequest,
|
||||
BatchResponse,
|
||||
@@ -69,6 +73,7 @@ from sglang.srt.openai_api.protocol import (
|
||||
FunctionResponse,
|
||||
LogProbs,
|
||||
MultimodalEmbeddingInput,
|
||||
RerankResponse,
|
||||
ScoringRequest,
|
||||
ScoringResponse,
|
||||
ToolCall,
|
||||
@@ -2020,6 +2025,64 @@ async def v1_embeddings(tokenizer_manager, raw_request: Request):
|
||||
return response
|
||||
|
||||
|
||||
def v1_rerank_request(obj: V1RerankReqInput):
|
||||
if obj.query is None:
|
||||
raise ValueError("query is required")
|
||||
if obj.documents is None or len(obj.documents) == 0:
|
||||
raise ValueError("documents is required")
|
||||
|
||||
pairs = []
|
||||
for doc in obj.documents:
|
||||
pairs.append([obj.query, doc])
|
||||
|
||||
adapted_request = EmbeddingReqInput(
|
||||
text=pairs,
|
||||
is_cross_encoder_request=True,
|
||||
)
|
||||
|
||||
return adapted_request
|
||||
|
||||
|
||||
def v1_rerank_response(ret, obj: V1RerankReqInput):
|
||||
|
||||
response = []
|
||||
for idx, ret_item in enumerate(ret):
|
||||
response.append(
|
||||
RerankResponse(
|
||||
score=ret[idx]["embedding"],
|
||||
document=obj.documents[idx],
|
||||
index=idx,
|
||||
meta_info=ret[idx]["meta_info"],
|
||||
)
|
||||
)
|
||||
|
||||
response.sort(key=lambda x: x.score, reverse=True)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
async def v1_rerank(tokenizer_manager, obj: V1RerankReqInput, raw_request: Request):
|
||||
adapted_request = v1_rerank_request(obj)
|
||||
|
||||
try:
|
||||
ret = await tokenizer_manager.generate_request(
|
||||
adapted_request, raw_request
|
||||
).__anext__()
|
||||
|
||||
except ValueError as e:
|
||||
return create_error_response(str(e))
|
||||
|
||||
if not isinstance(ret, list):
|
||||
ret = [ret]
|
||||
|
||||
response = v1_rerank_response(
|
||||
ret,
|
||||
obj,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def to_openai_style_logprobs(
|
||||
input_token_logprobs=None,
|
||||
output_token_logprobs=None,
|
||||
|
||||
@@ -539,6 +539,13 @@ class ScoringResponse(BaseModel):
|
||||
object: str = "scoring"
|
||||
|
||||
|
||||
class RerankResponse(BaseModel):
|
||||
score: float
|
||||
document: str
|
||||
index: int
|
||||
meta_info: Optional[dict] = None
|
||||
|
||||
|
||||
def exclude_if_none(obj, field_names: List[str]):
|
||||
omit_if_none_fields = {k for k, v in obj.model_fields.items() if k in field_names}
|
||||
return {k: v for k, v in obj if k not in omit_if_none_fields or v is not None}
|
||||
|
||||
@@ -42,6 +42,21 @@ DEFAULT_PROMPTS = [
|
||||
# the output of gemma-2-2b from SRT is unstable on the commented prompt
|
||||
# "The capital of France is",
|
||||
]
|
||||
TEST_RERANK_QUERY_DOCS = [
|
||||
{
|
||||
"query": "How many people live in Berlin?",
|
||||
"documents": [
|
||||
"Berlin is well known for its museums.",
|
||||
],
|
||||
},
|
||||
{
|
||||
"query": "How many people live in Berlin?",
|
||||
"documents": [
|
||||
"Berlin had a population of 3,520,031 registered inhabitants in an area of 891.82 square kilometers.",
|
||||
"Berlin is well known for its museums.",
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
dirpath = os.path.dirname(__file__)
|
||||
with open(os.path.join(dirpath, "long_prompt.txt"), "r") as f:
|
||||
@@ -241,7 +256,7 @@ class HFRunner:
|
||||
self.model = _get_sentence_transformer_embedding_model(
|
||||
model_path, torch_dtype
|
||||
)
|
||||
elif self.model_type == "reward":
|
||||
elif self.model_type == "reward" or self.model_type == "cross_encoder":
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
|
||||
self.model = AutoModelForSequenceClassification.from_pretrained(
|
||||
@@ -303,6 +318,15 @@ class HFRunner:
|
||||
else:
|
||||
logits = self.model.encode(prompts).tolist()
|
||||
out_queue.put(ModelOutput(embed_logits=logits))
|
||||
elif self.model_type == "cross_encoder":
|
||||
inputs = self.tokenizer(
|
||||
prompts, padding=True, return_tensors="pt"
|
||||
).to("cuda")
|
||||
scores = self.model(**inputs).logits
|
||||
scores = scores.squeeze().tolist()
|
||||
if not isinstance(scores, list):
|
||||
scores = [scores]
|
||||
out_queue.put(ModelOutput(scores=scores))
|
||||
|
||||
elif self.model_type == "reward":
|
||||
scores = []
|
||||
@@ -322,7 +346,9 @@ class HFRunner:
|
||||
|
||||
def forward(
|
||||
self,
|
||||
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
||||
prompts: Union[
|
||||
List[List[str]], List[str], List[torch.Tensor]
|
||||
] = DEFAULT_PROMPTS,
|
||||
image_data: Optional[List[str]] = None,
|
||||
max_new_tokens: int = 8,
|
||||
lora_paths: Optional[List[str]] = None,
|
||||
@@ -526,7 +552,9 @@ class SRTRunner:
|
||||
|
||||
def forward(
|
||||
self,
|
||||
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
||||
prompts: Union[
|
||||
List[List[str]], List[str], List[torch.Tensor]
|
||||
] = DEFAULT_PROMPTS,
|
||||
image_data: Optional[List[str]] = None,
|
||||
max_new_tokens: int = 8,
|
||||
lora_paths: Optional[List[str]] = None,
|
||||
@@ -552,6 +580,13 @@ class SRTRunner:
|
||||
else:
|
||||
logits = [response["embedding"]]
|
||||
return ModelOutput(embed_logits=logits)
|
||||
# cross encoder model
|
||||
elif self.model_type == "cross_encoder":
|
||||
response = self.engine.rerank(prompts)
|
||||
if not isinstance(response, list):
|
||||
response = [response]
|
||||
scores = [x["embedding"] for x in response]
|
||||
return ModelOutput(scores=scores)
|
||||
# reward model
|
||||
else:
|
||||
response = self.engine.encode(prompts)
|
||||
|
||||
@@ -41,6 +41,8 @@ DEFAULT_MOE_MODEL_NAME_FOR_TEST = "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST = "Qwen/Qwen1.5-MoE-A2.7B"
|
||||
|
||||
# MLA test models
|
||||
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST = "Alibaba-NLP/gte-Qwen2-1.5B-instruct"
|
||||
DEFAULT_SMALL_CROSS_ENCODER_MODEL_NAME_FOR_TEST = "cross-encoder/ms-marco-MiniLM-L6-v2"
|
||||
DEFAULT_MLA_MODEL_NAME_FOR_TEST = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
|
||||
DEFAULT_MLA_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8"
|
||||
DEFAULT_MODEL_NAME_FOR_TEST_MLA = "lmsys/sglang-ci-dsv3-test"
|
||||
|
||||
@@ -512,3 +512,12 @@ async def async_stream_and_merge(llm, prompt, sampling_params):
|
||||
cleaned_chunk = trim_overlap(final_text, chunk_text)
|
||||
final_text += cleaned_chunk
|
||||
yield cleaned_chunk # yield the non-overlapping portion
|
||||
|
||||
|
||||
def resolve_obj_by_qualname(qualname: str) -> Any:
|
||||
"""
|
||||
Resolve an object by its fully qualified name.
|
||||
"""
|
||||
module_name, obj_name = qualname.rsplit(".", 1)
|
||||
module = importlib.import_module(module_name)
|
||||
return getattr(module, obj_name)
|
||||
|
||||
Reference in New Issue
Block a user