187 lines
5.8 KiB
Markdown
187 lines
5.8 KiB
Markdown
|
|
---
|
||
|
|
license: apache-2.0
|
||
|
|
language:
|
||
|
|
- en
|
||
|
|
library_name: transformers
|
||
|
|
tags:
|
||
|
|
- turn-taking
|
||
|
|
- voice-ai
|
||
|
|
- conversational-ai
|
||
|
|
- dialogue
|
||
|
|
- qwen2
|
||
|
|
- onnx
|
||
|
|
base_model: Qwen/Qwen2.5-0.5B-Instruct
|
||
|
|
pipeline_tag: text-generation
|
||
|
|
---
|
||
|
|
|
||
|
|
# Semantic Turn-Taking Model
|
||
|
|
|
||
|
|
A fine-tuned [Qwen2.5-0.5B-Instruct](https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct) model that predicts turn-taking actions in conversations. Given a conversation context, the model predicts what action a voice AI agent should take next.
|
||
|
|
|
||
|
|
Unlike acoustic-based approaches (VAD, silence detection), this model uses the **semantic content** of the conversation to make turn-taking decisions.
|
||
|
|
|
||
|
|
## Action Classes
|
||
|
|
|
||
|
|
The model predicts one of 4 actions:
|
||
|
|
|
||
|
|
| Action | Token | Description |
|
||
|
|
|--------|-------|-------------|
|
||
|
|
| `start_speaking` | `<\|start_speaking\|>` | User finished their turn, agent should respond |
|
||
|
|
| `continue_listening` | `<\|continue_listening\|>` | User is mid-utterance, keep listening |
|
||
|
|
| `start_listening` | `<\|start_listening\|>` | User interrupted the agent, stop talking |
|
||
|
|
| `continue_speaking` | `<\|continue_speaking\|>` | User gave a backchannel, agent keeps talking |
|
||
|
|
|
||
|
|
## Usage
|
||
|
|
|
||
|
|
### PyTorch
|
||
|
|
|
||
|
|
```python
|
||
|
|
import torch
|
||
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||
|
|
|
||
|
|
model_name = "anyreach-ai/semantic-turn-taking"
|
||
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||
|
|
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).cuda()
|
||
|
|
model.eval()
|
||
|
|
|
||
|
|
# Format conversation as ChatML with <|predict|> trigger
|
||
|
|
conversation = """<|im_start|>user
|
||
|
|
I need help with my bill<|im_end|>
|
||
|
|
<|im_start|>assistant
|
||
|
|
Sure I can help with that what seems to be the issue<|im_end|>
|
||
|
|
<|im_start|>user
|
||
|
|
I was charged twice for the same order<|im_end|>
|
||
|
|
<|predict|>"""
|
||
|
|
|
||
|
|
inputs = tokenizer(conversation, return_tensors="pt").to("cuda")
|
||
|
|
|
||
|
|
with torch.no_grad():
|
||
|
|
logits = model(**inputs).logits[:, -1, :]
|
||
|
|
|
||
|
|
# Get action probabilities
|
||
|
|
action_tokens = {
|
||
|
|
"start_speaking": tokenizer.convert_tokens_to_ids("<|start_speaking|>"),
|
||
|
|
"continue_listening": tokenizer.convert_tokens_to_ids("<|continue_listening|>"),
|
||
|
|
"start_listening": tokenizer.convert_tokens_to_ids("<|start_listening|>"),
|
||
|
|
"continue_speaking": tokenizer.convert_tokens_to_ids("<|continue_speaking|>"),
|
||
|
|
}
|
||
|
|
|
||
|
|
action_logits = {name: logits[0, tid].item() for name, tid in action_tokens.items()}
|
||
|
|
probs = torch.softmax(torch.tensor(list(action_logits.values())), dim=0)
|
||
|
|
for (name, _), p in zip(action_logits.items(), probs):
|
||
|
|
print(f" {name}: {p:.4f}")
|
||
|
|
# → start_speaking: 0.95+ (user is done, agent should respond)
|
||
|
|
```
|
||
|
|
|
||
|
|
### ONNX (CPU)
|
||
|
|
|
||
|
|
```python
|
||
|
|
import numpy as np
|
||
|
|
import onnxruntime as ort
|
||
|
|
from transformers import AutoTokenizer
|
||
|
|
|
||
|
|
tokenizer = AutoTokenizer.from_pretrained("anyreach-ai/semantic-turn-taking")
|
||
|
|
|
||
|
|
sess_options = ort.SessionOptions()
|
||
|
|
sess_options.intra_op_num_threads = 4
|
||
|
|
session = ort.InferenceSession(
|
||
|
|
"onnx/model_q8.onnx", # download from this repo
|
||
|
|
providers=["CPUExecutionProvider"],
|
||
|
|
sess_options=sess_options,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Tokenize
|
||
|
|
conversation = "..." # ChatML format as above
|
||
|
|
inputs = tokenizer(conversation, return_tensors="np")
|
||
|
|
input_ids = inputs["input_ids"].astype("int64")
|
||
|
|
seq_len = input_ids.shape[1]
|
||
|
|
|
||
|
|
# Build feed (empty KV cache for single forward pass)
|
||
|
|
feed = {
|
||
|
|
"input_ids": input_ids,
|
||
|
|
"attention_mask": inputs["attention_mask"].astype("int64"),
|
||
|
|
"position_ids": np.arange(seq_len, dtype="int64").reshape(1, -1),
|
||
|
|
}
|
||
|
|
for i in range(24):
|
||
|
|
feed[f"past_key_values.{i}.key"] = np.zeros((1, 2, 0, 64), dtype="float32")
|
||
|
|
feed[f"past_key_values.{i}.value"] = np.zeros((1, 2, 0, 64), dtype="float32")
|
||
|
|
|
||
|
|
# Run inference
|
||
|
|
logits = session.run(None, feed)[0] # [1, seq_len, vocab_size]
|
||
|
|
last_logits = logits[0, -1, :]
|
||
|
|
|
||
|
|
# Extract action probabilities
|
||
|
|
ACTION_IDS = [151666, 151665, 151667, 151668] # SS, CL, SLi, CS
|
||
|
|
action_logits = last_logits[ACTION_IDS]
|
||
|
|
probs = np.exp(action_logits) / np.sum(np.exp(action_logits))
|
||
|
|
```
|
||
|
|
|
||
|
|
## Benchmark Results
|
||
|
|
|
||
|
|
Evaluated on [anyreach-ai/semantic-turn-taking-benchmark](https://huggingface.co/datasets/anyreach-ai/semantic-turn-taking-benchmark).
|
||
|
|
|
||
|
|
### Binary (EOU vs Not-EOU)
|
||
|
|
|
||
|
|
Only `start_speaking` and `continue_listening` examples. Predictions mapped: SS/CS → EOU, CL/SLi → Not-EOU.
|
||
|
|
|
||
|
|
| Subset | N | Accuracy | F1 (macro) |
|
||
|
|
|--------|--:|--:|--:|
|
||
|
|
| TEN | 428 | 91.82% | 91.80% |
|
||
|
|
| SwDA | 2,688 | 65.96% | 51.46% |
|
||
|
|
| Synthetic | 36 | 86.11% | 85.57% |
|
||
|
|
|
||
|
|
### Multi-class
|
||
|
|
|
||
|
|
| Subset | N | Classes | Accuracy | F1 (macro) |
|
||
|
|
|--------|--:|--------:|--:|--:|
|
||
|
|
| TEN | 428 | 2 | 91.82% | 91.80% |
|
||
|
|
| SwDA | 3,523 | 3 | 68.98% | 46.92% |
|
||
|
|
| Synthetic | 60 | 4 | 76.67% | 72.07% |
|
||
|
|
|
||
|
|
## Latency
|
||
|
|
|
||
|
|
Measured on single examples, CPU (4 threads) and GPU (NVIDIA T4).
|
||
|
|
|
||
|
|
| Format | Size | Short (8 tok) | Medium (28 tok) | Long (54 tok) |
|
||
|
|
|--------|-----:|--:|--:|--:|
|
||
|
|
| PyTorch GPU (fp16) | 942 MB | 26 ms | 30 ms | 34 ms |
|
||
|
|
| PyTorch CPU (fp32) | 942 MB | 165 ms | 247 ms | 289 ms |
|
||
|
|
| ONNX CPU (q8) | 473 MB | 128 ms | 151 ms | 191 ms |
|
||
|
|
|
||
|
|
## Model Details
|
||
|
|
|
||
|
|
- **Base model**: [Qwen/Qwen2.5-0.5B-Instruct](https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct) (494M parameters)
|
||
|
|
- **Training**: Full fine-tuning on ~154K synthetic conversation examples
|
||
|
|
- **Input format**: Qwen ChatML with `<|predict|>` trigger token
|
||
|
|
- **Max sequence length**: 1024 tokens (left truncation)
|
||
|
|
- **Special tokens**: 5 added (`<|predict|>`, 4 action tokens)
|
||
|
|
|
||
|
|
## Files
|
||
|
|
|
||
|
|
| File | Description |
|
||
|
|
|------|-------------|
|
||
|
|
| `model.safetensors` | PyTorch model weights (fp32) |
|
||
|
|
| `onnx/model_q8.onnx` | ONNX INT8 quantized (dynamic quantization) |
|
||
|
|
| `config.json` | Model configuration |
|
||
|
|
| `tokenizer.json` | Tokenizer |
|
||
|
|
|
||
|
|
## Citation
|
||
|
|
|
||
|
|
```bibtex
|
||
|
|
@misc{semantic-turn-taking-2026,
|
||
|
|
title={Semantic Turn-Taking Model},
|
||
|
|
author={Shangeth Rajaa},
|
||
|
|
year={2026},
|
||
|
|
publisher={Hugging Face},
|
||
|
|
url={https://huggingface.co/anyreach-ai/semantic-turn-taking}
|
||
|
|
}
|
||
|
|
```
|
||
|
|
|
||
|
|
## Authors
|
||
|
|
|
||
|
|
- [**Shangeth Rajaa**](https://github.com/shangeth)
|
||
|
|
|
||
|
|
## License
|
||
|
|
|
||
|
|
Apache 2.0
|