Implement BGE-M3 Sparse Embeddings in SGLang (#10869)
Co-authored-by: Christian Bahls <christian.bahls@planet-ai.de> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -237,6 +237,9 @@ class Envs:
|
|||||||
SGLANG_KT_AMX_METHOD = EnvStr(None)
|
SGLANG_KT_AMX_METHOD = EnvStr(None)
|
||||||
SGLANG_KT_MOE_CHUNKED_PREFILL_SIZE = EnvInt(None)
|
SGLANG_KT_MOE_CHUNKED_PREFILL_SIZE = EnvInt(None)
|
||||||
|
|
||||||
|
# Sparse Embeddings
|
||||||
|
SGLANG_EMBEDDINGS_SPARSE_HEAD = EnvStr(None)
|
||||||
|
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
98
python/sglang/srt/layers/sparse_pooler.py
Normal file
98
python/sglang/srt/layers/sparse_pooler.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
from sglang.srt.model_executor.model_runner import ForwardBatch
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SparseEmbeddingOutput:
|
||||||
|
embeddings: torch.Tensor # [batch_size, vocab_size]
|
||||||
|
|
||||||
|
|
||||||
|
class SparsePooler(nn.Module):
|
||||||
|
"""A layer that pools hidden states into sparse vocabulary-space embeddings.
|
||||||
|
|
||||||
|
This layer does the following:
|
||||||
|
1. Applies a linear transformation + ReLU to get token-level weights
|
||||||
|
2. Maps these weights to vocabulary positions using token IDs
|
||||||
|
3. Aggregates weights for repeated tokens using max pooling
|
||||||
|
4. Returns sparse embeddings in vocabulary space
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
config: Model configuration containing vocab_size and hidden_size
|
||||||
|
sparse_linear: Linear layer for computing token weights
|
||||||
|
vocab_size: Size of vocabulary for output embeddings
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: PretrainedConfig):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Validate required attributes
|
||||||
|
if not hasattr(config, "vocab_size"):
|
||||||
|
raise AttributeError(
|
||||||
|
f"Config {type(config)} missing required 'vocab_size' attribute"
|
||||||
|
)
|
||||||
|
if not hasattr(config, "hidden_size"):
|
||||||
|
raise AttributeError(
|
||||||
|
f"Config {type(config)} missing required 'hidden_size' attribute"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
self.sparse_linear = nn.Linear(config.hidden_size, 1)
|
||||||
|
self._weights_loaded = False
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
|
||||||
|
) -> SparseEmbeddingOutput:
|
||||||
|
"""
|
||||||
|
Forward pass for sparse pooling.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hidden_states: Packed sequence hidden states [total_tokens, hidden_size]
|
||||||
|
forward_batch: Batch information with sequence lengths and input_ids
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SparseEmbeddingOutput with embeddings of shape [batch_size, vocab_size]
|
||||||
|
"""
|
||||||
|
if not self._weights_loaded:
|
||||||
|
raise ValueError(
|
||||||
|
"Sparse pooling weights not loaded. Call load_weights() first"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply sparse linear + ReLU to get token weights
|
||||||
|
token_weights = F.relu(self.sparse_linear(hidden_states)).squeeze(
|
||||||
|
-1
|
||||||
|
) # [total_tokens]
|
||||||
|
|
||||||
|
# Create batch indices for packed sequences
|
||||||
|
batch_indices = torch.repeat_interleave(
|
||||||
|
torch.arange(
|
||||||
|
len(forward_batch.extend_seq_lens), device=hidden_states.device
|
||||||
|
),
|
||||||
|
forward_batch.extend_seq_lens,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize sparse embedding output
|
||||||
|
sparse_embedding = torch.zeros(
|
||||||
|
len(forward_batch.extend_seq_lens),
|
||||||
|
self.vocab_size,
|
||||||
|
dtype=token_weights.dtype,
|
||||||
|
device=token_weights.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Map to vocabulary space using scatter_reduce with amax
|
||||||
|
flat_indices = batch_indices * self.vocab_size + forward_batch.input_ids
|
||||||
|
sparse_embedding.view(-1).scatter_reduce_(
|
||||||
|
0, flat_indices, token_weights, reduce="amax"
|
||||||
|
)
|
||||||
|
|
||||||
|
return SparseEmbeddingOutput(embeddings=sparse_embedding)
|
||||||
|
|
||||||
|
def load_weights(self, state_dict: dict):
|
||||||
|
"""Load weights from state dict (called by the model)."""
|
||||||
|
self.sparse_linear.load_state_dict(state_dict)
|
||||||
|
self._weights_loaded = True
|
||||||
@@ -961,7 +961,7 @@ class BatchEmbeddingOutput(BaseBatchReq):
|
|||||||
# The finish reason
|
# The finish reason
|
||||||
finished_reasons: List[BaseFinishReason]
|
finished_reasons: List[BaseFinishReason]
|
||||||
# The output embedding
|
# The output embedding
|
||||||
embeddings: List[List[float]]
|
embeddings: Union[List[List[float]], List[Dict[int, float]]]
|
||||||
# Token counts
|
# Token counts
|
||||||
prompt_tokens: List[int]
|
prompt_tokens: List[int]
|
||||||
cached_tokens: List[int]
|
cached_tokens: List[int]
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.disaggregation.utils import DisaggregationMode
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||||
|
from sglang.srt.environ import envs
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
AbortReq,
|
AbortReq,
|
||||||
@@ -175,7 +176,21 @@ class SchedulerOutputProcessorMixin:
|
|||||||
logprob_pt += num_input_logprobs
|
logprob_pt += num_input_logprobs
|
||||||
|
|
||||||
else: # embedding or reward model
|
else: # embedding or reward model
|
||||||
embeddings = result.embeddings.tolist()
|
is_sparse = envs.SGLANG_EMBEDDINGS_SPARSE_HEAD.is_set()
|
||||||
|
|
||||||
|
embeddings = result.embeddings
|
||||||
|
|
||||||
|
if is_sparse:
|
||||||
|
batch_ids, token_ids = embeddings.indices()
|
||||||
|
values = embeddings.values()
|
||||||
|
|
||||||
|
embeddings = [{} for _ in range(embeddings.size(0))]
|
||||||
|
for i in range(batch_ids.shape[0]):
|
||||||
|
embeddings[batch_ids[i].item()][token_ids[i].item()] = values[
|
||||||
|
i
|
||||||
|
].item()
|
||||||
|
else:
|
||||||
|
embeddings = embeddings.tolist()
|
||||||
|
|
||||||
# Check finish conditions
|
# Check finish conditions
|
||||||
for i, req in enumerate(batch.reqs):
|
for i, req in enumerate(batch.reqs):
|
||||||
|
|||||||
@@ -77,6 +77,7 @@ from sglang.srt.model_loader.utils import (
|
|||||||
DEFAULT_GPU_MEMORY_FRACTION_FOR_CALIBRATION = (
|
DEFAULT_GPU_MEMORY_FRACTION_FOR_CALIBRATION = (
|
||||||
0.8 # Reserve 20% GPU memory headroom for ModelOpt calibration
|
0.8 # Reserve 20% GPU memory headroom for ModelOpt calibration
|
||||||
)
|
)
|
||||||
|
from sglang.srt.environ import envs
|
||||||
from sglang.srt.model_loader.weight_utils import (
|
from sglang.srt.model_loader.weight_utils import (
|
||||||
download_safetensors_index_file_from_hf,
|
download_safetensors_index_file_from_hf,
|
||||||
download_weights_from_hf,
|
download_weights_from_hf,
|
||||||
@@ -244,10 +245,19 @@ def _initialize_model(
|
|||||||
quant_config = _get_quantization_config(
|
quant_config = _get_quantization_config(
|
||||||
model_config, load_config, packed_modules_mapping
|
model_config, load_config, packed_modules_mapping
|
||||||
)
|
)
|
||||||
return model_class(
|
|
||||||
config=model_config.hf_config,
|
# Build kwargs conditionally
|
||||||
quant_config=quant_config,
|
kwargs = {
|
||||||
)
|
"config": model_config.hf_config,
|
||||||
|
"quant_config": quant_config,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Only add sparse head kwargs if envs.SGLANG_EMBEDDINGS_SPARSE_HEAD.is_set()
|
||||||
|
if envs.SGLANG_EMBEDDINGS_SPARSE_HEAD.is_set():
|
||||||
|
kwargs["sparse_head"] = envs.SGLANG_EMBEDDINGS_SPARSE_HEAD.value
|
||||||
|
kwargs["model_path"] = model_config.model_path
|
||||||
|
|
||||||
|
return model_class(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
class BaseModelLoader(ABC):
|
class BaseModelLoader(ABC):
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import os
|
||||||
from typing import Iterable, Optional, Tuple
|
from typing import Iterable, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -7,10 +8,12 @@ from torch import nn
|
|||||||
|
|
||||||
from sglang.srt.layers.pooler import CrossEncodingPooler, Pooler, PoolingType
|
from sglang.srt.layers.pooler import CrossEncodingPooler, Pooler, PoolingType
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
|
from sglang.srt.layers.sparse_pooler import SparsePooler
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
from sglang.srt.models.bert import BertEncoder
|
from sglang.srt.models.bert import BertEncoder
|
||||||
|
from sglang.srt.utils.hf_transformers_utils import download_from_hf
|
||||||
|
|
||||||
RobertaConfig = None
|
RobertaConfig = None
|
||||||
|
|
||||||
@@ -205,12 +208,29 @@ class XLMRobertaModel(nn.Module):
|
|||||||
config: RobertaConfig,
|
config: RobertaConfig,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
|
sparse_head: Optional[str] = None,
|
||||||
|
model_path: Optional[str] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.roberta = XLMRobertaBaseModel(
|
self.roberta = XLMRobertaBaseModel(
|
||||||
config=config, quant_config=quant_config, prefix=prefix
|
config=config, quant_config=quant_config, prefix=prefix
|
||||||
)
|
)
|
||||||
self.pooler = Pooler(pooling_type=PoolingType.CLS, normalize=True)
|
if sparse_head is not None:
|
||||||
|
self._is_sparse = True
|
||||||
|
self._model_path = model_path
|
||||||
|
self._sparse_head = sparse_head
|
||||||
|
self.pooler = SparsePooler(config=config)
|
||||||
|
# Zero out special tokens
|
||||||
|
self._special_tokens = [
|
||||||
|
config.bos_token_id,
|
||||||
|
config.eos_token_id,
|
||||||
|
config.pad_token_id,
|
||||||
|
# self.config.unk_token_id # not available in the XLMRobertaConfig
|
||||||
|
]
|
||||||
|
self._special_tokens = [t for t in self._special_tokens if t is not None]
|
||||||
|
else:
|
||||||
|
self._is_sparse = False
|
||||||
|
self.pooler = Pooler(pooling_type=PoolingType.CLS, normalize=True)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -223,11 +243,44 @@ class XLMRobertaModel(nn.Module):
|
|||||||
hidden_states = self.roberta(
|
hidden_states = self.roberta(
|
||||||
input_ids, positions, forward_batch, input_embeds, get_embedding
|
input_ids, positions, forward_batch, input_embeds, get_embedding
|
||||||
)
|
)
|
||||||
return self.pooler(hidden_states, forward_batch)
|
embeddings = self.pooler(hidden_states, forward_batch)
|
||||||
|
|
||||||
|
if self._is_sparse:
|
||||||
|
for token_id in self._special_tokens:
|
||||||
|
embeddings.embeddings[:, token_id] = 0.0
|
||||||
|
embeddings.embeddings = embeddings.embeddings.to_sparse()
|
||||||
|
|
||||||
|
return embeddings
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
self.roberta.load_weights(weights)
|
self.roberta.load_weights(weights)
|
||||||
|
|
||||||
|
if self._is_sparse:
|
||||||
|
sparse_dict = XLMRobertaModel._load_sparse_linear(
|
||||||
|
self._model_path, self._sparse_head
|
||||||
|
)
|
||||||
|
self.pooler.load_weights(sparse_dict)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _load_sparse_linear(model_path_or_dir: str, sparse_head: str) -> dict:
|
||||||
|
"""
|
||||||
|
Load sparse_head from local dir or HF Hub.
|
||||||
|
Returns a state_dict suitable for nn.Linear.load_state_dict().
|
||||||
|
"""
|
||||||
|
if os.path.isdir(model_path_or_dir):
|
||||||
|
path = os.path.join(model_path_or_dir, sparse_head)
|
||||||
|
if not os.path.exists(path):
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"'{sparse_head}' not found in {model_path_or_dir}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# remote → use SGLang HF utility
|
||||||
|
local_dir = download_from_hf(model_path_or_dir, allow_patterns=sparse_head)
|
||||||
|
path = os.path.join(local_dir, sparse_head)
|
||||||
|
|
||||||
|
state_dict = torch.load(path)
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
class XLMRobertaForSequenceClassification(nn.Module):
|
class XLMRobertaForSequenceClassification(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
Reference in New Issue
Block a user