Migrate llama_classification to use the /classify interface (#2417)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
"""
|
||||
Usage:
|
||||
python3 -m sglang.launch_server --disable-cuda-graph --model-path /model/llama-classification
|
||||
python3 -m sglang.launch_server --model-path /model/llama-classification --is-embedding --disable-radix-cache
|
||||
|
||||
python3 test_httpserver_classify.py
|
||||
"""
|
||||
@@ -11,7 +11,7 @@ import numpy as np
|
||||
import requests
|
||||
|
||||
|
||||
def get_logits(url, prompt):
|
||||
def get_logits_deprecated(url: str, prompt: str):
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
@@ -25,7 +25,7 @@ def get_logits(url, prompt):
|
||||
return response.json()["meta_info"]["normalized_prompt_logprob"]
|
||||
|
||||
|
||||
def get_logits_batch(url, prompts):
|
||||
def get_logits_batch_deprecated(url: str, prompts: list[str]):
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
@@ -46,6 +46,22 @@ def get_logits_batch(url, prompts):
|
||||
return logits
|
||||
|
||||
|
||||
def get_logits(url: str, prompt: str):
|
||||
response = requests.post(
|
||||
url + "/classify",
|
||||
json={"text": prompt},
|
||||
)
|
||||
return response.json()["embedding"]
|
||||
|
||||
|
||||
def get_logits_batch(url: str, prompts: list[str]):
|
||||
response = requests.post(
|
||||
url + "/classify",
|
||||
json={"text": prompts},
|
||||
)
|
||||
return np.array([x["embedding"] for x in response.json()])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||
|
||||
Reference in New Issue
Block a user