From 164302c7dfe3fd3e89f50437ce70af5caf1e676d Mon Sep 17 00:00:00 2001 From: Christian Bahls Date: Wed, 22 Oct 2025 22:46:16 +0200 Subject: [PATCH] Implement BGE-M3 Sparse Embeddings in SGLang (#10869) Co-authored-by: Christian Bahls Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- python/sglang/srt/environ.py | 3 + python/sglang/srt/layers/sparse_pooler.py | 98 +++++++++++++++++++ python/sglang/srt/managers/io_struct.py | 2 +- .../scheduler_output_processor_mixin.py | 17 +++- python/sglang/srt/model_loader/loader.py | 18 +++- python/sglang/srt/models/roberta.py | 57 ++++++++++- 6 files changed, 187 insertions(+), 8 deletions(-) create mode 100644 python/sglang/srt/layers/sparse_pooler.py diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index 153e147cd..0f76743b3 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -237,6 +237,9 @@ class Envs: SGLANG_KT_AMX_METHOD = EnvStr(None) SGLANG_KT_MOE_CHUNKED_PREFILL_SIZE = EnvInt(None) + # Sparse Embeddings + SGLANG_EMBEDDINGS_SPARSE_HEAD = EnvStr(None) + # fmt: on diff --git a/python/sglang/srt/layers/sparse_pooler.py b/python/sglang/srt/layers/sparse_pooler.py new file mode 100644 index 000000000..331b23c94 --- /dev/null +++ b/python/sglang/srt/layers/sparse_pooler.py @@ -0,0 +1,98 @@ +from dataclasses import dataclass + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import PretrainedConfig + +from sglang.srt.model_executor.model_runner import ForwardBatch + + +@dataclass +class SparseEmbeddingOutput: + embeddings: torch.Tensor # [batch_size, vocab_size] + + +class SparsePooler(nn.Module): + """A layer that pools hidden states into sparse vocabulary-space embeddings. + + This layer does the following: + 1. Applies a linear transformation + ReLU to get token-level weights + 2. Maps these weights to vocabulary positions using token IDs + 3. Aggregates weights for repeated tokens using max pooling + 4. Returns sparse embeddings in vocabulary space + + Attributes: + config: Model configuration containing vocab_size and hidden_size + sparse_linear: Linear layer for computing token weights + vocab_size: Size of vocabulary for output embeddings + """ + + def __init__(self, config: PretrainedConfig): + super().__init__() + + # Validate required attributes + if not hasattr(config, "vocab_size"): + raise AttributeError( + f"Config {type(config)} missing required 'vocab_size' attribute" + ) + if not hasattr(config, "hidden_size"): + raise AttributeError( + f"Config {type(config)} missing required 'hidden_size' attribute" + ) + + self.vocab_size = config.vocab_size + self.sparse_linear = nn.Linear(config.hidden_size, 1) + self._weights_loaded = False + + def forward( + self, hidden_states: torch.Tensor, forward_batch: ForwardBatch + ) -> SparseEmbeddingOutput: + """ + Forward pass for sparse pooling. + + Args: + hidden_states: Packed sequence hidden states [total_tokens, hidden_size] + forward_batch: Batch information with sequence lengths and input_ids + + Returns: + SparseEmbeddingOutput with embeddings of shape [batch_size, vocab_size] + """ + if not self._weights_loaded: + raise ValueError( + "Sparse pooling weights not loaded. Call load_weights() first" + ) + + # Apply sparse linear + ReLU to get token weights + token_weights = F.relu(self.sparse_linear(hidden_states)).squeeze( + -1 + ) # [total_tokens] + + # Create batch indices for packed sequences + batch_indices = torch.repeat_interleave( + torch.arange( + len(forward_batch.extend_seq_lens), device=hidden_states.device + ), + forward_batch.extend_seq_lens, + ) + + # Initialize sparse embedding output + sparse_embedding = torch.zeros( + len(forward_batch.extend_seq_lens), + self.vocab_size, + dtype=token_weights.dtype, + device=token_weights.device, + ) + + # Map to vocabulary space using scatter_reduce with amax + flat_indices = batch_indices * self.vocab_size + forward_batch.input_ids + sparse_embedding.view(-1).scatter_reduce_( + 0, flat_indices, token_weights, reduce="amax" + ) + + return SparseEmbeddingOutput(embeddings=sparse_embedding) + + def load_weights(self, state_dict: dict): + """Load weights from state dict (called by the model).""" + self.sparse_linear.load_state_dict(state_dict) + self._weights_loaded = True diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index cd67d4dc3..849204aad 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -961,7 +961,7 @@ class BatchEmbeddingOutput(BaseBatchReq): # The finish reason finished_reasons: List[BaseFinishReason] # The output embedding - embeddings: List[List[float]] + embeddings: Union[List[List[float]], List[Dict[int, float]]] # Token counts prompt_tokens: List[int] cached_tokens: List[int] diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index 02b62c0e8..f63fa8179 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union import torch from sglang.srt.disaggregation.utils import DisaggregationMode +from sglang.srt.environ import envs from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.io_struct import ( AbortReq, @@ -175,7 +176,21 @@ class SchedulerOutputProcessorMixin: logprob_pt += num_input_logprobs else: # embedding or reward model - embeddings = result.embeddings.tolist() + is_sparse = envs.SGLANG_EMBEDDINGS_SPARSE_HEAD.is_set() + + embeddings = result.embeddings + + if is_sparse: + batch_ids, token_ids = embeddings.indices() + values = embeddings.values() + + embeddings = [{} for _ in range(embeddings.size(0))] + for i in range(batch_ids.shape[0]): + embeddings[batch_ids[i].item()][token_ids[i].item()] = values[ + i + ].item() + else: + embeddings = embeddings.tolist() # Check finish conditions for i, req in enumerate(batch.reqs): diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index c8ef20fe3..6134f24ba 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -77,6 +77,7 @@ from sglang.srt.model_loader.utils import ( DEFAULT_GPU_MEMORY_FRACTION_FOR_CALIBRATION = ( 0.8 # Reserve 20% GPU memory headroom for ModelOpt calibration ) +from sglang.srt.environ import envs from sglang.srt.model_loader.weight_utils import ( download_safetensors_index_file_from_hf, download_weights_from_hf, @@ -244,10 +245,19 @@ def _initialize_model( quant_config = _get_quantization_config( model_config, load_config, packed_modules_mapping ) - return model_class( - config=model_config.hf_config, - quant_config=quant_config, - ) + + # Build kwargs conditionally + kwargs = { + "config": model_config.hf_config, + "quant_config": quant_config, + } + + # Only add sparse head kwargs if envs.SGLANG_EMBEDDINGS_SPARSE_HEAD.is_set() + if envs.SGLANG_EMBEDDINGS_SPARSE_HEAD.is_set(): + kwargs["sparse_head"] = envs.SGLANG_EMBEDDINGS_SPARSE_HEAD.value + kwargs["model_path"] = model_config.model_path + + return model_class(**kwargs) class BaseModelLoader(ABC): diff --git a/python/sglang/srt/models/roberta.py b/python/sglang/srt/models/roberta.py index 9fad5cfa3..c81590320 100644 --- a/python/sglang/srt/models/roberta.py +++ b/python/sglang/srt/models/roberta.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 +import os from typing import Iterable, Optional, Tuple import torch @@ -7,10 +8,12 @@ from torch import nn from sglang.srt.layers.pooler import CrossEncodingPooler, Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.sparse_pooler import SparsePooler from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.bert import BertEncoder +from sglang.srt.utils.hf_transformers_utils import download_from_hf RobertaConfig = None @@ -205,12 +208,29 @@ class XLMRobertaModel(nn.Module): config: RobertaConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + sparse_head: Optional[str] = None, + model_path: Optional[str] = None, ): super().__init__() self.roberta = XLMRobertaBaseModel( config=config, quant_config=quant_config, prefix=prefix ) - self.pooler = Pooler(pooling_type=PoolingType.CLS, normalize=True) + if sparse_head is not None: + self._is_sparse = True + self._model_path = model_path + self._sparse_head = sparse_head + self.pooler = SparsePooler(config=config) + # Zero out special tokens + self._special_tokens = [ + config.bos_token_id, + config.eos_token_id, + config.pad_token_id, + # self.config.unk_token_id # not available in the XLMRobertaConfig + ] + self._special_tokens = [t for t in self._special_tokens if t is not None] + else: + self._is_sparse = False + self.pooler = Pooler(pooling_type=PoolingType.CLS, normalize=True) def forward( self, @@ -223,11 +243,44 @@ class XLMRobertaModel(nn.Module): hidden_states = self.roberta( input_ids, positions, forward_batch, input_embeds, get_embedding ) - return self.pooler(hidden_states, forward_batch) + embeddings = self.pooler(hidden_states, forward_batch) + + if self._is_sparse: + for token_id in self._special_tokens: + embeddings.embeddings[:, token_id] = 0.0 + embeddings.embeddings = embeddings.embeddings.to_sparse() + + return embeddings def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): self.roberta.load_weights(weights) + if self._is_sparse: + sparse_dict = XLMRobertaModel._load_sparse_linear( + self._model_path, self._sparse_head + ) + self.pooler.load_weights(sparse_dict) + + @staticmethod + def _load_sparse_linear(model_path_or_dir: str, sparse_head: str) -> dict: + """ + Load sparse_head from local dir or HF Hub. + Returns a state_dict suitable for nn.Linear.load_state_dict(). + """ + if os.path.isdir(model_path_or_dir): + path = os.path.join(model_path_or_dir, sparse_head) + if not os.path.exists(path): + raise FileNotFoundError( + f"'{sparse_head}' not found in {model_path_or_dir}" + ) + else: + # remote → use SGLang HF utility + local_dir = download_from_hf(model_path_or_dir, allow_patterns=sparse_head) + path = os.path.join(local_dir, sparse_head) + + state_dict = torch.load(path) + return state_dict + class XLMRobertaForSequenceClassification(nn.Module): def __init__(