Migrate llama_classification to use the /classify interface (#2417)
This commit is contained in:
@@ -18,7 +18,7 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import LlamaConfig
|
from transformers import LlamaConfig
|
||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
@@ -40,7 +40,7 @@ class LlamaForClassification(nn.Module):
|
|||||||
self.classification_head = nn.Linear(
|
self.classification_head = nn.Linear(
|
||||||
config.hidden_size, config.classification_out_size, bias=False
|
config.hidden_size, config.classification_out_size, bias=False
|
||||||
)
|
)
|
||||||
self.eos_token_id = config.eos_token_id
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
@@ -49,28 +49,17 @@ class LlamaForClassification(nn.Module):
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
input_embeds: torch.Tensor = None,
|
input_embeds: torch.Tensor = None,
|
||||||
) -> torch.Tensor:
|
get_embedding: bool = True,
|
||||||
|
) -> EmbeddingPoolerOutput:
|
||||||
|
assert (
|
||||||
|
get_embedding
|
||||||
|
), "LlamaForClassification is only used for embedding. Please add --is-embedding when you launch the server."
|
||||||
|
|
||||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||||
is_eos_token = input_ids == self.eos_token_id
|
last_token_hidden = self.pooler(hidden_states, forward_batch).embeddings
|
||||||
hidden_states = hidden_states[is_eos_token]
|
scores = self.classification_head(last_token_hidden)
|
||||||
scores = self.classification_head(hidden_states)
|
|
||||||
|
|
||||||
if scores.shape[0] != forward_batch.batch_size:
|
return EmbeddingPoolerOutput(scores)
|
||||||
print("Warning: the EOS tokens are missing in some sentences.")
|
|
||||||
scores = torch.ones(
|
|
||||||
(forward_batch.batch_size, self.config.classification_out_size)
|
|
||||||
).to(input_ids.device)
|
|
||||||
|
|
||||||
logits_output = LogitsProcessorOutput(
|
|
||||||
next_token_logits=scores,
|
|
||||||
next_token_logprobs=scores,
|
|
||||||
normalized_prompt_logprobs=scores,
|
|
||||||
input_token_logprobs=torch.ones_like(input_ids),
|
|
||||||
input_top_logprobs=None,
|
|
||||||
output_top_logprobs=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
return logits_output
|
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
python3 -m sglang.launch_server --disable-cuda-graph --model-path /model/llama-classification
|
python3 -m sglang.launch_server --model-path /model/llama-classification --is-embedding --disable-radix-cache
|
||||||
|
|
||||||
python3 test_httpserver_classify.py
|
python3 test_httpserver_classify.py
|
||||||
"""
|
"""
|
||||||
@@ -11,7 +11,7 @@ import numpy as np
|
|||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
|
||||||
def get_logits(url, prompt):
|
def get_logits_deprecated(url: str, prompt: str):
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
url + "/generate",
|
url + "/generate",
|
||||||
json={
|
json={
|
||||||
@@ -25,7 +25,7 @@ def get_logits(url, prompt):
|
|||||||
return response.json()["meta_info"]["normalized_prompt_logprob"]
|
return response.json()["meta_info"]["normalized_prompt_logprob"]
|
||||||
|
|
||||||
|
|
||||||
def get_logits_batch(url, prompts):
|
def get_logits_batch_deprecated(url: str, prompts: list[str]):
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
url + "/generate",
|
url + "/generate",
|
||||||
json={
|
json={
|
||||||
@@ -46,6 +46,22 @@ def get_logits_batch(url, prompts):
|
|||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
def get_logits(url: str, prompt: str):
|
||||||
|
response = requests.post(
|
||||||
|
url + "/classify",
|
||||||
|
json={"text": prompt},
|
||||||
|
)
|
||||||
|
return response.json()["embedding"]
|
||||||
|
|
||||||
|
|
||||||
|
def get_logits_batch(url: str, prompts: list[str]):
|
||||||
|
response = requests.post(
|
||||||
|
url + "/classify",
|
||||||
|
json={"text": prompts},
|
||||||
|
)
|
||||||
|
return np.array([x["embedding"] for x in response.json()])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||||
|
|||||||
Reference in New Issue
Block a user