Add Reward API Docs etc (#1910)

Co-authored-by: Chayenne <zhaochenyang@g.ucla.edu>
This commit is contained in:
Chayenne
2024-11-03 22:33:03 -08:00
committed by GitHub
parent 1853c3523b
commit 704f8e8ed1
6 changed files with 90 additions and 16 deletions

View File

@@ -5,9 +5,10 @@
"metadata": {},
"source": [
"# Native APIs\n",
"\n",
"Apart from the OpenAI compatible APIs, the SGLang Runtime also provides its native server APIs. We introduce these following APIs:\n",
"\n",
"- `/generate`\n",
"- `/generate` (text generation model)\n",
"- `/get_server_args`\n",
"- `/get_model_info`\n",
"- `/health`\n",
@@ -15,7 +16,8 @@
"- `/flush_cache`\n",
"- `/get_memory_pool_size`\n",
"- `/update_weights`\n",
"- `/encode`\n",
"- `/encode`(embedding model)\n",
"- `/judge`(reward model)\n",
"\n",
"We mainly use `requests` to test these APIs in the following examples. You can also use `curl`."
]
@@ -40,6 +42,8 @@
" print_highlight,\n",
")\n",
"\n",
"import requests\n",
"\n",
"server_process = execute_shell_command(\n",
" \"\"\"\n",
"python3 -m sglang.launch_server --model-path meta-llama/Llama-3.2-1B-Instruct --port=30010\n",
@@ -53,7 +57,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Generate\n",
"## Generate (text generation model)\n",
"Generate completions. This is similar to the `/v1/completions` in OpenAI API. Detailed parameters can be found in the [sampling parameters](../references/sampling_params.html)."
]
},
@@ -63,8 +67,6 @@
"metadata": {},
"outputs": [],
"source": [
"import requests\n",
"\n",
"url = \"http://localhost:30010/generate\"\n",
"data = {\"text\": \"What is the capital of France?\"}\n",
"\n",
@@ -252,7 +254,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Encode\n",
"## Encode (embedding model)\n",
"\n",
"Encode text into embeddings. Note that this API is only available for [embedding models](openai_api_embeddings.html#openai-apis-embedding) and will raise an error for generation models.\n",
"Therefore, we launch a new server to server an embedding model."
@@ -292,13 +294,76 @@
"print_highlight(f\"Text embedding (first 10): {response_json['embedding'][:10]}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Judge (reward model)\n",
"\n",
"SGLang Runtime also supports reward models. Here we use a reward model to judge the quality of pairwise generations."
]
},
{
"cell_type": "code",
"execution_count": 43,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"terminate_process(embedding_process)"
"terminate_process(embedding_process)\n",
"\n",
"# Note that SGLang now treats embedding models and reward models as the same type of models.\n",
"# This will be updated in the future.\n",
"\n",
"reward_process = execute_shell_command(\n",
" \"\"\"\n",
"python -m sglang.launch_server --model-path Skywork/Skywork-Reward-Llama-3.1-8B-v0.2 --port 30030 --host 0.0.0.0 --is-embedding\n",
"\"\"\"\n",
")\n",
"\n",
"wait_for_server(\"http://localhost:30030\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoTokenizer\n",
"\n",
"PROMPT = (\n",
" \"What is the range of the numeric output of a sigmoid node in a neural network?\"\n",
")\n",
"\n",
"RESPONSE1 = \"The output of a sigmoid node is bounded between -1 and 1.\"\n",
"RESPONSE2 = \"The output of a sigmoid node is bounded between 0 and 1.\"\n",
"\n",
"CONVS = [\n",
" [{\"role\": \"user\", \"content\": PROMPT}, {\"role\": \"assistant\", \"content\": RESPONSE1}],\n",
" [{\"role\": \"user\", \"content\": PROMPT}, {\"role\": \"assistant\", \"content\": RESPONSE2}],\n",
"]\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(\"Skywork/Skywork-Reward-Llama-3.1-8B-v0.2\")\n",
"prompts = tokenizer.apply_chat_template(CONVS, tokenize=False)\n",
"\n",
"url = \"http://localhost:30030/judge\"\n",
"data = {\n",
" \"model\": \"Skywork/Skywork-Reward-Llama-3.1-8B-v0.2\", \n",
" \"text\": prompts\n",
"}\n",
"\n",
"responses = requests.post(url, json=data).json()\n",
"for response in responses:\n",
" print_highlight(f\"reward: {response['embedding'][0]}\")"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"terminate_process(reward_process)"
]
}
],