初始化项目,由ModelHub XC社区提供模型
Model: Salesforce/xLAM-1b-fc-r Source: Original Platform
This commit is contained in:
565
examples/demo.ipynb
Normal file
565
examples/demo.ipynb
Normal file
@@ -0,0 +1,565 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# xLAM Model Function-Calling Capabilities Demo\n",
|
||||
"\n",
|
||||
"This notebook demonstrates the function-calling capabilities of the xLAM model. The xLAM model is designed to handle various tasks by generating appropriate function calls based on the given query and available tools.\n",
|
||||
"\n",
|
||||
"We will cover the following steps:\n",
|
||||
"1. Setup and Initialization\n",
|
||||
"2. Example Usage with Provided Demo APIs\n",
|
||||
"3. Executing Real-Time Weather API Calls\n",
|
||||
"\n",
|
||||
"Let's get started!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 1. Setup and Initialization\n",
|
||||
"\n",
|
||||
"First, we need to set up the environment and initialize the xLAMHandler class. Ensure you have all the necessary dependencies installed:\n",
|
||||
"- `vllm`\n",
|
||||
"- `jinja2`\n",
|
||||
"- `requests`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Next, we'll import the necessary modules and define the xLAMHandler class and utility functions. You can find the script provided earlier in the cell below."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/export/home/conda/envs/rl/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||||
" from .autonotebook import tqdm as notebook_tqdm\n",
|
||||
"2024-07-18 07:25:11,294\tINFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO 07-18 07:25:13 llm_engine.py:161] Initializing an LLM engine (v0.5.0) with config: model='Salesforce/xLAM-1b-fc-r', speculative_config=None, tokenizer='Salesforce/xLAM-1b-fc-r', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=65536, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), seed=0, served_model_name=Salesforce/xLAM-1b-fc-r)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO 07-18 07:25:24 weight_utils.py:218] Using model weights format ['*.safetensors']\n",
|
||||
"INFO 07-18 07:25:24 weight_utils.py:261] No model.safetensors.index.json found in remote.\n",
|
||||
"INFO 07-18 07:25:25 model_runner.py:159] Loading model weights took 2.5583 GB\n",
|
||||
"INFO 07-18 07:25:31 gpu_executor.py:83] # GPU blocks: 10075, # CPU blocks: 1365\n",
|
||||
"INFO 07-18 07:25:40 model_runner.py:878] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.\n",
|
||||
"INFO 07-18 07:25:40 model_runner.py:882] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.\n",
|
||||
"INFO 07-18 07:26:02 model_runner.py:954] Graph capturing finished in 22 secs.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import json\n",
|
||||
"import time\n",
|
||||
"from typing import List, Dict\n",
|
||||
"\n",
|
||||
"from vllm import LLM, SamplingParams\n",
|
||||
"from jinja2 import Template\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"TASK_INSTRUCTION = \"\"\"\n",
|
||||
"You are an expert in composing functions. You are given a question and a set of possible functions. \n",
|
||||
"Based on the question, you will need to make one or more function/tool calls to achieve the purpose. \n",
|
||||
"If none of the functions can be used, point it out and refuse to answer. \n",
|
||||
"If the given question lacks the parameters required by the function, also point it out.\n",
|
||||
"\"\"\".strip()\n",
|
||||
"\n",
|
||||
"FORMAT_INSTRUCTION = \"\"\"\n",
|
||||
"The output MUST strictly adhere to the following JSON format, and NO other text MUST be included.\n",
|
||||
"The example format is as follows. Please make sure the parameter type is correct. If no function call is needed, please make tool_calls an empty list '[]'\n",
|
||||
"```\n",
|
||||
"{\n",
|
||||
" \"tool_calls\": [\n",
|
||||
" {\"name\": \"func_name1\", \"arguments\": {\"argument1\": \"value1\", \"argument2\": \"value2\"}},\n",
|
||||
" ... (more tool calls as required)\n",
|
||||
" ]\n",
|
||||
"}\n",
|
||||
"```\n",
|
||||
"\"\"\".strip()\n",
|
||||
"\n",
|
||||
"class XLAMHandler:\n",
|
||||
" def __init__(self, \n",
|
||||
" model: str, \n",
|
||||
" temperature: float = 0.3, \n",
|
||||
" top_p: float = 1, \n",
|
||||
" max_tokens: int = 512,\n",
|
||||
" tensor_parallel_size: int = 1,\n",
|
||||
" dtype: str = \"bfloat16\"):\n",
|
||||
" \n",
|
||||
" # Initialize LLM with GPU specifications\n",
|
||||
" self.llm = LLM(model=model,\n",
|
||||
" tensor_parallel_size=tensor_parallel_size,\n",
|
||||
" dtype=dtype)\n",
|
||||
" \n",
|
||||
" self.sampling_params = SamplingParams(\n",
|
||||
" temperature=temperature,\n",
|
||||
" top_p=top_p,\n",
|
||||
" max_tokens=max_tokens\n",
|
||||
" )\n",
|
||||
" self.chat_template = self.llm.get_tokenizer().chat_template\n",
|
||||
" \n",
|
||||
" @staticmethod\n",
|
||||
" def apply_chat_template(template, messages):\n",
|
||||
" jinja_template = Template(template)\n",
|
||||
" return jinja_template.render(messages=messages)\n",
|
||||
"\n",
|
||||
" def process_query(self, query: str, tools: list, task_instruction: str, format_instruction: str):\n",
|
||||
" # Convert tools to XLAM format\n",
|
||||
" xlam_tools = self.convert_to_xlam_tool(tools)\n",
|
||||
"\n",
|
||||
" # Build the input prompt\n",
|
||||
" prompt = self.build_prompt(query, xlam_tools, task_instruction, format_instruction)\n",
|
||||
"\n",
|
||||
" messages = [\n",
|
||||
" {\"role\": \"user\", \"content\": prompt}\n",
|
||||
" ]\n",
|
||||
" formatted_prompt = self.apply_chat_template(self.chat_template, messages)\n",
|
||||
"\n",
|
||||
" # Make inference\n",
|
||||
" start_time = time.time()\n",
|
||||
" outputs = self.llm.generate([formatted_prompt], self.sampling_params)\n",
|
||||
" latency = time.time() - start_time\n",
|
||||
"\n",
|
||||
" # Calculate tokens per second\n",
|
||||
" tokens_generated = sum(len(output.text.split()) for output in outputs[0].outputs)\n",
|
||||
" tokens_per_second = tokens_generated / latency\n",
|
||||
"\n",
|
||||
" # Parse response\n",
|
||||
" result = outputs[0].outputs[0].text\n",
|
||||
" parsed_result, success, _ = self.parse_response(result)\n",
|
||||
"\n",
|
||||
" # Prepare metadata\n",
|
||||
" metadata = {\n",
|
||||
" \"latency\": latency,\n",
|
||||
" \"tokens_per_second\": tokens_per_second,\n",
|
||||
" \"success\": success,\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" return parsed_result, metadata\n",
|
||||
"\n",
|
||||
" def convert_to_xlam_tool(self, tools):\n",
|
||||
" if isinstance(tools, dict):\n",
|
||||
" return {\n",
|
||||
" \"name\": tools[\"name\"],\n",
|
||||
" \"description\": tools[\"description\"],\n",
|
||||
" \"parameters\": {k: v for k, v in tools[\"parameters\"].get(\"properties\", {}).items()}\n",
|
||||
" }\n",
|
||||
" elif isinstance(tools, list):\n",
|
||||
" return [self.convert_to_xlam_tool(tool) for tool in tools]\n",
|
||||
" else:\n",
|
||||
" return tools\n",
|
||||
"\n",
|
||||
" def build_prompt(self, query, tools, task_instruction=TASK_INSTRUCTION, format_instruction=FORMAT_INSTRUCTION):\n",
|
||||
" prompt = f\"[BEGIN OF TASK INSTRUCTION]\\n{task_instruction}\\n[END OF TASK INSTRUCTION]\\n\\n\"\n",
|
||||
" prompt += f\"[BEGIN OF AVAILABLE TOOLS]\\n{json.dumps(tools)}\\n[END OF AVAILABLE TOOLS]\\n\\n\"\n",
|
||||
" prompt += f\"[BEGIN OF FORMAT INSTRUCTION]\\n{format_instruction}\\n[END OF FORMAT INSTRUCTION]\\n\\n\"\n",
|
||||
" prompt += f\"[BEGIN OF QUERY]\\n{query}\\n[END OF QUERY]\\n\\n\"\n",
|
||||
" return prompt\n",
|
||||
"\n",
|
||||
" def parse_response(self, response):\n",
|
||||
" try:\n",
|
||||
" data = json.loads(response)\n",
|
||||
" tool_calls = data.get('tool_calls', []) if isinstance(data, dict) else data\n",
|
||||
" result = [\n",
|
||||
" {tool_call['name']: tool_call['arguments']}\n",
|
||||
" for tool_call in tool_calls if isinstance(tool_call, dict)\n",
|
||||
" ]\n",
|
||||
" return result, True, []\n",
|
||||
" except json.JSONDecodeError:\n",
|
||||
" return [], False, [\"Failed to parse JSON response\"]\n",
|
||||
"\n",
|
||||
"handler = XLAMHandler(model=\"Salesforce/xLAM-1b-fc-r\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 2. Example Usage with Demo APIs\n",
|
||||
"\n",
|
||||
"In this section, we'll demonstrate how to use the xLAMHandler class with some example APIs. We'll start by defining several API tools and some test queries."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Query: What's the weather like in New York in Fahrenheit?\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Processed prompts: 100%|██████████| 1/1 [00:00<00:00, 4.51it/s, Generation Speed: 176.89 toks/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Result: [\n",
|
||||
" {\n",
|
||||
" \"get_weather\": {\n",
|
||||
" \"location\": \"New York\",\n",
|
||||
" \"unit\": \"fahrenheit\"\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
"]\n",
|
||||
"Latency: 0.22673869132995605\n",
|
||||
"Speed: 39.69326958363258\n",
|
||||
"--------------------------------------------------\n",
|
||||
"Query: What is the stock price of CRM?\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Processed prompts: 100%|██████████| 1/1 [00:00<00:00, 5.86it/s, Generation Speed: 182.37 toks/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Result: [\n",
|
||||
" {\n",
|
||||
" \"get_stock_price\": {\n",
|
||||
" \"symbol\": \"CRM\"\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
"]\n",
|
||||
"Latency: 0.17523670196533203\n",
|
||||
"Speed: 34.23940266341585\n",
|
||||
"--------------------------------------------------\n",
|
||||
"Query: Tell me the temperature in London in Celsius\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Processed prompts: 100%|██████████| 1/1 [00:00<00:00, 5.08it/s, Generation Speed: 183.60 toks/s]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Result: [\n",
|
||||
" {\n",
|
||||
" \"get_weather\": {\n",
|
||||
" \"location\": \"London\",\n",
|
||||
" \"unit\": \"celsius\"\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
"]\n",
|
||||
"Latency: 0.20116281509399414\n",
|
||||
"Speed: 39.768781304148916\n",
|
||||
"--------------------------------------------------\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"get_weather_api = {\n",
|
||||
" \"name\": \"get_weather\",\n",
|
||||
" \"description\": \"Get the current weather for a location\",\n",
|
||||
" \"parameters\": {\n",
|
||||
" \"type\": \"object\",\n",
|
||||
" \"properties\": {\n",
|
||||
" \"location\": {\n",
|
||||
" \"type\": \"string\",\n",
|
||||
" \"description\": \"The city and state, e.g. San Francisco, New York\"\n",
|
||||
" },\n",
|
||||
" \"unit\": {\n",
|
||||
" \"type\": \"string\",\n",
|
||||
" \"enum\": [\"celsius\", \"fahrenheit\"],\n",
|
||||
" \"description\": \"The unit of temperature to return\"\n",
|
||||
" }\n",
|
||||
" },\n",
|
||||
" \"required\": [\"location\"]\n",
|
||||
" }\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"search_api = {\n",
|
||||
" \"name\": \"search\",\n",
|
||||
" \"description\": \"Search for information on the internet\",\n",
|
||||
" \"parameters\": {\n",
|
||||
" \"type\": \"object\",\n",
|
||||
" \"properties\": {\n",
|
||||
" \"query\": {\n",
|
||||
" \"type\": \"string\",\n",
|
||||
" \"description\": \"The search query, e.g. 'latest news on AI'\"\n",
|
||||
" }\n",
|
||||
" },\n",
|
||||
" \"required\": [\"query\"]\n",
|
||||
" }\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"get_stock_price_api = {\n",
|
||||
" \"name\": \"get_stock_price\",\n",
|
||||
" \"description\": \"Get the current stock price for a company\",\n",
|
||||
" \"parameters\": {\n",
|
||||
" \"type\": \"object\",\n",
|
||||
" \"properties\": {\n",
|
||||
" \"symbol\": {\n",
|
||||
" \"type\": \"string\",\n",
|
||||
" \"description\": \"The stock symbol, e.g. 'AAPL' for Apple Inc.\"\n",
|
||||
" }\n",
|
||||
" },\n",
|
||||
" \"required\": [\"symbol\"]\n",
|
||||
" }\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"get_news_api = {\n",
|
||||
" \"name\": \"get_news\",\n",
|
||||
" \"description\": \"Get the latest news headlines\",\n",
|
||||
" \"parameters\": {\n",
|
||||
" \"type\": \"object\",\n",
|
||||
" \"properties\": {\n",
|
||||
" \"topic\": {\n",
|
||||
" \"type\": \"string\",\n",
|
||||
" \"description\": \"The news topic, e.g. 'technology', 'sports'\"\n",
|
||||
" }\n",
|
||||
" },\n",
|
||||
" \"required\": [\"topic\"]\n",
|
||||
" }\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"all_apis = [get_weather_api, search_api, get_stock_price_api, get_news_api]\n",
|
||||
"\n",
|
||||
"test_queries = [\n",
|
||||
" \"What's the weather like in New York in Fahrenheit?\",\n",
|
||||
" \"What is the stock price of CRM?\",\n",
|
||||
" \"Tell me the temperature in London in Celsius\",\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"for query in test_queries:\n",
|
||||
" print(f\"Query: {query}\")\n",
|
||||
" result, metadata = handler.process_query(query, all_apis, TASK_INSTRUCTION, FORMAT_INSTRUCTION)\n",
|
||||
" print(f\"Result: {json.dumps(result, indent=2)}\")\n",
|
||||
" print(\"Latency: \", metadata[\"latency\"])\n",
|
||||
" print(\"Speed: \", metadata[\"tokens_per_second\"])\n",
|
||||
" print(\"-\" * 50)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 3. Executing Real-Time Weather API Calls\n",
|
||||
"\n",
|
||||
"To make real-time weather API calls, we'll use the `requests` library to fetch data from a weather service. After obtaining the weather data, we will ask our xLAM model to summarize the results."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"The current weather in San Francisco is 16.0 celsius\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import ast\n",
|
||||
"import requests\n",
|
||||
"\n",
|
||||
"def get_weather(location, unit):\n",
|
||||
" \"\"\"\n",
|
||||
" Get the current weather for a specified location.\n",
|
||||
"\n",
|
||||
" Args:\n",
|
||||
" location (str): The city and state, e.g. San Francisco, New York.\n",
|
||||
" unit (str): The unit of temperature to return, either 'celsius' or 'fahrenheit'.\n",
|
||||
"\n",
|
||||
" Returns:\n",
|
||||
" float: The temperature in the corresponding unit.\n",
|
||||
" \"\"\"\n",
|
||||
" base_url = \"https://wttr.in\"\n",
|
||||
" unit_param = \"m\" if unit == \"celsius\" else \"u\"\n",
|
||||
" params = {\n",
|
||||
" \"format\": \"j1\",\n",
|
||||
" \"unit\": unit_param\n",
|
||||
" }\n",
|
||||
" response = requests.get(f\"{base_url}/{location}\", params=params)\n",
|
||||
" if response.status_code == 200:\n",
|
||||
" weather_data = response.json()[\"current_condition\"][0]\n",
|
||||
" return float(weather_data[\"temp_C\"]) if unit == \"celsius\" else float(weather_data[\"temp_F\"])\n",
|
||||
" else:\n",
|
||||
" return {\"error\": \"Failed to retrieve weather data\"}\n",
|
||||
" \n",
|
||||
"def execute_function_calls(function_calls):\n",
|
||||
" \"\"\"\n",
|
||||
" Convert the dictionary function_calls to executable Python code and execute the corresponding functions.\n",
|
||||
"\n",
|
||||
" Args:\n",
|
||||
" function_calls (list): A list of dictionaries containing function calls and their arguments.\n",
|
||||
"\n",
|
||||
" Returns:\n",
|
||||
" list: A list of results from executing the functions.\n",
|
||||
" \"\"\"\n",
|
||||
" results = []\n",
|
||||
" for function_call in function_calls:\n",
|
||||
" for func_name, args in function_call.items():\n",
|
||||
" if func_name in globals() and callable(globals()[func_name]):\n",
|
||||
" try:\n",
|
||||
" # Safely evaluate the arguments\n",
|
||||
" safe_args = ast.literal_eval(str(args))\n",
|
||||
" print(safe_args)\n",
|
||||
" # Call the function with unpacked arguments\n",
|
||||
" func_result = globals()[func_name](**safe_args)\n",
|
||||
" results.append(func_result)\n",
|
||||
" except Exception as e:\n",
|
||||
" results.append(f\"Error {str(e)}\")\n",
|
||||
" else:\n",
|
||||
" results.append(\"Error: Function not found or not callable\")\n",
|
||||
" \n",
|
||||
" return results\n",
|
||||
"\n",
|
||||
"# Example usage\n",
|
||||
"location = \"San Francisco\"\n",
|
||||
"unit = \"celsius\"\n",
|
||||
"weather_data = get_weather(location, unit)\n",
|
||||
"print(f\"The current weather in {location} is {weather_data} {unit}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Processed prompts: 100%|██████████| 1/1 [00:00<00:00, 4.86it/s, Generation Speed: 180.67 toks/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"The function call result: [\n",
|
||||
" {\n",
|
||||
" \"get_weather\": {\n",
|
||||
" \"location\": \"San Francisco\",\n",
|
||||
" \"unit\": \"celsius\"\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
"]\n",
|
||||
"{'location': 'San Francisco', 'unit': 'celsius'}\n",
|
||||
"Execution results: [16.0]\n",
|
||||
"--------------------------------------------------\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Processed prompts: 100%|██████████| 1/1 [00:00<00:00, 4.67it/s, Generation Speed: 183.21 toks/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"The function call result: [\n",
|
||||
" {\n",
|
||||
" \"get_weather\": {\n",
|
||||
" \"location\": \"New York\",\n",
|
||||
" \"unit\": \"fahrenheit\"\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
"]\n",
|
||||
"{'location': 'New York', 'unit': 'fahrenheit'}\n",
|
||||
"Execution results: [74.0]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Example 1\n",
|
||||
"query = \"I want to know the weather in San Francisco in Celsius\"\n",
|
||||
"function_calls, metadata = handler.process_query(query, all_apis, TASK_INSTRUCTION, FORMAT_INSTRUCTION)\n",
|
||||
"print(f\"The function call result: {json.dumps(function_calls, indent=2)}\")\n",
|
||||
"execution_results = execute_function_calls(function_calls)\n",
|
||||
"print(\"Execution results: \", execution_results)\n",
|
||||
"print(\"-\" * 50)\n",
|
||||
"\n",
|
||||
"# Example 2\n",
|
||||
"query = \"Tell me the temperature in New York in Fahrenheit\"\n",
|
||||
"function_calls, metadata = handler.process_query(query, all_apis, TASK_INSTRUCTION, FORMAT_INSTRUCTION)\n",
|
||||
"print(f\"The function call result: {json.dumps(function_calls, indent=2)}\")\n",
|
||||
"execution_results = execute_function_calls(function_calls)\n",
|
||||
"print(\"Execution results: \", execution_results)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.14"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
161
examples/test_prompt_template.py
Normal file
161
examples/test_prompt_template.py
Normal file
@@ -0,0 +1,161 @@
|
||||
import argparse
|
||||
import json
|
||||
from typing import Dict
|
||||
|
||||
from jinja2 import Template
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# Default prompts
|
||||
TASK_INSTRUCTION = """
|
||||
You are an expert in composing functions. You are given a question and a set of possible functions.
|
||||
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
|
||||
If none of the functions can be used, point it out and refuse to answer.
|
||||
If the given question lacks the parameters required by the function, also point it out.
|
||||
""".strip()
|
||||
|
||||
FORMAT_INSTRUCTION = """
|
||||
The output MUST strictly adhere to the following JSON format, and NO other text MUST be included.
|
||||
The example format is as follows. Please make sure the parameter type is correct. If no function call is needed, please make tool_calls an empty list '[]'
|
||||
```
|
||||
{
|
||||
"tool_calls": [
|
||||
{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},
|
||||
... (more tool calls as required)
|
||||
]
|
||||
}
|
||||
```
|
||||
""".strip()
|
||||
|
||||
class PromptAssembler:
|
||||
def __init__(self, model: str):
|
||||
tokenizer = AutoTokenizer.from_pretrained(model)
|
||||
self.chat_template = tokenizer.chat_template
|
||||
|
||||
@staticmethod
|
||||
def apply_chat_template(template, messages):
|
||||
jinja_template = Template(template)
|
||||
return jinja_template.render(messages=messages)
|
||||
|
||||
def assemble_prompt(self, query: str, tools: Dict, task_instruction: str, format_instruction: str):
|
||||
# Convert tools to XLAM format
|
||||
xlam_tools = self.convert_to_xlam_tool(tools)
|
||||
|
||||
# Build the input prompt
|
||||
prompt = self.build_prompt(query, xlam_tools, task_instruction, format_instruction)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
formatted_prompt = self.apply_chat_template(self.chat_template, messages)
|
||||
|
||||
return formatted_prompt
|
||||
|
||||
def convert_to_xlam_tool(self, tools):
|
||||
if isinstance(tools, dict):
|
||||
return {
|
||||
"name": tools["name"],
|
||||
"description": tools["description"],
|
||||
"parameters": {k: v for k, v in tools["parameters"].get("properties", {}).items()}
|
||||
}
|
||||
elif isinstance(tools, list):
|
||||
return [self.convert_to_xlam_tool(tool) for tool in tools]
|
||||
else:
|
||||
return tools
|
||||
|
||||
def build_prompt(self, query, tools, task_instruction=TASK_INSTRUCTION, format_instruction=FORMAT_INSTRUCTION):
|
||||
prompt = f"[BEGIN OF TASK INSTRUCTION]\n{task_instruction}\n[END OF TASK INSTRUCTION]\n\n"
|
||||
prompt += f"[BEGIN OF AVAILABLE TOOLS]\n{json.dumps(tools)}\n[END OF AVAILABLE TOOLS]\n\n"
|
||||
prompt += f"[BEGIN OF FORMAT INSTRUCTION]\n{format_instruction}\n[END OF FORMAT INSTRUCTION]\n\n"
|
||||
prompt += f"[BEGIN OF QUERY]\n{query}\n[END OF QUERY]\n\n"
|
||||
return prompt
|
||||
|
||||
def print_prompt_template(self):
|
||||
template = self.chat_template.replace("{{", "{").replace("}}", "}")
|
||||
print("Prompt Template with Placeholders:")
|
||||
print(template)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Assemble prompts using chat template")
|
||||
parser.add_argument("--model", required=True, help="Name of the model (for chat template)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Initialize the PromptAssembler
|
||||
assembler = PromptAssembler(args.model)
|
||||
|
||||
# Print the prompt template with placeholders
|
||||
assembler.print_prompt_template()
|
||||
|
||||
# Test case 1: Weather API, follows the OpenAI format: https://platform.openai.com/docs/guides/function-calling
|
||||
weather_api = {
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather for a location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The unit of temperature to return"
|
||||
}
|
||||
},
|
||||
"required": ["location"]
|
||||
}
|
||||
}
|
||||
|
||||
# Test queries
|
||||
test_queries = [
|
||||
"What's the weather like in New York?",
|
||||
"Tell me the temperature in London in Celsius",
|
||||
"What's the weather forecast for Tokyo?",
|
||||
"What is the stock price of CRM?", # the model should return an empty list
|
||||
"What's the current temperature in Paris in Fahrenheit?"
|
||||
]
|
||||
|
||||
# Run test cases
|
||||
for query in test_queries:
|
||||
print(f"\nQuery: {query}")
|
||||
formatted_prompt = assembler.assemble_prompt(query, weather_api, TASK_INSTRUCTION, FORMAT_INSTRUCTION)
|
||||
print("Formatted Prompt:")
|
||||
print(formatted_prompt)
|
||||
print("-" * 50)
|
||||
|
||||
# Test case 2: Multiple APIs, follows the OpenAI format: https://platform.openai.com/docs/guides/function-calling
|
||||
calculator_api = {
|
||||
"name": "calculate",
|
||||
"description": "Perform a mathematical calculation",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"operation": {
|
||||
"type": "string",
|
||||
"enum": ["add", "subtract", "multiply", "divide"],
|
||||
"description": "The mathematical operation to perform"
|
||||
},
|
||||
"x": {
|
||||
"type": "number",
|
||||
"description": "The first number"
|
||||
},
|
||||
"y": {
|
||||
"type": "number",
|
||||
"description": "The second number"
|
||||
}
|
||||
},
|
||||
"required": ["operation", "x", "y"]
|
||||
}
|
||||
}
|
||||
|
||||
multi_api_query = "What's the weather in Miami and what's 15 multiplied by 7?"
|
||||
print(f"\nMulti-API Query: {multi_api_query}")
|
||||
multi_api_formatted_prompt = assembler.assemble_prompt(
|
||||
multi_api_query,
|
||||
[weather_api, calculator_api],
|
||||
TASK_INSTRUCTION,
|
||||
FORMAT_INSTRUCTION
|
||||
)
|
||||
print("Formatted Prompt:")
|
||||
print(multi_api_formatted_prompt)
|
||||
188
examples/test_xlam_model_with_endpoint.py
Normal file
188
examples/test_xlam_model_with_endpoint.py
Normal file
@@ -0,0 +1,188 @@
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from openai import OpenAI
|
||||
|
||||
# Default prompts
|
||||
TASK_INSTRUCTION = """
|
||||
You are an expert in composing functions. You are given a question and a set of possible functions.
|
||||
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
|
||||
If none of the functions can be used, point it out and refuse to answer.
|
||||
If the given question lacks the parameters required by the function, also point it out.
|
||||
""".strip()
|
||||
|
||||
FORMAT_INSTRUCTION = """
|
||||
The output MUST strictly adhere to the following JSON format, and NO other text MUST be included.
|
||||
The example format is as follows. Please make sure the parameter type is correct. If no function call is needed, please make tool_calls an empty list '[]'
|
||||
```
|
||||
{
|
||||
"tool_calls": [
|
||||
{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},
|
||||
... (more tool calls as required)
|
||||
]
|
||||
}
|
||||
```
|
||||
""".strip()
|
||||
|
||||
|
||||
class XLAMHandler:
|
||||
def __init__(self, model_name, temperature=0.3, top_p=1, max_tokens=512, port=8000):
|
||||
self.model_name = model_name
|
||||
self.temperature = temperature
|
||||
self.top_p = top_p
|
||||
self.max_tokens = max_tokens
|
||||
base_url = f"http://localhost:{port}/v1"
|
||||
self.client = OpenAI(api_key="Empty", base_url=base_url)
|
||||
|
||||
def process_query(self, query, tools, task_instruction, format_instruction):
|
||||
# Convert tools to XLAM format
|
||||
xlam_tools = self.convert_to_xlam_tool(tools)
|
||||
|
||||
# Build the input prompt
|
||||
prompt = self.build_prompt(query, xlam_tools, task_instruction, format_instruction)
|
||||
|
||||
# Create message for API call
|
||||
message = [{"role": "user", "content": prompt}]
|
||||
|
||||
# Make API call
|
||||
start_time = time.time()
|
||||
response = self.client.chat.completions.create(
|
||||
messages=message,
|
||||
model=self.model_name,
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
top_p=self.top_p,
|
||||
)
|
||||
latency = time.time() - start_time
|
||||
|
||||
# Parse response
|
||||
result = response.choices[0].message.content
|
||||
parsed_result, success, _ = self.parse_response(result)
|
||||
|
||||
# Prepare metadata
|
||||
metadata = {
|
||||
"input_tokens": response.usage.prompt_tokens,
|
||||
"output_tokens": response.usage.completion_tokens,
|
||||
"latency": latency
|
||||
}
|
||||
|
||||
return parsed_result, metadata
|
||||
|
||||
def convert_to_xlam_tool(self, tools):
|
||||
if isinstance(tools, dict):
|
||||
return {
|
||||
"name": tools["name"],
|
||||
"description": tools["description"],
|
||||
"parameters": {k: v for k, v in tools["parameters"].get("properties", {}).items()}
|
||||
}
|
||||
elif isinstance(tools, list):
|
||||
return [self.convert_to_xlam_tool(tool) for tool in tools]
|
||||
else:
|
||||
return tools
|
||||
|
||||
def build_prompt(self, query, tools, task_instruction=TASK_INSTRUCTION, format_instruction=FORMAT_INSTRUCTION):
|
||||
prompt = f"[BEGIN OF TASK INSTRUCTION]\n{task_instruction}\n[END OF TASK INSTRUCTION]\n\n"
|
||||
prompt += f"[BEGIN OF AVAILABLE TOOLS]\n{json.dumps(tools)}\n[END OF AVAILABLE TOOLS]\n\n"
|
||||
prompt += f"[BEGIN OF FORMAT INSTRUCTION]\n{format_instruction}\n[END OF FORMAT INSTRUCTION]\n\n"
|
||||
prompt += f"[BEGIN OF QUERY]\n{query}\n[END OF QUERY]\n\n"
|
||||
return prompt
|
||||
|
||||
def parse_response(self, response):
|
||||
try:
|
||||
data = json.loads(response)
|
||||
tool_calls = data.get('tool_calls', []) if isinstance(data, dict) else data
|
||||
result = [
|
||||
{tool_call['name']: tool_call['arguments']}
|
||||
for tool_call in tool_calls if isinstance(tool_call, dict)
|
||||
]
|
||||
return result, True, []
|
||||
except json.JSONDecodeError:
|
||||
return [], False, ["Failed to parse JSON response"]
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Test XLAM model with endpoint")
|
||||
parser.add_argument("--model_name", default="xlam-1b-fc-r", help="Name of the model")
|
||||
parser.add_argument("--port", type=int, default=8001, help="Port number for the endpoint")
|
||||
parser.add_argument("--temperature", type=float, default=0.3, help="Temperature for sampling")
|
||||
parser.add_argument("--top_p", type=float, default=1.0, help="Top p for sampling")
|
||||
parser.add_argument("--max_tokens", type=int, default=512, help="Maximum number of tokens to generate")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Initialize the XLAMHandler with command-line arguments
|
||||
handler = XLAMHandler(args.model_name, temperature=args.temperature, top_p=args.top_p, max_tokens=args.max_tokens, port=args.port)
|
||||
|
||||
# Test case 1: Weather API, follows the OpenAI format: https://platform.openai.com/docs/guides/function-calling
|
||||
weather_api = {
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather for a location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The unit of temperature to return"
|
||||
}
|
||||
},
|
||||
"required": ["location"]
|
||||
}
|
||||
}
|
||||
|
||||
# Test queries
|
||||
test_queries = [
|
||||
"What's the weather like in New York?",
|
||||
"Tell me the temperature in London in Celsius",
|
||||
"What's the weather forecast for Tokyo?",
|
||||
"What is the stock price of CRM?", # the model should return an empty list, meaning that it refuse to answer this irrelevant query and tools.
|
||||
"What's the current temperature in Paris in Fahrenheit?"
|
||||
]
|
||||
|
||||
# Run test cases
|
||||
for query in test_queries:
|
||||
print(f"Query: {query}")
|
||||
result, metadata = handler.process_query(query, weather_api, TASK_INSTRUCTION, FORMAT_INSTRUCTION)
|
||||
print(f"Result: {json.dumps(result, indent=2)}")
|
||||
print(f"Metadata: {json.dumps(metadata, indent=2)}")
|
||||
print("-" * 50)
|
||||
|
||||
# Test case 2: Multiple APIs, follows the OpenAI format: https://platform.openai.com/docs/guides/function-calling
|
||||
calculator_api = {
|
||||
"name": "calculate",
|
||||
"description": "Perform a mathematical calculation",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"operation": {
|
||||
"type": "string",
|
||||
"enum": ["add", "subtract", "multiply", "divide"],
|
||||
"description": "The mathematical operation to perform"
|
||||
},
|
||||
"x": {
|
||||
"type": "number",
|
||||
"description": "The first number"
|
||||
},
|
||||
"y": {
|
||||
"type": "number",
|
||||
"description": "The second number"
|
||||
}
|
||||
},
|
||||
"required": ["operation", "x", "y"]
|
||||
}
|
||||
}
|
||||
|
||||
multi_api_query = "What's the weather in Miami and what's 15 multiplied by 7?"
|
||||
multi_api_result, multi_api_metadata = handler.process_query(
|
||||
multi_api_query,
|
||||
[weather_api, calculator_api],
|
||||
TASK_INSTRUCTION,
|
||||
FORMAT_INSTRUCTION
|
||||
)
|
||||
|
||||
print("Multi-API Query Result:")
|
||||
print(json.dumps(multi_api_result, indent=2))
|
||||
print(f"Metadata: {json.dumps(multi_api_metadata, indent=2)}")
|
||||
193
examples/test_xlam_model_with_vllm.py
Normal file
193
examples/test_xlam_model_with_vllm.py
Normal file
@@ -0,0 +1,193 @@
|
||||
import json
|
||||
import time
|
||||
import argparse
|
||||
from typing import List, Dict
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from jinja2 import Template
|
||||
|
||||
|
||||
# Default prompts
|
||||
TASK_INSTRUCTION = """
|
||||
You are an expert in composing functions. You are given a question and a set of possible functions.
|
||||
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
|
||||
If none of the functions can be used, point it out and refuse to answer.
|
||||
If the given question lacks the parameters required by the function, also point it out.
|
||||
""".strip()
|
||||
|
||||
FORMAT_INSTRUCTION = """
|
||||
The output MUST strictly adhere to the following JSON format, and NO other text MUST be included.
|
||||
The example format is as follows. Please make sure the parameter type is correct. If no function call is needed, please make tool_calls an empty list '[]'
|
||||
```
|
||||
{
|
||||
"tool_calls": [
|
||||
{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},
|
||||
... (more tool calls as required)
|
||||
]
|
||||
}
|
||||
```
|
||||
""".strip()
|
||||
|
||||
|
||||
class XLAMHandler:
|
||||
def __init__(self, model: str, temperature: float = 0.3, top_p: float = 1, max_tokens: int = 512):
|
||||
self.llm = LLM(model=model)
|
||||
self.sampling_params = SamplingParams(
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_tokens=max_tokens
|
||||
)
|
||||
self.chat_template = self.llm.get_tokenizer().chat_template
|
||||
|
||||
@staticmethod
|
||||
def apply_chat_template(template, messages):
|
||||
jinja_template = Template(template)
|
||||
return jinja_template.render(messages=messages)
|
||||
|
||||
def process_query(self, query: str, tools: Dict, task_instruction: str, format_instruction: str):
|
||||
# Convert tools to XLAM format
|
||||
xlam_tools = self.convert_to_xlam_tool(tools)
|
||||
|
||||
# Build the input prompt
|
||||
prompt = self.build_prompt(query, xlam_tools, task_instruction, format_instruction)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
formatted_prompt = self.apply_chat_template(self.chat_template, messages)
|
||||
|
||||
# Make inference
|
||||
start_time = time.time()
|
||||
outputs = self.llm.generate([formatted_prompt], self.sampling_params)
|
||||
latency = time.time() - start_time
|
||||
|
||||
# Parse response
|
||||
result = outputs[0].outputs[0].text
|
||||
parsed_result, success, _ = self.parse_response(result)
|
||||
|
||||
# Prepare metadata
|
||||
metadata = {
|
||||
"latency": latency,
|
||||
"success": success,
|
||||
}
|
||||
|
||||
return parsed_result, metadata
|
||||
|
||||
def convert_to_xlam_tool(self, tools):
|
||||
if isinstance(tools, dict):
|
||||
return {
|
||||
"name": tools["name"],
|
||||
"description": tools["description"],
|
||||
"parameters": {k: v for k, v in tools["parameters"].get("properties", {}).items()}
|
||||
}
|
||||
elif isinstance(tools, list):
|
||||
return [self.convert_to_xlam_tool(tool) for tool in tools]
|
||||
else:
|
||||
return tools
|
||||
|
||||
def build_prompt(self, query, tools, task_instruction=TASK_INSTRUCTION, format_instruction=FORMAT_INSTRUCTION):
|
||||
prompt = f"[BEGIN OF TASK INSTRUCTION]\n{task_instruction}\n[END OF TASK INSTRUCTION]\n\n"
|
||||
prompt += f"[BEGIN OF AVAILABLE TOOLS]\n{json.dumps(tools)}\n[END OF AVAILABLE TOOLS]\n\n"
|
||||
prompt += f"[BEGIN OF FORMAT INSTRUCTION]\n{format_instruction}\n[END OF FORMAT INSTRUCTION]\n\n"
|
||||
prompt += f"[BEGIN OF QUERY]\n{query}\n[END OF QUERY]\n\n"
|
||||
return prompt
|
||||
|
||||
def parse_response(self, response):
|
||||
try:
|
||||
data = json.loads(response)
|
||||
tool_calls = data.get('tool_calls', []) if isinstance(data, dict) else data
|
||||
result = [
|
||||
{tool_call['name']: tool_call['arguments']}
|
||||
for tool_call in tool_calls if isinstance(tool_call, dict)
|
||||
]
|
||||
return result, True, []
|
||||
except json.JSONDecodeError:
|
||||
return [], False, ["Failed to parse JSON response"]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Test XLAM model with vLLM")
|
||||
parser.add_argument("--model", required=True, help="Path to the model")
|
||||
parser.add_argument("--temperature", type=float, default=0.3, help="Temperature for sampling")
|
||||
parser.add_argument("--top_p", type=float, default=1.0, help="Top p for sampling")
|
||||
parser.add_argument("--max_tokens", type=int, default=512, help="Maximum number of tokens to generate")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Initialize the XLAMHandler with command-line arguments
|
||||
handler = XLAMHandler(args.model, temperature=args.temperature, top_p=args.top_p, max_tokens=args.max_tokens)
|
||||
|
||||
# Test case 1: Weather API, follows the OpenAI format: https://platform.openai.com/docs/guides/function-calling
|
||||
weather_api = {
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather for a location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The unit of temperature to return"
|
||||
}
|
||||
},
|
||||
"required": ["location"]
|
||||
}
|
||||
}
|
||||
|
||||
# Test queries
|
||||
test_queries = [
|
||||
"What's the weather like in New York?",
|
||||
"Tell me the temperature in London in Celsius",
|
||||
"What's the weather forecast for Tokyo?",
|
||||
"What is the stock price of CRM?", # the model should return an empty list
|
||||
"What's the current temperature in Paris in Fahrenheit?"
|
||||
]
|
||||
|
||||
# Run test cases
|
||||
for query in test_queries:
|
||||
print(f"Query: {query}")
|
||||
result, metadata = handler.process_query(query, weather_api, TASK_INSTRUCTION, FORMAT_INSTRUCTION)
|
||||
print(f"Result: {json.dumps(result, indent=2)}")
|
||||
print(f"Metadata: {json.dumps(metadata, indent=2)}")
|
||||
print("-" * 50)
|
||||
|
||||
# Test case 2: Multiple APIs, follows the OpenAI format: https://platform.openai.com/docs/guides/function-calling
|
||||
calculator_api = {
|
||||
"name": "calculate",
|
||||
"description": "Perform a mathematical calculation",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"operation": {
|
||||
"type": "string",
|
||||
"enum": ["add", "subtract", "multiply", "divide"],
|
||||
"description": "The mathematical operation to perform"
|
||||
},
|
||||
"x": {
|
||||
"type": "number",
|
||||
"description": "The first number"
|
||||
},
|
||||
"y": {
|
||||
"type": "number",
|
||||
"description": "The second number"
|
||||
}
|
||||
},
|
||||
"required": ["operation", "x", "y"]
|
||||
}
|
||||
}
|
||||
|
||||
multi_api_query = "What's the weather in Miami and what's 15 multiplied by 7?"
|
||||
multi_api_result, multi_api_metadata = handler.process_query(
|
||||
multi_api_query,
|
||||
[weather_api, calculator_api],
|
||||
TASK_INSTRUCTION,
|
||||
FORMAT_INSTRUCTION
|
||||
)
|
||||
|
||||
print("Multi-API Query Result:")
|
||||
print(json.dumps(multi_api_result, indent=2))
|
||||
print(f"Metadata: {json.dumps(multi_api_metadata, indent=2)}")
|
||||
Reference in New Issue
Block a user