194 lines
7.3 KiB
Python
194 lines
7.3 KiB
Python
import json
|
|
import time
|
|
import argparse
|
|
from typing import List, Dict
|
|
|
|
from vllm import LLM, SamplingParams
|
|
from jinja2 import Template
|
|
|
|
|
|
# Default prompts
|
|
TASK_INSTRUCTION = """
|
|
You are an expert in composing functions. You are given a question and a set of possible functions.
|
|
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
|
|
If none of the functions can be used, point it out and refuse to answer.
|
|
If the given question lacks the parameters required by the function, also point it out.
|
|
""".strip()
|
|
|
|
FORMAT_INSTRUCTION = """
|
|
The output MUST strictly adhere to the following JSON format, and NO other text MUST be included.
|
|
The example format is as follows. Please make sure the parameter type is correct. If no function call is needed, please make tool_calls an empty list '[]'
|
|
```
|
|
{
|
|
"tool_calls": [
|
|
{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},
|
|
... (more tool calls as required)
|
|
]
|
|
}
|
|
```
|
|
""".strip()
|
|
|
|
|
|
class XLAMHandler:
|
|
def __init__(self, model: str, temperature: float = 0.3, top_p: float = 1, max_tokens: int = 512):
|
|
self.llm = LLM(model=model)
|
|
self.sampling_params = SamplingParams(
|
|
temperature=temperature,
|
|
top_p=top_p,
|
|
max_tokens=max_tokens
|
|
)
|
|
self.chat_template = self.llm.get_tokenizer().chat_template
|
|
|
|
@staticmethod
|
|
def apply_chat_template(template, messages):
|
|
jinja_template = Template(template)
|
|
return jinja_template.render(messages=messages)
|
|
|
|
def process_query(self, query: str, tools: Dict, task_instruction: str, format_instruction: str):
|
|
# Convert tools to XLAM format
|
|
xlam_tools = self.convert_to_xlam_tool(tools)
|
|
|
|
# Build the input prompt
|
|
prompt = self.build_prompt(query, xlam_tools, task_instruction, format_instruction)
|
|
|
|
messages = [
|
|
{"role": "user", "content": prompt}
|
|
]
|
|
formatted_prompt = self.apply_chat_template(self.chat_template, messages)
|
|
|
|
# Make inference
|
|
start_time = time.time()
|
|
outputs = self.llm.generate([formatted_prompt], self.sampling_params)
|
|
latency = time.time() - start_time
|
|
|
|
# Parse response
|
|
result = outputs[0].outputs[0].text
|
|
parsed_result, success, _ = self.parse_response(result)
|
|
|
|
# Prepare metadata
|
|
metadata = {
|
|
"latency": latency,
|
|
"success": success,
|
|
}
|
|
|
|
return parsed_result, metadata
|
|
|
|
def convert_to_xlam_tool(self, tools):
|
|
if isinstance(tools, dict):
|
|
return {
|
|
"name": tools["name"],
|
|
"description": tools["description"],
|
|
"parameters": {k: v for k, v in tools["parameters"].get("properties", {}).items()}
|
|
}
|
|
elif isinstance(tools, list):
|
|
return [self.convert_to_xlam_tool(tool) for tool in tools]
|
|
else:
|
|
return tools
|
|
|
|
def build_prompt(self, query, tools, task_instruction=TASK_INSTRUCTION, format_instruction=FORMAT_INSTRUCTION):
|
|
prompt = f"[BEGIN OF TASK INSTRUCTION]\n{task_instruction}\n[END OF TASK INSTRUCTION]\n\n"
|
|
prompt += f"[BEGIN OF AVAILABLE TOOLS]\n{json.dumps(tools)}\n[END OF AVAILABLE TOOLS]\n\n"
|
|
prompt += f"[BEGIN OF FORMAT INSTRUCTION]\n{format_instruction}\n[END OF FORMAT INSTRUCTION]\n\n"
|
|
prompt += f"[BEGIN OF QUERY]\n{query}\n[END OF QUERY]\n\n"
|
|
return prompt
|
|
|
|
def parse_response(self, response):
|
|
try:
|
|
data = json.loads(response)
|
|
tool_calls = data.get('tool_calls', []) if isinstance(data, dict) else data
|
|
result = [
|
|
{tool_call['name']: tool_call['arguments']}
|
|
for tool_call in tool_calls if isinstance(tool_call, dict)
|
|
]
|
|
return result, True, []
|
|
except json.JSONDecodeError:
|
|
return [], False, ["Failed to parse JSON response"]
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="Test XLAM model with vLLM")
|
|
parser.add_argument("--model", required=True, help="Path to the model")
|
|
parser.add_argument("--temperature", type=float, default=0.3, help="Temperature for sampling")
|
|
parser.add_argument("--top_p", type=float, default=1.0, help="Top p for sampling")
|
|
parser.add_argument("--max_tokens", type=int, default=512, help="Maximum number of tokens to generate")
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Initialize the XLAMHandler with command-line arguments
|
|
handler = XLAMHandler(args.model, temperature=args.temperature, top_p=args.top_p, max_tokens=args.max_tokens)
|
|
|
|
# Test case 1: Weather API, follows the OpenAI format: https://platform.openai.com/docs/guides/function-calling
|
|
weather_api = {
|
|
"name": "get_weather",
|
|
"description": "Get the current weather for a location",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"location": {
|
|
"type": "string",
|
|
"description": "The city and state, e.g. San Francisco, CA"
|
|
},
|
|
"unit": {
|
|
"type": "string",
|
|
"enum": ["celsius", "fahrenheit"],
|
|
"description": "The unit of temperature to return"
|
|
}
|
|
},
|
|
"required": ["location"]
|
|
}
|
|
}
|
|
|
|
# Test queries
|
|
test_queries = [
|
|
"What's the weather like in New York?",
|
|
"Tell me the temperature in London in Celsius",
|
|
"What's the weather forecast for Tokyo?",
|
|
"What is the stock price of CRM?", # the model should return an empty list
|
|
"What's the current temperature in Paris in Fahrenheit?"
|
|
]
|
|
|
|
# Run test cases
|
|
for query in test_queries:
|
|
print(f"Query: {query}")
|
|
result, metadata = handler.process_query(query, weather_api, TASK_INSTRUCTION, FORMAT_INSTRUCTION)
|
|
print(f"Result: {json.dumps(result, indent=2)}")
|
|
print(f"Metadata: {json.dumps(metadata, indent=2)}")
|
|
print("-" * 50)
|
|
|
|
# Test case 2: Multiple APIs, follows the OpenAI format: https://platform.openai.com/docs/guides/function-calling
|
|
calculator_api = {
|
|
"name": "calculate",
|
|
"description": "Perform a mathematical calculation",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"operation": {
|
|
"type": "string",
|
|
"enum": ["add", "subtract", "multiply", "divide"],
|
|
"description": "The mathematical operation to perform"
|
|
},
|
|
"x": {
|
|
"type": "number",
|
|
"description": "The first number"
|
|
},
|
|
"y": {
|
|
"type": "number",
|
|
"description": "The second number"
|
|
}
|
|
},
|
|
"required": ["operation", "x", "y"]
|
|
}
|
|
}
|
|
|
|
multi_api_query = "What's the weather in Miami and what's 15 multiplied by 7?"
|
|
multi_api_result, multi_api_metadata = handler.process_query(
|
|
multi_api_query,
|
|
[weather_api, calculator_api],
|
|
TASK_INSTRUCTION,
|
|
FORMAT_INSTRUCTION
|
|
)
|
|
|
|
print("Multi-API Query Result:")
|
|
print(json.dumps(multi_api_result, indent=2))
|
|
print(f"Metadata: {json.dumps(multi_api_metadata, indent=2)}")
|