239 lines
7.2 KiB
Rust
239 lines
7.2 KiB
Rust
|
|
use super::traits::Tokenizer as TokenizerTrait;
|
|||
|
|
use anyhow::Result;
|
|||
|
|
use std::sync::Arc;
|
|||
|
|
|
|||
|
|
/// Maintains state for an ongoing sequence of tokens and their decoded text
|
|||
|
|
/// This provides a cleaner abstraction for managing token sequences
|
|||
|
|
pub struct Sequence {
|
|||
|
|
/// The tokenizer used for encoding/decoding
|
|||
|
|
tokenizer: Arc<dyn TokenizerTrait>,
|
|||
|
|
|
|||
|
|
/// The current sequence of token ids
|
|||
|
|
token_ids: Vec<u32>,
|
|||
|
|
|
|||
|
|
/// The position in the current sequence the last decoded token completed
|
|||
|
|
prefix_offset: usize,
|
|||
|
|
|
|||
|
|
/// Current position in the sequence
|
|||
|
|
read_offset: usize,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
impl std::fmt::Debug for Sequence {
|
|||
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|||
|
|
f.debug_struct("Sequence")
|
|||
|
|
.field("tokenizer", &"Arc<dyn Tokenizer>")
|
|||
|
|
.field(
|
|||
|
|
"token_ids",
|
|||
|
|
&format_args!("{}", {
|
|||
|
|
let token_ids = self.token_ids();
|
|||
|
|
if token_ids.len() <= 20 {
|
|||
|
|
format!("{:?}", token_ids)
|
|||
|
|
} else {
|
|||
|
|
let first_ten = &token_ids[..10];
|
|||
|
|
let last_ten = &token_ids[token_ids.len() - 10..];
|
|||
|
|
format!("{:?} ... {:?}", first_ten, last_ten)
|
|||
|
|
}
|
|||
|
|
}),
|
|||
|
|
)
|
|||
|
|
.field("prefix_offset", &self.prefix_offset)
|
|||
|
|
.field("read_offset", &self.read_offset)
|
|||
|
|
.field("token count", &self.token_ids.len())
|
|||
|
|
.finish()
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
impl Sequence {
|
|||
|
|
/// Create a new empty sequence
|
|||
|
|
pub fn new(tokenizer: Arc<dyn TokenizerTrait>) -> Self {
|
|||
|
|
Self {
|
|||
|
|
tokenizer,
|
|||
|
|
token_ids: Vec::new(),
|
|||
|
|
prefix_offset: 0,
|
|||
|
|
read_offset: 0,
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
/// Create a sequence with initial tokens
|
|||
|
|
pub fn with_tokens(tokenizer: Arc<dyn TokenizerTrait>, token_ids: Vec<u32>) -> Self {
|
|||
|
|
let len = token_ids.len();
|
|||
|
|
Self {
|
|||
|
|
tokenizer,
|
|||
|
|
token_ids,
|
|||
|
|
prefix_offset: 0,
|
|||
|
|
read_offset: len,
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
/// Check if the sequence is empty
|
|||
|
|
pub fn is_empty(&self) -> bool {
|
|||
|
|
self.token_ids.is_empty()
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
/// Get the length of the sequence
|
|||
|
|
pub fn len(&self) -> usize {
|
|||
|
|
self.token_ids.len()
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
/// Clear the sequence
|
|||
|
|
pub fn clear(&mut self) {
|
|||
|
|
self.token_ids.clear();
|
|||
|
|
self.prefix_offset = 0;
|
|||
|
|
self.read_offset = 0;
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
/// Append text to the sequence by encoding it
|
|||
|
|
pub fn append_text(&mut self, input: &str) -> Result<()> {
|
|||
|
|
let encoding = self.tokenizer.encode(input)?;
|
|||
|
|
self.token_ids.extend(encoding.token_ids());
|
|||
|
|
Ok(())
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
/// Append a single token to the sequence and return newly decoded text
|
|||
|
|
/// Based on HuggingFace TGI incremental decoding
|
|||
|
|
pub fn append_token(&mut self, token_id: u32) -> Result<String> {
|
|||
|
|
// Store the old read offset before adding the new token
|
|||
|
|
let old_read_offset = self.read_offset;
|
|||
|
|
|
|||
|
|
self.token_ids.push(token_id);
|
|||
|
|
self.read_offset = self.token_ids.len();
|
|||
|
|
|
|||
|
|
// If this is the first token or we're at the beginning, decode everything
|
|||
|
|
if self.prefix_offset == 0 && old_read_offset == 0 {
|
|||
|
|
let text = self.tokenizer.decode(&self.token_ids, false)?;
|
|||
|
|
if text.ends_with("<EFBFBD>") {
|
|||
|
|
// Incomplete UTF-8 sequence, wait for more tokens
|
|||
|
|
return Ok(String::new());
|
|||
|
|
}
|
|||
|
|
self.prefix_offset = 0;
|
|||
|
|
return Ok(text);
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Decode the text up to the previous position
|
|||
|
|
let prefix_text = self
|
|||
|
|
.tokenizer
|
|||
|
|
.decode(&self.token_ids[self.prefix_offset..old_read_offset], false)?;
|
|||
|
|
|
|||
|
|
// Decode the text including the new token
|
|||
|
|
let new_text = self
|
|||
|
|
.tokenizer
|
|||
|
|
.decode(&self.token_ids[self.prefix_offset..], false)?;
|
|||
|
|
|
|||
|
|
// Handle multi-byte character boundaries
|
|||
|
|
let mut prefix_text_len = prefix_text.len();
|
|||
|
|
while !new_text.is_char_boundary(prefix_text_len) && prefix_text_len > 0 {
|
|||
|
|
prefix_text_len -= 1;
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if new_text.len() > prefix_text.len() {
|
|||
|
|
if new_text.ends_with("<EFBFBD>") {
|
|||
|
|
// Incomplete UTF-8 sequence, wait for more tokens
|
|||
|
|
return Ok(String::new());
|
|||
|
|
} else {
|
|||
|
|
// Return the new text portion
|
|||
|
|
let incremental_text = new_text[prefix_text_len..].to_string().replace("<EFBFBD>", "");
|
|||
|
|
self.prefix_offset = old_read_offset;
|
|||
|
|
return Ok(incremental_text);
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
Ok(String::new())
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
/// Get a reference to the tokenizer
|
|||
|
|
pub fn tokenizer(&self) -> &Arc<dyn TokenizerTrait> {
|
|||
|
|
&self.tokenizer
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
/// Get the current token ids
|
|||
|
|
pub fn token_ids(&self) -> &[u32] {
|
|||
|
|
&self.token_ids
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
/// Decode the entire sequence to text
|
|||
|
|
pub fn text(&self) -> Result<String> {
|
|||
|
|
self.tokenizer.decode(&self.token_ids, false)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
/// Get the prefix offset
|
|||
|
|
pub fn prefix_offset(&self) -> usize {
|
|||
|
|
self.prefix_offset
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
/// Get the read offset
|
|||
|
|
pub fn read_offset(&self) -> usize {
|
|||
|
|
self.read_offset
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
#[cfg(test)]
|
|||
|
|
mod tests {
|
|||
|
|
use super::*;
|
|||
|
|
use crate::tokenizer::mock::MockTokenizer;
|
|||
|
|
|
|||
|
|
#[test]
|
|||
|
|
fn test_sequence_new() {
|
|||
|
|
let tokenizer = Arc::new(MockTokenizer::new());
|
|||
|
|
let seq = Sequence::new(tokenizer);
|
|||
|
|
assert!(seq.is_empty());
|
|||
|
|
assert_eq!(seq.len(), 0);
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
#[test]
|
|||
|
|
fn test_sequence_append_text() {
|
|||
|
|
let tokenizer = Arc::new(MockTokenizer::new());
|
|||
|
|
let mut seq = Sequence::new(tokenizer);
|
|||
|
|
|
|||
|
|
seq.append_text("Hello").unwrap();
|
|||
|
|
assert!(!seq.is_empty());
|
|||
|
|
assert!(!seq.is_empty());
|
|||
|
|
|
|||
|
|
let text = seq.text().unwrap();
|
|||
|
|
assert_eq!(text, "Hello");
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
#[test]
|
|||
|
|
fn test_sequence_append_token() {
|
|||
|
|
let tokenizer = Arc::new(MockTokenizer::new());
|
|||
|
|
let mut seq = Sequence::new(tokenizer.clone());
|
|||
|
|
|
|||
|
|
// Start with an empty sequence and append token 1 ("Hello")
|
|||
|
|
let text1 = seq.append_token(1).unwrap();
|
|||
|
|
assert_eq!(text1, "Hello");
|
|||
|
|
|
|||
|
|
// Now append token 2 ("world")
|
|||
|
|
// The mock tokenizer will decode [1, 2] as "Hello world" (with a space)
|
|||
|
|
let text2 = seq.append_token(2).unwrap();
|
|||
|
|
// The incremental text should be " world" (with the space that the mock tokenizer adds)
|
|||
|
|
assert_eq!(text2, " world");
|
|||
|
|
|
|||
|
|
// Verify the full text
|
|||
|
|
assert_eq!(seq.text().unwrap(), "Hello world");
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
#[test]
|
|||
|
|
fn test_sequence_clear() {
|
|||
|
|
let tokenizer = Arc::new(MockTokenizer::new());
|
|||
|
|
let mut seq = Sequence::new(tokenizer);
|
|||
|
|
|
|||
|
|
seq.append_text("Hello world").unwrap();
|
|||
|
|
assert!(!seq.is_empty());
|
|||
|
|
|
|||
|
|
seq.clear();
|
|||
|
|
assert!(seq.is_empty());
|
|||
|
|
assert_eq!(seq.len(), 0);
|
|||
|
|
assert_eq!(seq.prefix_offset(), 0);
|
|||
|
|
assert_eq!(seq.read_offset(), 0);
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
#[test]
|
|||
|
|
fn test_sequence_debug() {
|
|||
|
|
let tokenizer = Arc::new(MockTokenizer::new());
|
|||
|
|
let mut seq = Sequence::new(tokenizer);
|
|||
|
|
|
|||
|
|
seq.append_text("Test").unwrap();
|
|||
|
|
let debug_str = format!("{:?}", seq);
|
|||
|
|
assert!(debug_str.contains("Sequence"));
|
|||
|
|
assert!(debug_str.contains("token count"));
|
|||
|
|
}
|
|||
|
|
}
|