fix black in pre-commit (#1940)
This commit is contained in:
@@ -71,7 +71,7 @@
|
||||
"source": [
|
||||
"import json\n",
|
||||
"import os\n",
|
||||
"from typing import List\n",
|
||||
"from typing import List\n",
|
||||
"\n",
|
||||
"import chromadb\n",
|
||||
"\n",
|
||||
@@ -80,7 +80,7 @@
|
||||
"if not os.path.exists(path_qca):\n",
|
||||
" !wget https://virattt.github.io/datasets/abnb-2023-10k.json -O airbnb-2023-10k-qca.json\n",
|
||||
"\n",
|
||||
"with open(path_qca, 'r') as f:\n",
|
||||
"with open(path_qca, \"r\") as f:\n",
|
||||
" question_context_answers = json.load(f)\n",
|
||||
"\n",
|
||||
"chroma_client = chromadb.PersistentClient()\n",
|
||||
@@ -88,7 +88,7 @@
|
||||
"if collection.count() == 0:\n",
|
||||
" collection.add(\n",
|
||||
" documents=[qca[\"context\"] for qca in question_context_answers],\n",
|
||||
" ids=[str(i) for i in range(len(question_context_answers))]\n",
|
||||
" ids=[str(i) for i in range(len(question_context_answers))],\n",
|
||||
" )"
|
||||
],
|
||||
"metadata": {
|
||||
@@ -123,7 +123,7 @@
|
||||
"\n",
|
||||
"load_dotenv()\n",
|
||||
"\n",
|
||||
"os.environ['TOKENIZERS_PARALLELISM'] = \"false\"\n",
|
||||
"os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
|
||||
"\n",
|
||||
"p = Parea(api_key=os.getenv(\"PAREA_API_KEY\"), project_name=\"rag_sglang\")\n",
|
||||
"p.integrate_with_sglang()\n",
|
||||
@@ -150,10 +150,7 @@
|
||||
"source": [
|
||||
"@trace\n",
|
||||
"def retrieval(question: str) -> List[str]:\n",
|
||||
" return collection.query(\n",
|
||||
" query_texts=[question],\n",
|
||||
" n_results=1\n",
|
||||
" )['documents'][0]"
|
||||
" return collection.query(query_texts=[question], n_results=1)[\"documents\"][0]"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
@@ -176,7 +173,9 @@
|
||||
"@function\n",
|
||||
"def generation_sglang(s, question: str, *context: str):\n",
|
||||
" context = \"\\n\".join(context)\n",
|
||||
" s += user(f'Given this question:\\n{question}\\n\\nAnd this context:\\n{context}\\n\\nAnswer the question.')\n",
|
||||
" s += user(\n",
|
||||
" f\"Given this question:\\n{question}\\n\\nAnd this context:\\n{context}\\n\\nAnswer the question.\"\n",
|
||||
" )\n",
|
||||
" s += assistant(gen(\"answer\"))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
@@ -223,7 +222,9 @@
|
||||
" return generation(question, *contexts)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"rag_pipeline(\"When did the World Health Organization formally declare an end to the COVID-19 global health emergency?\")"
|
||||
"rag_pipeline(\n",
|
||||
" \"When did the World Health Organization formally declare an end to the COVID-19 global health emergency?\"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -271,7 +272,10 @@
|
||||
"execution_count": null,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from parea.evals.rag import context_query_relevancy_factory, percent_target_supported_by_context_factory\n",
|
||||
"from parea.evals.rag import (\n",
|
||||
" context_query_relevancy_factory,\n",
|
||||
" percent_target_supported_by_context_factory,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"context_relevancy_eval = context_query_relevancy_factory()\n",
|
||||
@@ -280,10 +284,7 @@
|
||||
"\n",
|
||||
"@trace(eval_funcs=[context_relevancy_eval, percent_target_supported_by_context])\n",
|
||||
"def retrieval(question: str) -> List[str]:\n",
|
||||
" return collection.query(\n",
|
||||
" query_texts=[question],\n",
|
||||
" n_results=1\n",
|
||||
" )['documents'][0]"
|
||||
" return collection.query(query_texts=[question], n_results=1)[\"documents\"][0]"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
@@ -310,10 +311,13 @@
|
||||
"answer_context_faithfulness = answer_context_faithfulness_statement_level_factory()\n",
|
||||
"answer_matches_target_llm_grader = answer_matches_target_llm_grader_factory()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@function\n",
|
||||
"def generation_sglang(s, question: str, *context: str):\n",
|
||||
" context = \"\\n\".join(context)\n",
|
||||
" s += user(f'Given this question:\\n{question}\\n\\nAnd this context:\\n{context}\\n\\nAnswer the question.')\n",
|
||||
" s += user(\n",
|
||||
" f\"Given this question:\\n{question}\\n\\nAnd this context:\\n{context}\\n\\nAnswer the question.\"\n",
|
||||
" )\n",
|
||||
" s += assistant(gen(\"answer\", max_tokens=1_000))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
@@ -357,7 +361,9 @@
|
||||
" return generation(question, *contexts)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"rag_pipeline(\"When did the World Health Organization formally declare an end to the COVID-19 global health emergency?\")"
|
||||
"rag_pipeline(\n",
|
||||
" \"When did the World Health Organization formally declare an end to the COVID-19 global health emergency?\"\n",
|
||||
")"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
@@ -402,6 +408,7 @@
|
||||
"source": [
|
||||
"!pip install nest-asyncio\n",
|
||||
"import nest_asyncio\n",
|
||||
"\n",
|
||||
"nest_asyncio.apply()"
|
||||
],
|
||||
"metadata": {
|
||||
@@ -461,7 +468,7 @@
|
||||
],
|
||||
"source": [
|
||||
"e = p.experiment(\n",
|
||||
" 'RAG',\n",
|
||||
" \"RAG\",\n",
|
||||
" data=[\n",
|
||||
" {\n",
|
||||
" \"question\": qca[\"question\"],\n",
|
||||
@@ -469,7 +476,7 @@
|
||||
" }\n",
|
||||
" for qca in question_context_answers\n",
|
||||
" ],\n",
|
||||
" func=rag_pipeline\n",
|
||||
" func=rag_pipeline,\n",
|
||||
").run()"
|
||||
],
|
||||
"metadata": {
|
||||
|
||||
@@ -7,6 +7,7 @@ from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
|
||||
MODEL_PATH = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
|
||||
|
||||
def main():
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
|
||||
Reference in New Issue
Block a user