fix black in pre-commit (#1940)

This commit is contained in:
Chayenne
2024-11-07 15:42:47 -08:00
committed by GitHub
parent dca87ec348
commit c77c1e05ba
29 changed files with 641 additions and 508 deletions

View File

@@ -71,7 +71,7 @@
"source": [
"import json\n",
"import os\n",
"from typing import List\n",
"from typing import List\n",
"\n",
"import chromadb\n",
"\n",
@@ -80,7 +80,7 @@
"if not os.path.exists(path_qca):\n",
" !wget https://virattt.github.io/datasets/abnb-2023-10k.json -O airbnb-2023-10k-qca.json\n",
"\n",
"with open(path_qca, 'r') as f:\n",
"with open(path_qca, \"r\") as f:\n",
" question_context_answers = json.load(f)\n",
"\n",
"chroma_client = chromadb.PersistentClient()\n",
@@ -88,7 +88,7 @@
"if collection.count() == 0:\n",
" collection.add(\n",
" documents=[qca[\"context\"] for qca in question_context_answers],\n",
" ids=[str(i) for i in range(len(question_context_answers))]\n",
" ids=[str(i) for i in range(len(question_context_answers))],\n",
" )"
],
"metadata": {
@@ -123,7 +123,7 @@
"\n",
"load_dotenv()\n",
"\n",
"os.environ['TOKENIZERS_PARALLELISM'] = \"false\"\n",
"os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
"\n",
"p = Parea(api_key=os.getenv(\"PAREA_API_KEY\"), project_name=\"rag_sglang\")\n",
"p.integrate_with_sglang()\n",
@@ -150,10 +150,7 @@
"source": [
"@trace\n",
"def retrieval(question: str) -> List[str]:\n",
" return collection.query(\n",
" query_texts=[question],\n",
" n_results=1\n",
" )['documents'][0]"
" return collection.query(query_texts=[question], n_results=1)[\"documents\"][0]"
],
"metadata": {
"collapsed": false
@@ -176,7 +173,9 @@
"@function\n",
"def generation_sglang(s, question: str, *context: str):\n",
" context = \"\\n\".join(context)\n",
" s += user(f'Given this question:\\n{question}\\n\\nAnd this context:\\n{context}\\n\\nAnswer the question.')\n",
" s += user(\n",
" f\"Given this question:\\n{question}\\n\\nAnd this context:\\n{context}\\n\\nAnswer the question.\"\n",
" )\n",
" s += assistant(gen(\"answer\"))\n",
"\n",
"\n",
@@ -223,7 +222,9 @@
" return generation(question, *contexts)\n",
"\n",
"\n",
"rag_pipeline(\"When did the World Health Organization formally declare an end to the COVID-19 global health emergency?\")"
"rag_pipeline(\n",
" \"When did the World Health Organization formally declare an end to the COVID-19 global health emergency?\"\n",
")"
]
},
{
@@ -271,7 +272,10 @@
"execution_count": null,
"outputs": [],
"source": [
"from parea.evals.rag import context_query_relevancy_factory, percent_target_supported_by_context_factory\n",
"from parea.evals.rag import (\n",
" context_query_relevancy_factory,\n",
" percent_target_supported_by_context_factory,\n",
")\n",
"\n",
"\n",
"context_relevancy_eval = context_query_relevancy_factory()\n",
@@ -280,10 +284,7 @@
"\n",
"@trace(eval_funcs=[context_relevancy_eval, percent_target_supported_by_context])\n",
"def retrieval(question: str) -> List[str]:\n",
" return collection.query(\n",
" query_texts=[question],\n",
" n_results=1\n",
" )['documents'][0]"
" return collection.query(query_texts=[question], n_results=1)[\"documents\"][0]"
],
"metadata": {
"collapsed": false
@@ -310,10 +311,13 @@
"answer_context_faithfulness = answer_context_faithfulness_statement_level_factory()\n",
"answer_matches_target_llm_grader = answer_matches_target_llm_grader_factory()\n",
"\n",
"\n",
"@function\n",
"def generation_sglang(s, question: str, *context: str):\n",
" context = \"\\n\".join(context)\n",
" s += user(f'Given this question:\\n{question}\\n\\nAnd this context:\\n{context}\\n\\nAnswer the question.')\n",
" s += user(\n",
" f\"Given this question:\\n{question}\\n\\nAnd this context:\\n{context}\\n\\nAnswer the question.\"\n",
" )\n",
" s += assistant(gen(\"answer\", max_tokens=1_000))\n",
"\n",
"\n",
@@ -357,7 +361,9 @@
" return generation(question, *contexts)\n",
"\n",
"\n",
"rag_pipeline(\"When did the World Health Organization formally declare an end to the COVID-19 global health emergency?\")"
"rag_pipeline(\n",
" \"When did the World Health Organization formally declare an end to the COVID-19 global health emergency?\"\n",
")"
],
"metadata": {
"collapsed": false
@@ -402,6 +408,7 @@
"source": [
"!pip install nest-asyncio\n",
"import nest_asyncio\n",
"\n",
"nest_asyncio.apply()"
],
"metadata": {
@@ -461,7 +468,7 @@
],
"source": [
"e = p.experiment(\n",
" 'RAG',\n",
" \"RAG\",\n",
" data=[\n",
" {\n",
" \"question\": qca[\"question\"],\n",
@@ -469,7 +476,7 @@
" }\n",
" for qca in question_context_answers\n",
" ],\n",
" func=rag_pipeline\n",
" func=rag_pipeline,\n",
").run()"
],
"metadata": {

View File

@@ -7,6 +7,7 @@ from sglang.srt.hf_transformers_utils import get_tokenizer
MODEL_PATH = "meta-llama/Llama-3.1-8B-Instruct"
def main():
# Sample prompts.
prompts = [