102 lines
3.8 KiB
Python
102 lines
3.8 KiB
Python
|
|
# llm-jp-4 tokenizer
|
||
|
|
|
||
|
|
from collections.abc import Sequence
|
||
|
|
import os
|
||
|
|
|
||
|
|
from transformers import LlamaTokenizerFast
|
||
|
|
from tokenizers import Tokenizer
|
||
|
|
|
||
|
|
from .llmjp4_harmony import HarmonyMessageParser, HarmonyMessage
|
||
|
|
|
||
|
|
|
||
|
|
class Llmjp4Tokenizer(LlamaTokenizerFast):
|
||
|
|
_HARMONY_TOKENS: set[str] = {
|
||
|
|
"<|start|>",
|
||
|
|
"<|message|>",
|
||
|
|
"<|channel|>",
|
||
|
|
"<|constrain|>",
|
||
|
|
"<|end|>",
|
||
|
|
"<|return|>",
|
||
|
|
"<|call|>",
|
||
|
|
}
|
||
|
|
|
||
|
|
# NOTE(odashi):
|
||
|
|
# Response schemas are not recognized automatically.
|
||
|
|
# We need to define them manually.
|
||
|
|
# https://github.com/huggingface/trl/issues/4609
|
||
|
|
_RESPONSE_SCHEMA = {
|
||
|
|
"type": "object",
|
||
|
|
"properties": {
|
||
|
|
"role": {"const": "assistant"},
|
||
|
|
"content": {"type": "string", "x-regex": r"<\|channel\|>final<\|message\|>(.*?)(?:<\|end\|>|<\|return\|>|$)"},
|
||
|
|
"thinking": {"type": "string", "x-regex": r"<\|channel\|>analysis<\|message\|>(.*?)<\|end\|>"},
|
||
|
|
"tool_calls": {
|
||
|
|
"x-regex-iterator": r"<\|channel\|>commentary (to=functions\..*?<\|message\|>.*?)(?:<\|call\|>|$)",
|
||
|
|
"type": "array",
|
||
|
|
"items": {
|
||
|
|
"type": "object",
|
||
|
|
"properties": {
|
||
|
|
"type": {"const": "function"},
|
||
|
|
"function": {
|
||
|
|
"type": "object",
|
||
|
|
"properties": {
|
||
|
|
"name": {"type": "string", "x-regex": r"^to=functions\.(\w+)"},
|
||
|
|
"arguments": {
|
||
|
|
"type": "object",
|
||
|
|
"x-regex": r"<\|message\|>(.*)",
|
||
|
|
"x-parser": "json",
|
||
|
|
"additionalProperties": {"type": "any"},
|
||
|
|
},
|
||
|
|
},
|
||
|
|
},
|
||
|
|
},
|
||
|
|
},
|
||
|
|
},
|
||
|
|
},
|
||
|
|
}
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def convert_to_native_format(cls, **kwargs):
|
||
|
|
# NOTE(odashi):
|
||
|
|
# Workaround for transformers 5.x.
|
||
|
|
# Guaranteeing the same inner behavior with TokenizersBackend.
|
||
|
|
# https://github.com/huggingface/transformers/blob/7d9754a05193eb79b1d86aa744b622b8068008cd/src/transformers/tokenization_utils_tokenizers.py#L110-L116
|
||
|
|
local_kwargs = dict(kwargs)
|
||
|
|
fast_tokenizer_file = local_kwargs.pop("tokenizer_file", None)
|
||
|
|
if fast_tokenizer_file is None or not os.path.isfile(fast_tokenizer_file):
|
||
|
|
raise ValueError("Tokenizer file must exist.")
|
||
|
|
|
||
|
|
local_kwargs["tokenizer_object"] = Tokenizer.from_file(fast_tokenizer_file)
|
||
|
|
return local_kwargs
|
||
|
|
|
||
|
|
def __init__(self, *args, **kwargs):
|
||
|
|
super().__init__(*args, **kwargs)
|
||
|
|
|
||
|
|
self.response_schema = self._RESPONSE_SCHEMA
|
||
|
|
|
||
|
|
self._harmony_token_ids = {
|
||
|
|
self.convert_tokens_to_ids(token)
|
||
|
|
for token in self._HARMONY_TOKENS
|
||
|
|
}
|
||
|
|
|
||
|
|
def _decode(self, token_ids: int | list[int], *args, **kwargs):
|
||
|
|
if isinstance(token_ids, int):
|
||
|
|
token_ids = [token_ids]
|
||
|
|
|
||
|
|
result: list[str] = []
|
||
|
|
prev_pos = 0
|
||
|
|
|
||
|
|
# NOTE(odashi):
|
||
|
|
# Ensure that text tokens are decoded without preceding Harmony tokens
|
||
|
|
# to avoid incorrect addition of whitespaces.
|
||
|
|
for pos, token_id in enumerate(token_ids, start=1):
|
||
|
|
if token_id in self._harmony_token_ids or pos == len(token_ids):
|
||
|
|
result.append(super()._decode(token_ids[prev_pos:pos], *args, **kwargs))
|
||
|
|
prev_pos = pos
|
||
|
|
|
||
|
|
return "".join(result)
|
||
|
|
|
||
|
|
def parse_harmony_message(self, token_ids: Sequence[int]) -> list[HarmonyMessage]:
|
||
|
|
"""Helper function to parse token IDs into Harmony messages."""
|
||
|
|
return HarmonyMessageParser(self).get_all_messages(token_ids)
|