初始化项目,由ModelHub XC社区提供模型
Model: Salesforce/xLAM-1b-fc-r Source: Original Platform
This commit is contained in:
193
examples/test_xlam_model_with_vllm.py
Normal file
193
examples/test_xlam_model_with_vllm.py
Normal file
@@ -0,0 +1,193 @@
|
||||
import json
|
||||
import time
|
||||
import argparse
|
||||
from typing import List, Dict
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from jinja2 import Template
|
||||
|
||||
|
||||
# Default prompts
|
||||
TASK_INSTRUCTION = """
|
||||
You are an expert in composing functions. You are given a question and a set of possible functions.
|
||||
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
|
||||
If none of the functions can be used, point it out and refuse to answer.
|
||||
If the given question lacks the parameters required by the function, also point it out.
|
||||
""".strip()
|
||||
|
||||
FORMAT_INSTRUCTION = """
|
||||
The output MUST strictly adhere to the following JSON format, and NO other text MUST be included.
|
||||
The example format is as follows. Please make sure the parameter type is correct. If no function call is needed, please make tool_calls an empty list '[]'
|
||||
```
|
||||
{
|
||||
"tool_calls": [
|
||||
{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},
|
||||
... (more tool calls as required)
|
||||
]
|
||||
}
|
||||
```
|
||||
""".strip()
|
||||
|
||||
|
||||
class XLAMHandler:
|
||||
def __init__(self, model: str, temperature: float = 0.3, top_p: float = 1, max_tokens: int = 512):
|
||||
self.llm = LLM(model=model)
|
||||
self.sampling_params = SamplingParams(
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_tokens=max_tokens
|
||||
)
|
||||
self.chat_template = self.llm.get_tokenizer().chat_template
|
||||
|
||||
@staticmethod
|
||||
def apply_chat_template(template, messages):
|
||||
jinja_template = Template(template)
|
||||
return jinja_template.render(messages=messages)
|
||||
|
||||
def process_query(self, query: str, tools: Dict, task_instruction: str, format_instruction: str):
|
||||
# Convert tools to XLAM format
|
||||
xlam_tools = self.convert_to_xlam_tool(tools)
|
||||
|
||||
# Build the input prompt
|
||||
prompt = self.build_prompt(query, xlam_tools, task_instruction, format_instruction)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
formatted_prompt = self.apply_chat_template(self.chat_template, messages)
|
||||
|
||||
# Make inference
|
||||
start_time = time.time()
|
||||
outputs = self.llm.generate([formatted_prompt], self.sampling_params)
|
||||
latency = time.time() - start_time
|
||||
|
||||
# Parse response
|
||||
result = outputs[0].outputs[0].text
|
||||
parsed_result, success, _ = self.parse_response(result)
|
||||
|
||||
# Prepare metadata
|
||||
metadata = {
|
||||
"latency": latency,
|
||||
"success": success,
|
||||
}
|
||||
|
||||
return parsed_result, metadata
|
||||
|
||||
def convert_to_xlam_tool(self, tools):
|
||||
if isinstance(tools, dict):
|
||||
return {
|
||||
"name": tools["name"],
|
||||
"description": tools["description"],
|
||||
"parameters": {k: v for k, v in tools["parameters"].get("properties", {}).items()}
|
||||
}
|
||||
elif isinstance(tools, list):
|
||||
return [self.convert_to_xlam_tool(tool) for tool in tools]
|
||||
else:
|
||||
return tools
|
||||
|
||||
def build_prompt(self, query, tools, task_instruction=TASK_INSTRUCTION, format_instruction=FORMAT_INSTRUCTION):
|
||||
prompt = f"[BEGIN OF TASK INSTRUCTION]\n{task_instruction}\n[END OF TASK INSTRUCTION]\n\n"
|
||||
prompt += f"[BEGIN OF AVAILABLE TOOLS]\n{json.dumps(tools)}\n[END OF AVAILABLE TOOLS]\n\n"
|
||||
prompt += f"[BEGIN OF FORMAT INSTRUCTION]\n{format_instruction}\n[END OF FORMAT INSTRUCTION]\n\n"
|
||||
prompt += f"[BEGIN OF QUERY]\n{query}\n[END OF QUERY]\n\n"
|
||||
return prompt
|
||||
|
||||
def parse_response(self, response):
|
||||
try:
|
||||
data = json.loads(response)
|
||||
tool_calls = data.get('tool_calls', []) if isinstance(data, dict) else data
|
||||
result = [
|
||||
{tool_call['name']: tool_call['arguments']}
|
||||
for tool_call in tool_calls if isinstance(tool_call, dict)
|
||||
]
|
||||
return result, True, []
|
||||
except json.JSONDecodeError:
|
||||
return [], False, ["Failed to parse JSON response"]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Test XLAM model with vLLM")
|
||||
parser.add_argument("--model", required=True, help="Path to the model")
|
||||
parser.add_argument("--temperature", type=float, default=0.3, help="Temperature for sampling")
|
||||
parser.add_argument("--top_p", type=float, default=1.0, help="Top p for sampling")
|
||||
parser.add_argument("--max_tokens", type=int, default=512, help="Maximum number of tokens to generate")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Initialize the XLAMHandler with command-line arguments
|
||||
handler = XLAMHandler(args.model, temperature=args.temperature, top_p=args.top_p, max_tokens=args.max_tokens)
|
||||
|
||||
# Test case 1: Weather API, follows the OpenAI format: https://platform.openai.com/docs/guides/function-calling
|
||||
weather_api = {
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather for a location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The unit of temperature to return"
|
||||
}
|
||||
},
|
||||
"required": ["location"]
|
||||
}
|
||||
}
|
||||
|
||||
# Test queries
|
||||
test_queries = [
|
||||
"What's the weather like in New York?",
|
||||
"Tell me the temperature in London in Celsius",
|
||||
"What's the weather forecast for Tokyo?",
|
||||
"What is the stock price of CRM?", # the model should return an empty list
|
||||
"What's the current temperature in Paris in Fahrenheit?"
|
||||
]
|
||||
|
||||
# Run test cases
|
||||
for query in test_queries:
|
||||
print(f"Query: {query}")
|
||||
result, metadata = handler.process_query(query, weather_api, TASK_INSTRUCTION, FORMAT_INSTRUCTION)
|
||||
print(f"Result: {json.dumps(result, indent=2)}")
|
||||
print(f"Metadata: {json.dumps(metadata, indent=2)}")
|
||||
print("-" * 50)
|
||||
|
||||
# Test case 2: Multiple APIs, follows the OpenAI format: https://platform.openai.com/docs/guides/function-calling
|
||||
calculator_api = {
|
||||
"name": "calculate",
|
||||
"description": "Perform a mathematical calculation",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"operation": {
|
||||
"type": "string",
|
||||
"enum": ["add", "subtract", "multiply", "divide"],
|
||||
"description": "The mathematical operation to perform"
|
||||
},
|
||||
"x": {
|
||||
"type": "number",
|
||||
"description": "The first number"
|
||||
},
|
||||
"y": {
|
||||
"type": "number",
|
||||
"description": "The second number"
|
||||
}
|
||||
},
|
||||
"required": ["operation", "x", "y"]
|
||||
}
|
||||
}
|
||||
|
||||
multi_api_query = "What's the weather in Miami and what's 15 multiplied by 7?"
|
||||
multi_api_result, multi_api_metadata = handler.process_query(
|
||||
multi_api_query,
|
||||
[weather_api, calculator_api],
|
||||
TASK_INSTRUCTION,
|
||||
FORMAT_INSTRUCTION
|
||||
)
|
||||
|
||||
print("Multi-API Query Result:")
|
||||
print(json.dumps(multi_api_result, indent=2))
|
||||
print(f"Metadata: {json.dumps(multi_api_metadata, indent=2)}")
|
||||
Reference in New Issue
Block a user