Files
sglang/python/sglang/srt/models/llama_classification.py

78 lines
2.9 KiB
Python
Raw Normal View History

# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
2024-07-28 23:07:12 +10:00
2024-06-22 00:45:33 -07:00
from typing import Iterable, Optional, Tuple
import torch
from torch import nn
from transformers import LlamaConfig
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
2024-09-19 20:53:11 +08:00
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
2024-06-22 00:45:33 -07:00
class LlamaForClassification(nn.Module):
def __init__(
self,
config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = LlamaModel(config, quant_config=quant_config)
2024-07-05 10:06:17 -07:00
self.classification_head = nn.Linear(
config.hidden_size, config.classification_out_size, bias=False
2024-07-05 10:06:17 -07:00
)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False)
2024-06-22 00:45:33 -07:00
2024-07-15 22:09:09 -07:00
@torch.no_grad()
2024-06-22 00:45:33 -07:00
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
2024-06-22 00:45:33 -07:00
input_embeds: torch.Tensor = None,
get_embedding: bool = True,
) -> EmbeddingPoolerOutput:
assert (
get_embedding
), "LlamaForClassification is only used for embedding. Please add --is-embedding when you launch the server."
2024-06-22 00:45:33 -07:00
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
last_token_hidden = self.pooler(hidden_states, forward_batch).embeddings
scores = self.classification_head(last_token_hidden)
2024-06-22 00:45:33 -07:00
return EmbeddingPoolerOutput(scores)
2024-06-22 00:45:33 -07:00
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
2024-06-22 00:45:33 -07:00
for name, loaded_weight in weights:
if "classification_head" in name:
2024-06-22 00:45:33 -07:00
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
elif "lm_head" in name:
continue
else:
LlamaForCausalLM.load_weights(self, [(name, loaded_weight)])
2024-06-22 00:45:33 -07:00
2024-07-05 10:06:17 -07:00
EntryClass = LlamaForClassification