Implement BGE-M3 Sparse Embeddings in SGLang (#10869)
Co-authored-by: Christian Bahls <christian.bahls@planet-ai.de> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -237,6 +237,9 @@ class Envs:
|
||||
SGLANG_KT_AMX_METHOD = EnvStr(None)
|
||||
SGLANG_KT_MOE_CHUNKED_PREFILL_SIZE = EnvInt(None)
|
||||
|
||||
# Sparse Embeddings
|
||||
SGLANG_EMBEDDINGS_SPARSE_HEAD = EnvStr(None)
|
||||
|
||||
# fmt: on
|
||||
|
||||
|
||||
|
||||
98
python/sglang/srt/layers/sparse_pooler.py
Normal file
98
python/sglang/srt/layers/sparse_pooler.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from sglang.srt.model_executor.model_runner import ForwardBatch
|
||||
|
||||
|
||||
@dataclass
|
||||
class SparseEmbeddingOutput:
|
||||
embeddings: torch.Tensor # [batch_size, vocab_size]
|
||||
|
||||
|
||||
class SparsePooler(nn.Module):
|
||||
"""A layer that pools hidden states into sparse vocabulary-space embeddings.
|
||||
|
||||
This layer does the following:
|
||||
1. Applies a linear transformation + ReLU to get token-level weights
|
||||
2. Maps these weights to vocabulary positions using token IDs
|
||||
3. Aggregates weights for repeated tokens using max pooling
|
||||
4. Returns sparse embeddings in vocabulary space
|
||||
|
||||
Attributes:
|
||||
config: Model configuration containing vocab_size and hidden_size
|
||||
sparse_linear: Linear layer for computing token weights
|
||||
vocab_size: Size of vocabulary for output embeddings
|
||||
"""
|
||||
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
super().__init__()
|
||||
|
||||
# Validate required attributes
|
||||
if not hasattr(config, "vocab_size"):
|
||||
raise AttributeError(
|
||||
f"Config {type(config)} missing required 'vocab_size' attribute"
|
||||
)
|
||||
if not hasattr(config, "hidden_size"):
|
||||
raise AttributeError(
|
||||
f"Config {type(config)} missing required 'hidden_size' attribute"
|
||||
)
|
||||
|
||||
self.vocab_size = config.vocab_size
|
||||
self.sparse_linear = nn.Linear(config.hidden_size, 1)
|
||||
self._weights_loaded = False
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
|
||||
) -> SparseEmbeddingOutput:
|
||||
"""
|
||||
Forward pass for sparse pooling.
|
||||
|
||||
Args:
|
||||
hidden_states: Packed sequence hidden states [total_tokens, hidden_size]
|
||||
forward_batch: Batch information with sequence lengths and input_ids
|
||||
|
||||
Returns:
|
||||
SparseEmbeddingOutput with embeddings of shape [batch_size, vocab_size]
|
||||
"""
|
||||
if not self._weights_loaded:
|
||||
raise ValueError(
|
||||
"Sparse pooling weights not loaded. Call load_weights() first"
|
||||
)
|
||||
|
||||
# Apply sparse linear + ReLU to get token weights
|
||||
token_weights = F.relu(self.sparse_linear(hidden_states)).squeeze(
|
||||
-1
|
||||
) # [total_tokens]
|
||||
|
||||
# Create batch indices for packed sequences
|
||||
batch_indices = torch.repeat_interleave(
|
||||
torch.arange(
|
||||
len(forward_batch.extend_seq_lens), device=hidden_states.device
|
||||
),
|
||||
forward_batch.extend_seq_lens,
|
||||
)
|
||||
|
||||
# Initialize sparse embedding output
|
||||
sparse_embedding = torch.zeros(
|
||||
len(forward_batch.extend_seq_lens),
|
||||
self.vocab_size,
|
||||
dtype=token_weights.dtype,
|
||||
device=token_weights.device,
|
||||
)
|
||||
|
||||
# Map to vocabulary space using scatter_reduce with amax
|
||||
flat_indices = batch_indices * self.vocab_size + forward_batch.input_ids
|
||||
sparse_embedding.view(-1).scatter_reduce_(
|
||||
0, flat_indices, token_weights, reduce="amax"
|
||||
)
|
||||
|
||||
return SparseEmbeddingOutput(embeddings=sparse_embedding)
|
||||
|
||||
def load_weights(self, state_dict: dict):
|
||||
"""Load weights from state dict (called by the model)."""
|
||||
self.sparse_linear.load_state_dict(state_dict)
|
||||
self._weights_loaded = True
|
||||
@@ -961,7 +961,7 @@ class BatchEmbeddingOutput(BaseBatchReq):
|
||||
# The finish reason
|
||||
finished_reasons: List[BaseFinishReason]
|
||||
# The output embedding
|
||||
embeddings: List[List[float]]
|
||||
embeddings: Union[List[List[float]], List[Dict[int, float]]]
|
||||
# Token counts
|
||||
prompt_tokens: List[int]
|
||||
cached_tokens: List[int]
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||
import torch
|
||||
|
||||
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||
from sglang.srt.environ import envs
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
@@ -175,7 +176,21 @@ class SchedulerOutputProcessorMixin:
|
||||
logprob_pt += num_input_logprobs
|
||||
|
||||
else: # embedding or reward model
|
||||
embeddings = result.embeddings.tolist()
|
||||
is_sparse = envs.SGLANG_EMBEDDINGS_SPARSE_HEAD.is_set()
|
||||
|
||||
embeddings = result.embeddings
|
||||
|
||||
if is_sparse:
|
||||
batch_ids, token_ids = embeddings.indices()
|
||||
values = embeddings.values()
|
||||
|
||||
embeddings = [{} for _ in range(embeddings.size(0))]
|
||||
for i in range(batch_ids.shape[0]):
|
||||
embeddings[batch_ids[i].item()][token_ids[i].item()] = values[
|
||||
i
|
||||
].item()
|
||||
else:
|
||||
embeddings = embeddings.tolist()
|
||||
|
||||
# Check finish conditions
|
||||
for i, req in enumerate(batch.reqs):
|
||||
|
||||
@@ -77,6 +77,7 @@ from sglang.srt.model_loader.utils import (
|
||||
DEFAULT_GPU_MEMORY_FRACTION_FOR_CALIBRATION = (
|
||||
0.8 # Reserve 20% GPU memory headroom for ModelOpt calibration
|
||||
)
|
||||
from sglang.srt.environ import envs
|
||||
from sglang.srt.model_loader.weight_utils import (
|
||||
download_safetensors_index_file_from_hf,
|
||||
download_weights_from_hf,
|
||||
@@ -244,10 +245,19 @@ def _initialize_model(
|
||||
quant_config = _get_quantization_config(
|
||||
model_config, load_config, packed_modules_mapping
|
||||
)
|
||||
return model_class(
|
||||
config=model_config.hf_config,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
# Build kwargs conditionally
|
||||
kwargs = {
|
||||
"config": model_config.hf_config,
|
||||
"quant_config": quant_config,
|
||||
}
|
||||
|
||||
# Only add sparse head kwargs if envs.SGLANG_EMBEDDINGS_SPARSE_HEAD.is_set()
|
||||
if envs.SGLANG_EMBEDDINGS_SPARSE_HEAD.is_set():
|
||||
kwargs["sparse_head"] = envs.SGLANG_EMBEDDINGS_SPARSE_HEAD.value
|
||||
kwargs["model_path"] = model_config.model_path
|
||||
|
||||
return model_class(**kwargs)
|
||||
|
||||
|
||||
class BaseModelLoader(ABC):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import os
|
||||
from typing import Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@@ -7,10 +8,12 @@ from torch import nn
|
||||
|
||||
from sglang.srt.layers.pooler import CrossEncodingPooler, Pooler, PoolingType
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.sparse_pooler import SparsePooler
|
||||
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.bert import BertEncoder
|
||||
from sglang.srt.utils.hf_transformers_utils import download_from_hf
|
||||
|
||||
RobertaConfig = None
|
||||
|
||||
@@ -205,12 +208,29 @@ class XLMRobertaModel(nn.Module):
|
||||
config: RobertaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
sparse_head: Optional[str] = None,
|
||||
model_path: Optional[str] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.roberta = XLMRobertaBaseModel(
|
||||
config=config, quant_config=quant_config, prefix=prefix
|
||||
)
|
||||
self.pooler = Pooler(pooling_type=PoolingType.CLS, normalize=True)
|
||||
if sparse_head is not None:
|
||||
self._is_sparse = True
|
||||
self._model_path = model_path
|
||||
self._sparse_head = sparse_head
|
||||
self.pooler = SparsePooler(config=config)
|
||||
# Zero out special tokens
|
||||
self._special_tokens = [
|
||||
config.bos_token_id,
|
||||
config.eos_token_id,
|
||||
config.pad_token_id,
|
||||
# self.config.unk_token_id # not available in the XLMRobertaConfig
|
||||
]
|
||||
self._special_tokens = [t for t in self._special_tokens if t is not None]
|
||||
else:
|
||||
self._is_sparse = False
|
||||
self.pooler = Pooler(pooling_type=PoolingType.CLS, normalize=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -223,11 +243,44 @@ class XLMRobertaModel(nn.Module):
|
||||
hidden_states = self.roberta(
|
||||
input_ids, positions, forward_batch, input_embeds, get_embedding
|
||||
)
|
||||
return self.pooler(hidden_states, forward_batch)
|
||||
embeddings = self.pooler(hidden_states, forward_batch)
|
||||
|
||||
if self._is_sparse:
|
||||
for token_id in self._special_tokens:
|
||||
embeddings.embeddings[:, token_id] = 0.0
|
||||
embeddings.embeddings = embeddings.embeddings.to_sparse()
|
||||
|
||||
return embeddings
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
self.roberta.load_weights(weights)
|
||||
|
||||
if self._is_sparse:
|
||||
sparse_dict = XLMRobertaModel._load_sparse_linear(
|
||||
self._model_path, self._sparse_head
|
||||
)
|
||||
self.pooler.load_weights(sparse_dict)
|
||||
|
||||
@staticmethod
|
||||
def _load_sparse_linear(model_path_or_dir: str, sparse_head: str) -> dict:
|
||||
"""
|
||||
Load sparse_head from local dir or HF Hub.
|
||||
Returns a state_dict suitable for nn.Linear.load_state_dict().
|
||||
"""
|
||||
if os.path.isdir(model_path_or_dir):
|
||||
path = os.path.join(model_path_or_dir, sparse_head)
|
||||
if not os.path.exists(path):
|
||||
raise FileNotFoundError(
|
||||
f"'{sparse_head}' not found in {model_path_or_dir}"
|
||||
)
|
||||
else:
|
||||
# remote → use SGLang HF utility
|
||||
local_dir = download_from_hf(model_path_or_dir, allow_patterns=sparse_head)
|
||||
path = os.path.join(local_dir, sparse_head)
|
||||
|
||||
state_dict = torch.load(path)
|
||||
return state_dict
|
||||
|
||||
|
||||
class XLMRobertaForSequenceClassification(nn.Module):
|
||||
def __init__(
|
||||
|
||||
Reference in New Issue
Block a user