初始化项目,由ModelHub XC社区提供模型
Model: llm-jp/llm-jp-4-8b-base Source: Original Platform
This commit is contained in:
129
llmjp4_harmony.py
Normal file
129
llmjp4_harmony.py
Normal file
@@ -0,0 +1,129 @@
|
||||
# Generic parser for OpenAI Harmony format.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Iterator, Sequence
|
||||
|
||||
from transformers import PreTrainedTokenizerBase as TokenizerLike
|
||||
|
||||
|
||||
class HarmonyMessageEndType(Enum):
|
||||
INCOMPLETE = 0
|
||||
END = 1
|
||||
CALL = 2
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class HarmonySequence:
|
||||
"""A data class representing a sequence of tokens in the Harmony format."""
|
||||
token_ids: list[int]
|
||||
start: int # Start position of the sequence in the original token sequence
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class HarmonyMessage:
|
||||
"""A data class representing a message in the Harmony format."""
|
||||
end: HarmonyMessageEndType
|
||||
role: HarmonySequence | None = None
|
||||
channel: HarmonySequence | None = None
|
||||
constrain: HarmonySequence | None = None
|
||||
content: HarmonySequence | None = None
|
||||
|
||||
|
||||
class HarmonyMessageParser:
|
||||
"""A parser that performs lexical analysis to extract Harmony messages."""
|
||||
|
||||
def __init__(self, tokenizer: TokenizerLike):
|
||||
vocab = tokenizer.get_vocab()
|
||||
self._begin_map = {
|
||||
vocab["<|start|>"]: "role",
|
||||
vocab["<|channel|>"]: "channel",
|
||||
vocab["<|constrain|>"]: "constrain",
|
||||
vocab["<|message|>"]: "content",
|
||||
}
|
||||
self._end_map = {
|
||||
vocab["<|end|>"]: HarmonyMessageEndType.END,
|
||||
vocab["<|return|>"]: HarmonyMessageEndType.END,
|
||||
vocab["<|call|>"]: HarmonyMessageEndType.CALL,
|
||||
}
|
||||
|
||||
def iter_messages(self, token_ids: Sequence[int]) -> Iterator[HarmonyMessage]:
|
||||
"""
|
||||
Parse given token ids into messages.
|
||||
|
||||
Args:
|
||||
token_ids: A sequence of token ids to be parsed.
|
||||
|
||||
Yields:
|
||||
Detected HarmonyMessages.
|
||||
"""
|
||||
|
||||
message_dict: dict[str, HarmonySequence] = {}
|
||||
section: str | None = None # None indicates out-of-message.
|
||||
text_ids: list[int] = []
|
||||
text_start: int | None = None
|
||||
|
||||
for token_position, token_id in enumerate(token_ids):
|
||||
if token_id in self._begin_map:
|
||||
if section is not None:
|
||||
message_dict[section] = HarmonySequence(
|
||||
token_ids=text_ids,
|
||||
start=text_start,
|
||||
)
|
||||
section = self._begin_map[token_id]
|
||||
text_ids = []
|
||||
text_start = token_position + 1
|
||||
|
||||
elif token_id in self._end_map:
|
||||
if section is not None:
|
||||
message_dict[section] = HarmonySequence(
|
||||
token_ids=text_ids,
|
||||
start=text_start,
|
||||
)
|
||||
|
||||
yield HarmonyMessage(**message_dict, end=self._end_map[token_id])
|
||||
|
||||
message_dict = {}
|
||||
section = None
|
||||
text_ids = []
|
||||
text_start = None
|
||||
|
||||
else:
|
||||
if section is not None:
|
||||
text_ids.append(token_id)
|
||||
|
||||
if section is not None:
|
||||
message_dict[section] = HarmonySequence(
|
||||
token_ids=text_ids,
|
||||
start=text_start,
|
||||
)
|
||||
yield HarmonyMessage(**message_dict, end=HarmonyMessageEndType.INCOMPLETE)
|
||||
|
||||
def get_all_messages(self, token_ids: Sequence[int]) -> list[HarmonyMessage]:
|
||||
"""
|
||||
Parse given token ids into messages.
|
||||
|
||||
Args:
|
||||
token_ids: A sequence of token ids to be parsed.
|
||||
|
||||
Returns:
|
||||
A list of detected HarmonyMessages.
|
||||
"""
|
||||
return list(self.iter_messages(token_ids))
|
||||
|
||||
def reverse_iter_messages(self, token_ids: Sequence[int]) -> Iterator[HarmonyMessage]:
|
||||
"""
|
||||
Parse given token ids into messages in reverse order.
|
||||
|
||||
Args:
|
||||
token_ids: A sequence of token ids to be parsed.
|
||||
|
||||
Yields:
|
||||
Detected HarmonyMessages in reverse order.
|
||||
"""
|
||||
end_position = len(token_ids)
|
||||
|
||||
for i in range(len(token_ids) - 1, -1, -1):
|
||||
if token_ids[i] == self._start_id:
|
||||
yield next(self.iter_messages(token_ids[i:end_position]))
|
||||
end_position = i
|
||||
Reference in New Issue
Block a user