162 lines
6.0 KiB
Python
162 lines
6.0 KiB
Python
|
|
import argparse
|
||
|
|
import json
|
||
|
|
from typing import Dict
|
||
|
|
|
||
|
|
from jinja2 import Template
|
||
|
|
from transformers import AutoTokenizer
|
||
|
|
|
||
|
|
# Default prompts
|
||
|
|
TASK_INSTRUCTION = """
|
||
|
|
You are an expert in composing functions. You are given a question and a set of possible functions.
|
||
|
|
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
|
||
|
|
If none of the functions can be used, point it out and refuse to answer.
|
||
|
|
If the given question lacks the parameters required by the function, also point it out.
|
||
|
|
""".strip()
|
||
|
|
|
||
|
|
FORMAT_INSTRUCTION = """
|
||
|
|
The output MUST strictly adhere to the following JSON format, and NO other text MUST be included.
|
||
|
|
The example format is as follows. Please make sure the parameter type is correct. If no function call is needed, please make tool_calls an empty list '[]'
|
||
|
|
```
|
||
|
|
{
|
||
|
|
"tool_calls": [
|
||
|
|
{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},
|
||
|
|
... (more tool calls as required)
|
||
|
|
]
|
||
|
|
}
|
||
|
|
```
|
||
|
|
""".strip()
|
||
|
|
|
||
|
|
class PromptAssembler:
|
||
|
|
def __init__(self, model: str):
|
||
|
|
tokenizer = AutoTokenizer.from_pretrained(model)
|
||
|
|
self.chat_template = tokenizer.chat_template
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def apply_chat_template(template, messages):
|
||
|
|
jinja_template = Template(template)
|
||
|
|
return jinja_template.render(messages=messages)
|
||
|
|
|
||
|
|
def assemble_prompt(self, query: str, tools: Dict, task_instruction: str, format_instruction: str):
|
||
|
|
# Convert tools to XLAM format
|
||
|
|
xlam_tools = self.convert_to_xlam_tool(tools)
|
||
|
|
|
||
|
|
# Build the input prompt
|
||
|
|
prompt = self.build_prompt(query, xlam_tools, task_instruction, format_instruction)
|
||
|
|
|
||
|
|
messages = [
|
||
|
|
{"role": "user", "content": prompt}
|
||
|
|
]
|
||
|
|
formatted_prompt = self.apply_chat_template(self.chat_template, messages)
|
||
|
|
|
||
|
|
return formatted_prompt
|
||
|
|
|
||
|
|
def convert_to_xlam_tool(self, tools):
|
||
|
|
if isinstance(tools, dict):
|
||
|
|
return {
|
||
|
|
"name": tools["name"],
|
||
|
|
"description": tools["description"],
|
||
|
|
"parameters": {k: v for k, v in tools["parameters"].get("properties", {}).items()}
|
||
|
|
}
|
||
|
|
elif isinstance(tools, list):
|
||
|
|
return [self.convert_to_xlam_tool(tool) for tool in tools]
|
||
|
|
else:
|
||
|
|
return tools
|
||
|
|
|
||
|
|
def build_prompt(self, query, tools, task_instruction=TASK_INSTRUCTION, format_instruction=FORMAT_INSTRUCTION):
|
||
|
|
prompt = f"[BEGIN OF TASK INSTRUCTION]\n{task_instruction}\n[END OF TASK INSTRUCTION]\n\n"
|
||
|
|
prompt += f"[BEGIN OF AVAILABLE TOOLS]\n{json.dumps(tools)}\n[END OF AVAILABLE TOOLS]\n\n"
|
||
|
|
prompt += f"[BEGIN OF FORMAT INSTRUCTION]\n{format_instruction}\n[END OF FORMAT INSTRUCTION]\n\n"
|
||
|
|
prompt += f"[BEGIN OF QUERY]\n{query}\n[END OF QUERY]\n\n"
|
||
|
|
return prompt
|
||
|
|
|
||
|
|
def print_prompt_template(self):
|
||
|
|
template = self.chat_template.replace("{{", "{").replace("}}", "}")
|
||
|
|
print("Prompt Template with Placeholders:")
|
||
|
|
print(template)
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
parser = argparse.ArgumentParser(description="Assemble prompts using chat template")
|
||
|
|
parser.add_argument("--model", required=True, help="Name of the model (for chat template)")
|
||
|
|
|
||
|
|
args = parser.parse_args()
|
||
|
|
|
||
|
|
# Initialize the PromptAssembler
|
||
|
|
assembler = PromptAssembler(args.model)
|
||
|
|
|
||
|
|
# Print the prompt template with placeholders
|
||
|
|
assembler.print_prompt_template()
|
||
|
|
|
||
|
|
# Test case 1: Weather API, follows the OpenAI format: https://platform.openai.com/docs/guides/function-calling
|
||
|
|
weather_api = {
|
||
|
|
"name": "get_weather",
|
||
|
|
"description": "Get the current weather for a location",
|
||
|
|
"parameters": {
|
||
|
|
"type": "object",
|
||
|
|
"properties": {
|
||
|
|
"location": {
|
||
|
|
"type": "string",
|
||
|
|
"description": "The city and state, e.g. San Francisco, CA"
|
||
|
|
},
|
||
|
|
"unit": {
|
||
|
|
"type": "string",
|
||
|
|
"enum": ["celsius", "fahrenheit"],
|
||
|
|
"description": "The unit of temperature to return"
|
||
|
|
}
|
||
|
|
},
|
||
|
|
"required": ["location"]
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
# Test queries
|
||
|
|
test_queries = [
|
||
|
|
"What's the weather like in New York?",
|
||
|
|
"Tell me the temperature in London in Celsius",
|
||
|
|
"What's the weather forecast for Tokyo?",
|
||
|
|
"What is the stock price of CRM?", # the model should return an empty list
|
||
|
|
"What's the current temperature in Paris in Fahrenheit?"
|
||
|
|
]
|
||
|
|
|
||
|
|
# Run test cases
|
||
|
|
for query in test_queries:
|
||
|
|
print(f"\nQuery: {query}")
|
||
|
|
formatted_prompt = assembler.assemble_prompt(query, weather_api, TASK_INSTRUCTION, FORMAT_INSTRUCTION)
|
||
|
|
print("Formatted Prompt:")
|
||
|
|
print(formatted_prompt)
|
||
|
|
print("-" * 50)
|
||
|
|
|
||
|
|
# Test case 2: Multiple APIs, follows the OpenAI format: https://platform.openai.com/docs/guides/function-calling
|
||
|
|
calculator_api = {
|
||
|
|
"name": "calculate",
|
||
|
|
"description": "Perform a mathematical calculation",
|
||
|
|
"parameters": {
|
||
|
|
"type": "object",
|
||
|
|
"properties": {
|
||
|
|
"operation": {
|
||
|
|
"type": "string",
|
||
|
|
"enum": ["add", "subtract", "multiply", "divide"],
|
||
|
|
"description": "The mathematical operation to perform"
|
||
|
|
},
|
||
|
|
"x": {
|
||
|
|
"type": "number",
|
||
|
|
"description": "The first number"
|
||
|
|
},
|
||
|
|
"y": {
|
||
|
|
"type": "number",
|
||
|
|
"description": "The second number"
|
||
|
|
}
|
||
|
|
},
|
||
|
|
"required": ["operation", "x", "y"]
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
multi_api_query = "What's the weather in Miami and what's 15 multiplied by 7?"
|
||
|
|
print(f"\nMulti-API Query: {multi_api_query}")
|
||
|
|
multi_api_formatted_prompt = assembler.assemble_prompt(
|
||
|
|
multi_api_query,
|
||
|
|
[weather_api, calculator_api],
|
||
|
|
TASK_INSTRUCTION,
|
||
|
|
FORMAT_INSTRUCTION
|
||
|
|
)
|
||
|
|
print("Formatted Prompt:")
|
||
|
|
print(multi_api_formatted_prompt)
|