Migrate llama_classification to use the /classify interface (#2417)

This commit is contained in:
Lianmin Zheng
2024-12-08 23:30:51 -08:00
parent 3844feb9bb
commit 835f8afc77
2 changed files with 30 additions and 25 deletions

View File

@@ -18,7 +18,7 @@ import torch
from torch import nn
from transformers import LlamaConfig
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
@@ -40,7 +40,7 @@ class LlamaForClassification(nn.Module):
self.classification_head = nn.Linear(
config.hidden_size, config.classification_out_size, bias=False
)
self.eos_token_id = config.eos_token_id
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False)
@torch.no_grad()
def forward(
@@ -49,28 +49,17 @@ class LlamaForClassification(nn.Module):
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
get_embedding: bool = True,
) -> EmbeddingPoolerOutput:
assert (
get_embedding
), "LlamaForClassification is only used for embedding. Please add --is-embedding when you launch the server."
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
is_eos_token = input_ids == self.eos_token_id
hidden_states = hidden_states[is_eos_token]
scores = self.classification_head(hidden_states)
last_token_hidden = self.pooler(hidden_states, forward_batch).embeddings
scores = self.classification_head(last_token_hidden)
if scores.shape[0] != forward_batch.batch_size:
print("Warning: the EOS tokens are missing in some sentences.")
scores = torch.ones(
(forward_batch.batch_size, self.config.classification_out_size)
).to(input_ids.device)
logits_output = LogitsProcessorOutput(
next_token_logits=scores,
next_token_logprobs=scores,
normalized_prompt_logprobs=scores,
input_token_logprobs=torch.ones_like(input_ids),
input_top_logprobs=None,
output_top_logprobs=None,
)
return logits_output
return EmbeddingPoolerOutput(scores)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())