From 0de7c2d09efe1e6bd25bbff5f572ca629c04e197 Mon Sep 17 00:00:00 2001 From: Ying Sheng Date: Thu, 8 Aug 2024 00:04:15 -0700 Subject: [PATCH] Add e5-mistral modules [unreachable code] - step 1/3 (#983) --- python/sglang/srt/layers/pooler.py | 50 ++++++++++++ python/sglang/srt/models/llama_embedding.py | 86 +++++++++++++++++++++ 2 files changed, 136 insertions(+) create mode 100644 python/sglang/srt/layers/pooler.py create mode 100644 python/sglang/srt/models/llama_embedding.py diff --git a/python/sglang/srt/layers/pooler.py b/python/sglang/srt/layers/pooler.py new file mode 100644 index 000000000..21752366a --- /dev/null +++ b/python/sglang/srt/layers/pooler.py @@ -0,0 +1,50 @@ +# adapted from +# https://github.com/vllm-project/vllm/blob/82a1b1a82b1fbb454c82a9ef95730b929c9b270c/vllm/model_executor/layers/pooler.py + +from dataclasses import dataclass +from enum import IntEnum + +import torch +import torch.nn as nn + +from sglang.srt.model_executor.model_runner import InputMetadata + + +class PoolingType(IntEnum): + LAST = 0 + + +@dataclass +class EmbeddingPoolerOutput: + embeddings: torch.Tensor + + +class Pooler(nn.Module): + """A layer that pools specific information from hidden states. + This layer does the following: + 1. Extracts specific tokens or aggregates data based on pooling method. + 2. Normalizes output if specified. + 3. Returns structured results as `PoolerOutput`. + Attributes: + pooling_type: The type of pooling to use (LAST, AVERAGE, MAX). + normalize: Whether to normalize the pooled data. + """ + + def __init__(self, pooling_type: PoolingType, normalize: bool): + super().__init__() + self.pooling_type = pooling_type + self.normalize = normalize + + def forward( + self, hidden_states: torch.Tensor, input_metadata: InputMetadata + ) -> EmbeddingPoolerOutput: + if self.pooling_type == PoolingType.LAST: + last_token_indices = torch.cumsum(input_metadata.extend_seq_lens, dim=0) - 1 + pooled_data = hidden_states[last_token_indices] + else: + raise ValueError(f"Invalid pooling type: {self.pooling_type}") + + if self.normalize: + pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1) + + return EmbeddingPoolerOutput(embeddings=pooled_data) diff --git a/python/sglang/srt/models/llama_embedding.py b/python/sglang/srt/models/llama_embedding.py new file mode 100644 index 000000000..b849a4b51 --- /dev/null +++ b/python/sglang/srt/models/llama_embedding.py @@ -0,0 +1,86 @@ +from typing import Iterable, Optional, Tuple + +import torch +from torch import nn +from transformers import LlamaConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType +from sglang.srt.model_executor.model_runner import InputMetadata +from sglang.srt.models.llama2 import LlamaForCausalLM, LlamaModel + + +class LlamaEmbeddingModel(nn.Module): + def __init__( + self, + config: LlamaConfig, + quant_config=None, + cache_config=None, + efficient_weight_load=False, + ) -> None: + super().__init__() + self.model = LlamaModel(config, quant_config=quant_config) + self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_metadata: InputMetadata, + input_embeds: torch.Tensor = None, + ) -> EmbeddingPoolerOutput: + hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) + return self.pooler(hidden_states, input_metadata) + + def load_weights( + self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None + ): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.model.named_parameters()) + + def load_weights_per_param(name, loaded_weight): + if "rotary_emb.inv_freq" in name or "projector" in name: + return + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + return + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if name.startswith("model.vision_tower") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + return + if name.startswith("model.vision_tower") and name not in params_dict: + return + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + if name is None or loaded_weight is None: + for name, loaded_weight in weights: + load_weights_per_param(name, loaded_weight) + else: + load_weights_per_param(name, loaded_weight) + + +EntryClass = LlamaEmbeddingModel