Migrate llama_classification to use the /classify interface (#2417)

This commit is contained in:
Lianmin Zheng
2024-12-08 23:30:51 -08:00
parent 3844feb9bb
commit 835f8afc77
2 changed files with 30 additions and 25 deletions

View File

@@ -1,6 +1,6 @@
"""
Usage:
python3 -m sglang.launch_server --disable-cuda-graph --model-path /model/llama-classification
python3 -m sglang.launch_server --model-path /model/llama-classification --is-embedding --disable-radix-cache
python3 test_httpserver_classify.py
"""
@@ -11,7 +11,7 @@ import numpy as np
import requests
def get_logits(url, prompt):
def get_logits_deprecated(url: str, prompt: str):
response = requests.post(
url + "/generate",
json={
@@ -25,7 +25,7 @@ def get_logits(url, prompt):
return response.json()["meta_info"]["normalized_prompt_logprob"]
def get_logits_batch(url, prompts):
def get_logits_batch_deprecated(url: str, prompts: list[str]):
response = requests.post(
url + "/generate",
json={
@@ -46,6 +46,22 @@ def get_logits_batch(url, prompts):
return logits
def get_logits(url: str, prompt: str):
response = requests.post(
url + "/classify",
json={"text": prompt},
)
return response.json()["embedding"]
def get_logits_batch(url: str, prompts: list[str]):
response = requests.post(
url + "/classify",
json={"text": prompts},
)
return np.array([x["embedding"] for x in response.json()])
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")