392 lines
15 KiB
Markdown
392 lines
15 KiB
Markdown
|
|
---
|
||
|
|
base_model: meta-llama/Llama-3.2-1B-Instruct
|
||
|
|
datasets:
|
||
|
|
- argilla-warehouse/apigen-synth-trl
|
||
|
|
library_name: transformers
|
||
|
|
model_name: Llama-3.2-1B-Instruct-APIGen-FC-v0.1
|
||
|
|
tags:
|
||
|
|
- generated_from_trainer
|
||
|
|
- trl
|
||
|
|
- sft
|
||
|
|
licence: license
|
||
|
|
license: apache-2.0
|
||
|
|
language:
|
||
|
|
- en
|
||
|
|
---
|
||
|
|
|
||
|
|
# Model Card for Llama-3.2-1B-Instruct-APIGen-FC-v0.1
|
||
|
|
|
||
|
|
This model is a fine-tuned version of [meta-llama/Llama-3.2-1B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct) on
|
||
|
|
[argilla-warehouse/apigen-synth-trl](https://huggingface.co/datasets/argilla-warehouse/apigen-synth-trl) dataset, a version of
|
||
|
|
[argilla/Synth-APIGen-v0.1](https://huggingface.co/datasets/argilla-warehouse/Synth-APIGen-v0.1) ready to do SFT on top of it.
|
||
|
|
It has been trained using [TRL](https://github.com/huggingface/trl).
|
||
|
|
|
||
|
|
## Quick start
|
||
|
|
|
||
|
|
This is a Fine tuned version of `Llama-3.2-1B-Instruct` model specific for Function Calling, to showcase how to fine tune a model on top of a dataset
|
||
|
|
like [argilla/Synth-APIGen-v0.1](https://huggingface.co/datasets/argilla/Synth-APIGen-v0.1).
|
||
|
|
|
||
|
|
### Helper functions for the prompt and output parsing
|
||
|
|
|
||
|
|
<details><summary> Click to see helper functions </summary>
|
||
|
|
|
||
|
|
````python
|
||
|
|
from typing import Optional
|
||
|
|
import re
|
||
|
|
import json
|
||
|
|
|
||
|
|
from jinja2 import Template
|
||
|
|
|
||
|
|
SYSTEM_PROMPT = """
|
||
|
|
You are an expert in composing functions. You are given a question and a set of possible functions.
|
||
|
|
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
|
||
|
|
If none of the functions can be used, point it out and refuse to answer.
|
||
|
|
If the given question lacks the parameters required by the function, also point it out.
|
||
|
|
|
||
|
|
The output MUST strictly adhere to the following format, and NO other text MUST be included.
|
||
|
|
The example format is as follows. Please make sure the parameter type is correct. If no function call is needed, please make the tool calls an empty list '[]'.
|
||
|
|
```
|
||
|
|
<tool_call>[
|
||
|
|
{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},
|
||
|
|
... (more tool calls as required)
|
||
|
|
]</tool_call>
|
||
|
|
```
|
||
|
|
""".strip()
|
||
|
|
|
||
|
|
prompt = Template("""
|
||
|
|
You have access to the following tools:
|
||
|
|
<tools>{{ tools }}</tools>
|
||
|
|
|
||
|
|
Please answer the following query:
|
||
|
|
{{ query }}
|
||
|
|
""".lstrip())
|
||
|
|
|
||
|
|
def prepare_messages(
|
||
|
|
query: str,
|
||
|
|
tools: Optional[dict[str, any]] = None,
|
||
|
|
conversation_history: Optional[list[dict[str, str]]] = None
|
||
|
|
) -> list[dict[str, str]]:
|
||
|
|
"""Prepare the system and user messages for the given query and tools.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
query: The query to be answered.
|
||
|
|
tools: The tools available to the user. Defaults to None, in which case if a
|
||
|
|
list without content will be passed to the model.
|
||
|
|
conversation_history: Exchange of messages, including the system_prompt from
|
||
|
|
the first query. Defaults to None, the first message in a conversation.
|
||
|
|
"""
|
||
|
|
if tools is None:
|
||
|
|
tools = []
|
||
|
|
|
||
|
|
if conversation_history:
|
||
|
|
messages = conversation_history.copy()
|
||
|
|
messages.append({"role": "user", "content": query})
|
||
|
|
else:
|
||
|
|
messages = [
|
||
|
|
{"role": "system", "content": system_prompt},
|
||
|
|
{"role": "user", "content": prompt.render(tools=json.dumps(tools), query=query)}
|
||
|
|
]
|
||
|
|
|
||
|
|
return messages
|
||
|
|
|
||
|
|
|
||
|
|
def parse_response(text: str) -> str | dict[str, any]:
|
||
|
|
"""Parses a response from the model, returning either the
|
||
|
|
parsed list with the tool calls parsed, or the
|
||
|
|
model thought or response if couldn't generate one.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
text: Response from the model.
|
||
|
|
"""
|
||
|
|
pattern = r"<tool_call>(.*?)</tool_call>"
|
||
|
|
matches = re.findall(pattern, text, re.DOTALL)
|
||
|
|
if matches:
|
||
|
|
return json.loads(matches[0])
|
||
|
|
return text
|
||
|
|
|
||
|
|
````
|
||
|
|
|
||
|
|
</details>
|
||
|
|
|
||
|
|
### Examples
|
||
|
|
|
||
|
|
The following examples show how to use the model with transformers, for different types of queries and depending on the availability of tools.
|
||
|
|
|
||
|
|
|
||
|
|
Example of *simple* function call:
|
||
|
|
|
||
|
|
````python
|
||
|
|
import torch
|
||
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||
|
|
|
||
|
|
model_name = "argilla-warehouse/Llama-3.2-1B-Instruct-APIGen-FC-v0.1"
|
||
|
|
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype="auto", trust_remote_code=True)
|
||
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||
|
|
|
||
|
|
|
||
|
|
get_weather_api = {
|
||
|
|
"name": "get_weather",
|
||
|
|
"description": "Get the current weather for a location",
|
||
|
|
"parameters": {
|
||
|
|
"type": "object",
|
||
|
|
"properties": {
|
||
|
|
"location": {
|
||
|
|
"type": "string",
|
||
|
|
"description": "The city and state, e.g. San Francisco, New York"
|
||
|
|
},
|
||
|
|
"unit": {
|
||
|
|
"type": "string",
|
||
|
|
"enum": ["celsius", "fahrenheit"],
|
||
|
|
"description": "The unit of temperature to return"
|
||
|
|
}
|
||
|
|
},
|
||
|
|
"required": ["location"]
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
search_api = {
|
||
|
|
"name": "search",
|
||
|
|
"description": "Search for information on the internet",
|
||
|
|
"parameters": {
|
||
|
|
"type": "object",
|
||
|
|
"properties": {
|
||
|
|
"query": {
|
||
|
|
"type": "string",
|
||
|
|
"description": "The search query, e.g. 'latest news on AI'"
|
||
|
|
}
|
||
|
|
},
|
||
|
|
"required": ["query"]
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
available_tools = [get_weather_api, search_api]
|
||
|
|
|
||
|
|
query = "What's the weather like in New York in fahrenheit?"
|
||
|
|
|
||
|
|
messages = prepare_messages(query, tools=available_tools)
|
||
|
|
|
||
|
|
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
|
||
|
|
|
||
|
|
outputs = model.generate(inputs, max_new_tokens=512, do_sample=False, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)
|
||
|
|
result = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=False)
|
||
|
|
|
||
|
|
response = parse_response(result)
|
||
|
|
# [{'name': 'get_weather', 'arguments': {'location': 'New York', 'unit': 'fahrenheit'}}]
|
||
|
|
````
|
||
|
|
|
||
|
|
#### `Parallel` function call
|
||
|
|
|
||
|
|
<details><summary> Click here: </summary>
|
||
|
|
|
||
|
|
```python
|
||
|
|
available_tools = [{"name": "spotify.play", "description": "Play specific tracks from a given artist for a specific time duration.", "parameters": {"type": "dict", "properties": {"artist": {"type": "string", "description": "The artist whose songs you want to play."}, "duration": {"type": "integer", "description": "The duration for which the songs should be played, in minutes."}}, "required": ["artist", "duration"]}}]
|
||
|
|
query = "Play songs from the artists Taylor Swift and Maroon 5, with a play time of 20 minutes and 15 minutes respectively, on Spotify."
|
||
|
|
|
||
|
|
messages = prepare_messages(query, tools=available_tools)
|
||
|
|
|
||
|
|
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
|
||
|
|
outputs = model.generate(inputs, max_new_tokens=512, do_sample=False, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)
|
||
|
|
result = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=False)
|
||
|
|
|
||
|
|
response = parse_response(result)
|
||
|
|
# [{'name': 'spotify.play', 'arguments': {'artist': 'Taylor Swift', 'duration': 20}}, {'name': 'spotify.play', 'arguments': {'artist': 'Maroon 5', 'duration': 15}}]
|
||
|
|
```
|
||
|
|
|
||
|
|
</details>
|
||
|
|
|
||
|
|
#### `Multiple` function call
|
||
|
|
|
||
|
|
|
||
|
|
<details><summary> Click here: </summary>
|
||
|
|
|
||
|
|
```python
|
||
|
|
available_tools = [{"name": "country_info.largest_city", "description": "Fetch the largest city of a specified country.", "parameters": {"type": "dict", "properties": {"country": {"type": "string", "description": "Name of the country."}}, "required": ["country"]}}, {"name": "country_info.capital", "description": "Fetch the capital city of a specified country.", "parameters": {"type": "dict", "properties": {"country": {"type": "string", "description": "Name of the country."}}, "required": ["country"]}}, {"name": "country_info.population", "description": "Fetch the current population of a specified country.", "parameters": {"type": "dict", "properties": {"country": {"type": "string", "description": "Name of the country."}}, "required": ["country"]}}]
|
||
|
|
query = "What is the capital of Brazil?"
|
||
|
|
|
||
|
|
messages = prepare_messages(query, tools=available_tools)
|
||
|
|
|
||
|
|
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
|
||
|
|
outputs = model.generate(inputs, max_new_tokens=512, do_sample=False, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)
|
||
|
|
result = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=False)
|
||
|
|
|
||
|
|
response = parse_response(result)
|
||
|
|
# [{'name': 'country_info.capital', 'arguments': {'country': 'Brazil'}}]
|
||
|
|
```
|
||
|
|
|
||
|
|
</details>
|
||
|
|
|
||
|
|
#### `Parallel multiple` function call
|
||
|
|
|
||
|
|
|
||
|
|
<details><summary> Click here: </summary>
|
||
|
|
|
||
|
|
```python
|
||
|
|
available_tools = [{"name": "math_toolkit.sum_of_multiples", "description": "Find the sum of all multiples of specified numbers within a specified range.", "parameters": {"type": "dict", "properties": {"lower_limit": {"type": "integer", "description": "The start of the range (inclusive)."}, "upper_limit": {"type": "integer", "description": "The end of the range (inclusive)."}, "multiples": {"type": "array", "items": {"type": "integer"}, "description": "The numbers to find multiples of."}}, "required": ["lower_limit", "upper_limit", "multiples"]}}, {"name": "math_toolkit.product_of_primes", "description": "Find the product of the first n prime numbers.", "parameters": {"type": "dict", "properties": {"count": {"type": "integer", "description": "The number of prime numbers to multiply together."}}, "required": ["count"]}}]
|
||
|
|
query = "Find the sum of all the multiples of 3 and 5 between 1 and 1000. Also find the product of the first five prime numbers."
|
||
|
|
|
||
|
|
messages = prepare_messages(query, tools=available_tools)
|
||
|
|
|
||
|
|
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
|
||
|
|
outputs = model.generate(inputs, max_new_tokens=512, do_sample=False, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)
|
||
|
|
result = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=False)
|
||
|
|
|
||
|
|
response = parse_response(result)
|
||
|
|
# [{'name': 'math_toolkit.sum_of_multiples', 'arguments': {'lower_limit': 1, 'upper_limit': 1000, 'multiples': [3, 5]}}, {'name': 'math_toolkit.product_of_primes', 'arguments': {'count': 5}}]
|
||
|
|
```
|
||
|
|
|
||
|
|
</details>
|
||
|
|
|
||
|
|
#### `Multi-turn` function call
|
||
|
|
|
||
|
|
|
||
|
|
<details><summary> Click here: </summary>
|
||
|
|
|
||
|
|
```python
|
||
|
|
|
||
|
|
get_weather_api = {
|
||
|
|
"name": "get_weather",
|
||
|
|
"description": "Get the current weather for a location",
|
||
|
|
"parameters": {
|
||
|
|
"type": "object",
|
||
|
|
"properties": {
|
||
|
|
"location": {
|
||
|
|
"type": "string",
|
||
|
|
"description": "The city and state, e.g. San Francisco, New York"
|
||
|
|
},
|
||
|
|
"unit": {
|
||
|
|
"type": "string",
|
||
|
|
"enum": ["celsius", "fahrenheit"],
|
||
|
|
"description": "The unit of temperature to return"
|
||
|
|
}
|
||
|
|
},
|
||
|
|
"required": ["location"]
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
available_tools = [get_weather_api]
|
||
|
|
|
||
|
|
query = "What's the weather like in Madrid in celsius?"
|
||
|
|
|
||
|
|
messages = prepare_messages(query, tools=available_tools)
|
||
|
|
|
||
|
|
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
|
||
|
|
|
||
|
|
outputs = model.generate(inputs, max_new_tokens=512, do_sample=False, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)
|
||
|
|
result = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=False)
|
||
|
|
|
||
|
|
response = parse_response(result)
|
||
|
|
|
||
|
|
# 2nd turn
|
||
|
|
conversation_history = messages.copy()
|
||
|
|
conversation_history.append({"role": "assistant", "content": json.dumps(response)})
|
||
|
|
|
||
|
|
new_query = "And in Edinburgh in celsius?"
|
||
|
|
|
||
|
|
new_messages = prepare_messages(new_query, tools=available_tools, conversation_history=conversation_history)
|
||
|
|
|
||
|
|
inputs = tokenizer.apply_chat_template(new_messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
|
||
|
|
|
||
|
|
outputs = model.generate(inputs, max_new_tokens=512, do_sample=False, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)
|
||
|
|
result = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=False)
|
||
|
|
|
||
|
|
response = parse_response(result)
|
||
|
|
# [{'name': 'get_weather', 'arguments': {'location': 'Edinburgh', 'unit': 'celsius'}}]
|
||
|
|
```
|
||
|
|
|
||
|
|
</details>
|
||
|
|
|
||
|
|
#### `Irrelevance` function call (examples when some data is missing)
|
||
|
|
|
||
|
|
|
||
|
|
<details><summary> Click here: </summary>
|
||
|
|
|
||
|
|
Example response with no tools available
|
||
|
|
|
||
|
|
```python
|
||
|
|
available_tools = []
|
||
|
|
|
||
|
|
query = "What's the weather like in New York in fahrenheit?"
|
||
|
|
|
||
|
|
messages = prepare_messages(query, tools=available_tools)
|
||
|
|
|
||
|
|
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
|
||
|
|
|
||
|
|
outputs = model.generate(inputs, max_new_tokens=512, do_sample=False, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)
|
||
|
|
result = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)
|
||
|
|
response = parse_response(result)
|
||
|
|
# 'The query cannot be answered, no tools were provided.'
|
||
|
|
```
|
||
|
|
|
||
|
|
Example when a wrong tool is informed:
|
||
|
|
|
||
|
|
```python
|
||
|
|
cut_number = {
|
||
|
|
'type': 'function',
|
||
|
|
'function': {
|
||
|
|
'name': 'cut_number',
|
||
|
|
'description': 'Returns the value `number` if it is greater than or equal to `threshold`, otherwise returns the value `threshold`.',
|
||
|
|
'parameters': {
|
||
|
|
'type': 'object',
|
||
|
|
'properties': {'number': {'type': 'number', 'description': 'The number to compare.'}},
|
||
|
|
'required': ['number']
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
available_tools = [cut_number]
|
||
|
|
|
||
|
|
query = "What's the weather like in New York in fahrenheit?"
|
||
|
|
|
||
|
|
messages = prepare_messages(query, tools=available_tools)
|
||
|
|
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
|
||
|
|
|
||
|
|
outputs = model.generate(inputs, max_new_tokens=512, do_sample=False, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)
|
||
|
|
result = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)
|
||
|
|
response = parse_response(result)
|
||
|
|
# "The query cannot be answered with the provided tools. The query lacks the parameters required by the function. Please provide the parameters, and I'll be happy to assist."
|
||
|
|
```
|
||
|
|
|
||
|
|
</details>
|
||
|
|
|
||
|
|
## Training procedure
|
||
|
|
|
||
|
|
[<img src="https://raw.githubusercontent.com/wandb/assets/main/wandb-github-badge-28.svg" alt="Visualize in Weights & Biases" width="150" height="24"/>](https://wandb.ai/plaguss/huggingface/runs/dw9q43g4)
|
||
|
|
|
||
|
|
This model was trained with SFT. You can take a look at [sft.slurm](https://huggingface.co/argilla/Llama-3.2-1B-Instruct-APIGen-FC-v0.1/blob/main/sft.slurm) to see the
|
||
|
|
training script, if you don't have access to a slurm cluster, it can be run jsut using the `accelerate` command. It took 13 minutes in a node with 8xH100.
|
||
|
|
|
||
|
|
To install the requirements, the following commands can be used:
|
||
|
|
|
||
|
|
```bash
|
||
|
|
uv venv .venv --python 3.11
|
||
|
|
source .venv/bin/activate
|
||
|
|
git clone https://github.com/huggingface/trl.git
|
||
|
|
uv pip install .
|
||
|
|
uv pip install wandb
|
||
|
|
uv pip install deepspeed
|
||
|
|
```
|
||
|
|
|
||
|
|
And login to your WandB and Hugging Face accounts to push both logs and the final model.
|
||
|
|
|
||
|
|
### Framework versions
|
||
|
|
|
||
|
|
- TRL: 0.12.0.dev0
|
||
|
|
- Transformers: 4.45.1
|
||
|
|
- Pytorch: 2.4.1
|
||
|
|
- Datasets: 3.0.1
|
||
|
|
- Tokenizers: 0.20.0
|
||
|
|
|
||
|
|
## Citations
|
||
|
|
|
||
|
|
Cite TRL as:
|
||
|
|
|
||
|
|
```bibtex
|
||
|
|
@misc{vonwerra2022trl,
|
||
|
|
title = {{TRL: Transformer Reinforcement Learning}},
|
||
|
|
author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallouédec},
|
||
|
|
year = 2020,
|
||
|
|
journal = {GitHub repository},
|
||
|
|
publisher = {GitHub},
|
||
|
|
howpublished = {\url{https://github.com/huggingface/trl}}
|
||
|
|
}
|
||
|
|
```
|