初始化项目,由ModelHub XC社区提供模型
Model: Salesforce/xLAM-1b-fc-r Source: Original Platform
This commit is contained in:
161
examples/test_prompt_template.py
Normal file
161
examples/test_prompt_template.py
Normal file
@@ -0,0 +1,161 @@
|
||||
import argparse
|
||||
import json
|
||||
from typing import Dict
|
||||
|
||||
from jinja2 import Template
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# Default prompts
|
||||
TASK_INSTRUCTION = """
|
||||
You are an expert in composing functions. You are given a question and a set of possible functions.
|
||||
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
|
||||
If none of the functions can be used, point it out and refuse to answer.
|
||||
If the given question lacks the parameters required by the function, also point it out.
|
||||
""".strip()
|
||||
|
||||
FORMAT_INSTRUCTION = """
|
||||
The output MUST strictly adhere to the following JSON format, and NO other text MUST be included.
|
||||
The example format is as follows. Please make sure the parameter type is correct. If no function call is needed, please make tool_calls an empty list '[]'
|
||||
```
|
||||
{
|
||||
"tool_calls": [
|
||||
{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},
|
||||
... (more tool calls as required)
|
||||
]
|
||||
}
|
||||
```
|
||||
""".strip()
|
||||
|
||||
class PromptAssembler:
|
||||
def __init__(self, model: str):
|
||||
tokenizer = AutoTokenizer.from_pretrained(model)
|
||||
self.chat_template = tokenizer.chat_template
|
||||
|
||||
@staticmethod
|
||||
def apply_chat_template(template, messages):
|
||||
jinja_template = Template(template)
|
||||
return jinja_template.render(messages=messages)
|
||||
|
||||
def assemble_prompt(self, query: str, tools: Dict, task_instruction: str, format_instruction: str):
|
||||
# Convert tools to XLAM format
|
||||
xlam_tools = self.convert_to_xlam_tool(tools)
|
||||
|
||||
# Build the input prompt
|
||||
prompt = self.build_prompt(query, xlam_tools, task_instruction, format_instruction)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
formatted_prompt = self.apply_chat_template(self.chat_template, messages)
|
||||
|
||||
return formatted_prompt
|
||||
|
||||
def convert_to_xlam_tool(self, tools):
|
||||
if isinstance(tools, dict):
|
||||
return {
|
||||
"name": tools["name"],
|
||||
"description": tools["description"],
|
||||
"parameters": {k: v for k, v in tools["parameters"].get("properties", {}).items()}
|
||||
}
|
||||
elif isinstance(tools, list):
|
||||
return [self.convert_to_xlam_tool(tool) for tool in tools]
|
||||
else:
|
||||
return tools
|
||||
|
||||
def build_prompt(self, query, tools, task_instruction=TASK_INSTRUCTION, format_instruction=FORMAT_INSTRUCTION):
|
||||
prompt = f"[BEGIN OF TASK INSTRUCTION]\n{task_instruction}\n[END OF TASK INSTRUCTION]\n\n"
|
||||
prompt += f"[BEGIN OF AVAILABLE TOOLS]\n{json.dumps(tools)}\n[END OF AVAILABLE TOOLS]\n\n"
|
||||
prompt += f"[BEGIN OF FORMAT INSTRUCTION]\n{format_instruction}\n[END OF FORMAT INSTRUCTION]\n\n"
|
||||
prompt += f"[BEGIN OF QUERY]\n{query}\n[END OF QUERY]\n\n"
|
||||
return prompt
|
||||
|
||||
def print_prompt_template(self):
|
||||
template = self.chat_template.replace("{{", "{").replace("}}", "}")
|
||||
print("Prompt Template with Placeholders:")
|
||||
print(template)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Assemble prompts using chat template")
|
||||
parser.add_argument("--model", required=True, help="Name of the model (for chat template)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Initialize the PromptAssembler
|
||||
assembler = PromptAssembler(args.model)
|
||||
|
||||
# Print the prompt template with placeholders
|
||||
assembler.print_prompt_template()
|
||||
|
||||
# Test case 1: Weather API, follows the OpenAI format: https://platform.openai.com/docs/guides/function-calling
|
||||
weather_api = {
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather for a location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The unit of temperature to return"
|
||||
}
|
||||
},
|
||||
"required": ["location"]
|
||||
}
|
||||
}
|
||||
|
||||
# Test queries
|
||||
test_queries = [
|
||||
"What's the weather like in New York?",
|
||||
"Tell me the temperature in London in Celsius",
|
||||
"What's the weather forecast for Tokyo?",
|
||||
"What is the stock price of CRM?", # the model should return an empty list
|
||||
"What's the current temperature in Paris in Fahrenheit?"
|
||||
]
|
||||
|
||||
# Run test cases
|
||||
for query in test_queries:
|
||||
print(f"\nQuery: {query}")
|
||||
formatted_prompt = assembler.assemble_prompt(query, weather_api, TASK_INSTRUCTION, FORMAT_INSTRUCTION)
|
||||
print("Formatted Prompt:")
|
||||
print(formatted_prompt)
|
||||
print("-" * 50)
|
||||
|
||||
# Test case 2: Multiple APIs, follows the OpenAI format: https://platform.openai.com/docs/guides/function-calling
|
||||
calculator_api = {
|
||||
"name": "calculate",
|
||||
"description": "Perform a mathematical calculation",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"operation": {
|
||||
"type": "string",
|
||||
"enum": ["add", "subtract", "multiply", "divide"],
|
||||
"description": "The mathematical operation to perform"
|
||||
},
|
||||
"x": {
|
||||
"type": "number",
|
||||
"description": "The first number"
|
||||
},
|
||||
"y": {
|
||||
"type": "number",
|
||||
"description": "The second number"
|
||||
}
|
||||
},
|
||||
"required": ["operation", "x", "y"]
|
||||
}
|
||||
}
|
||||
|
||||
multi_api_query = "What's the weather in Miami and what's 15 multiplied by 7?"
|
||||
print(f"\nMulti-API Query: {multi_api_query}")
|
||||
multi_api_formatted_prompt = assembler.assemble_prompt(
|
||||
multi_api_query,
|
||||
[weather_api, calculator_api],
|
||||
TASK_INSTRUCTION,
|
||||
FORMAT_INSTRUCTION
|
||||
)
|
||||
print("Formatted Prompt:")
|
||||
print(multi_api_formatted_prompt)
|
||||
Reference in New Issue
Block a user