[Feat] Return hidden states (experimental) (#3364)

Co-authored-by: Chayenne <zhaochen20@outlook.com>
This commit is contained in:
Jackmin801
2025-02-10 15:54:37 -08:00
committed by GitHub
parent 2f47d710ae
commit 5f0e7de339
12 changed files with 204 additions and 5 deletions

View File

@@ -179,6 +179,59 @@
"asyncio.run(main())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"llm.shutdown()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Return Hidden States"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sglang as sgl\n",
"\n",
"llm = sgl.Engine(\n",
" model_path=\"meta-llama/Meta-Llama-3.1-8B-Instruct\", return_hidden_states=True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"prompts = [\n",
" \"Hello, my name is\",\n",
" \"The president of the United States is\",\n",
" \"The capital of France is\",\n",
" \"The future of AI is\",\n",
"]\n",
"\n",
"sampling_params = {\"temperature\": 0.8, \"top_p\": 0.95, \"max_new_tokens\": 10}\n",
"\n",
"outputs = llm.generate(prompts, sampling_params=sampling_params)\n",
"for prompt, output in zip(prompts, outputs):\n",
" print(\"===============================\")\n",
" print(\n",
" f\"Prompt: {prompt}\\nGenerated text: {output['text']}\\nPrompt_Tokens: {output['meta_info']['prompt_tokens']}\\tCompletion_tokens: {output['meta_info']['completion_tokens']}\\nHidden states: {[i.shape for i in output['meta_info']['hidden_states']]}\"\n",
" )\n",
" print()"
]
},
{
"cell_type": "code",
"execution_count": null,