130 lines
4.1 KiB
Python
130 lines
4.1 KiB
Python
|
|
# Generic parser for OpenAI Harmony format.
|
||
|
|
|
||
|
|
from dataclasses import dataclass
|
||
|
|
from enum import Enum
|
||
|
|
from typing import Iterator, Sequence
|
||
|
|
|
||
|
|
from transformers import PreTrainedTokenizerBase as TokenizerLike
|
||
|
|
|
||
|
|
|
||
|
|
class HarmonyMessageEndType(Enum):
|
||
|
|
INCOMPLETE = 0
|
||
|
|
END = 1
|
||
|
|
CALL = 2
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass(frozen=True)
|
||
|
|
class HarmonySequence:
|
||
|
|
"""A data class representing a sequence of tokens in the Harmony format."""
|
||
|
|
token_ids: list[int]
|
||
|
|
start: int # Start position of the sequence in the original token sequence
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass(frozen=True)
|
||
|
|
class HarmonyMessage:
|
||
|
|
"""A data class representing a message in the Harmony format."""
|
||
|
|
end: HarmonyMessageEndType
|
||
|
|
role: HarmonySequence | None = None
|
||
|
|
channel: HarmonySequence | None = None
|
||
|
|
constrain: HarmonySequence | None = None
|
||
|
|
content: HarmonySequence | None = None
|
||
|
|
|
||
|
|
|
||
|
|
class HarmonyMessageParser:
|
||
|
|
"""A parser that performs lexical analysis to extract Harmony messages."""
|
||
|
|
|
||
|
|
def __init__(self, tokenizer: TokenizerLike):
|
||
|
|
vocab = tokenizer.get_vocab()
|
||
|
|
self._begin_map = {
|
||
|
|
vocab["<|start|>"]: "role",
|
||
|
|
vocab["<|channel|>"]: "channel",
|
||
|
|
vocab["<|constrain|>"]: "constrain",
|
||
|
|
vocab["<|message|>"]: "content",
|
||
|
|
}
|
||
|
|
self._end_map = {
|
||
|
|
vocab["<|end|>"]: HarmonyMessageEndType.END,
|
||
|
|
vocab["<|return|>"]: HarmonyMessageEndType.END,
|
||
|
|
vocab["<|call|>"]: HarmonyMessageEndType.CALL,
|
||
|
|
}
|
||
|
|
|
||
|
|
def iter_messages(self, token_ids: Sequence[int]) -> Iterator[HarmonyMessage]:
|
||
|
|
"""
|
||
|
|
Parse given token ids into messages.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
token_ids: A sequence of token ids to be parsed.
|
||
|
|
|
||
|
|
Yields:
|
||
|
|
Detected HarmonyMessages.
|
||
|
|
"""
|
||
|
|
|
||
|
|
message_dict: dict[str, HarmonySequence] = {}
|
||
|
|
section: str | None = None # None indicates out-of-message.
|
||
|
|
text_ids: list[int] = []
|
||
|
|
text_start: int | None = None
|
||
|
|
|
||
|
|
for token_position, token_id in enumerate(token_ids):
|
||
|
|
if token_id in self._begin_map:
|
||
|
|
if section is not None:
|
||
|
|
message_dict[section] = HarmonySequence(
|
||
|
|
token_ids=text_ids,
|
||
|
|
start=text_start,
|
||
|
|
)
|
||
|
|
section = self._begin_map[token_id]
|
||
|
|
text_ids = []
|
||
|
|
text_start = token_position + 1
|
||
|
|
|
||
|
|
elif token_id in self._end_map:
|
||
|
|
if section is not None:
|
||
|
|
message_dict[section] = HarmonySequence(
|
||
|
|
token_ids=text_ids,
|
||
|
|
start=text_start,
|
||
|
|
)
|
||
|
|
|
||
|
|
yield HarmonyMessage(**message_dict, end=self._end_map[token_id])
|
||
|
|
|
||
|
|
message_dict = {}
|
||
|
|
section = None
|
||
|
|
text_ids = []
|
||
|
|
text_start = None
|
||
|
|
|
||
|
|
else:
|
||
|
|
if section is not None:
|
||
|
|
text_ids.append(token_id)
|
||
|
|
|
||
|
|
if section is not None:
|
||
|
|
message_dict[section] = HarmonySequence(
|
||
|
|
token_ids=text_ids,
|
||
|
|
start=text_start,
|
||
|
|
)
|
||
|
|
yield HarmonyMessage(**message_dict, end=HarmonyMessageEndType.INCOMPLETE)
|
||
|
|
|
||
|
|
def get_all_messages(self, token_ids: Sequence[int]) -> list[HarmonyMessage]:
|
||
|
|
"""
|
||
|
|
Parse given token ids into messages.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
token_ids: A sequence of token ids to be parsed.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
A list of detected HarmonyMessages.
|
||
|
|
"""
|
||
|
|
return list(self.iter_messages(token_ids))
|
||
|
|
|
||
|
|
def reverse_iter_messages(self, token_ids: Sequence[int]) -> Iterator[HarmonyMessage]:
|
||
|
|
"""
|
||
|
|
Parse given token ids into messages in reverse order.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
token_ids: A sequence of token ids to be parsed.
|
||
|
|
|
||
|
|
Yields:
|
||
|
|
Detected HarmonyMessages in reverse order.
|
||
|
|
"""
|
||
|
|
end_position = len(token_ids)
|
||
|
|
|
||
|
|
for i in range(len(token_ids) - 1, -1, -1):
|
||
|
|
if token_ids[i] == self._start_id:
|
||
|
|
yield next(self.iter_messages(token_ids[i:end_position]))
|
||
|
|
end_position = i
|